From 4c6b67b474588fa1d477fc2fdc2864a60bbbdb64 Mon Sep 17 00:00:00 2001 From: Shuang Yu Date: Mon, 10 Nov 2025 22:15:56 -0800 Subject: [PATCH 01/40] kv-cache: prepare clean commit without excluded files Signed-off-by: Zhaopeng Qiu --- nemo_rl/algorithms/grpo.py | 255 ++++++++++++++++- .../data/datasets/response_datasets/dapo.py | 105 +++++++ .../models/generation/vllm/vllm_backend.py | 5 +- .../models/generation/vllm/vllm_generation.py | 17 +- nemo_rl/models/generation/vllm/vllm_worker.py | 13 +- nemo_rl/models/policy/interfaces.py | 25 ++ nemo_rl/models/policy/lm_policy.py | 67 +++++ .../models/policy/megatron_policy_worker.py | 258 ++++++++++++++++++ 8 files changed, 721 insertions(+), 24 deletions(-) create mode 100644 nemo_rl/data/datasets/response_datasets/dapo.py diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 2cb9e001c9..76358078a7 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -180,6 +180,35 @@ class MasterConfig(TypedDict): # Setup & Initialization # =============================================================================== +# Function to check if KV cache scales should be calculated and synchronized during refit +# TODO: Where and how to calcualte this kv cache scales? +# TODO: This should be checked only once and reused during the whole training process. Should the flag be stored somewhere? +def _should_sync_kv_scales(master_config: MasterConfig) -> bool: + """ + Check if KV cache scales should be synchronized during refit. + + Returns True if: + - vLLM backend is used for generation + - Either kv_cache_dtype is fp8 OR vLLM precision is fp8 (which implies fp8 kv cache) + - This indicates we need to sync _k_scale and _v_scale values + """ + generation_config = master_config["policy"]["generation"] + if generation_config is None: + return False + + backend = generation_config.get("backend", "") + if backend != "vllm": + return False + + vllm_cfg = generation_config.get("vllm_cfg", {}) + kv_cache_dtype = vllm_cfg.get("kv_cache_dtype", "auto") + vllm_precision = vllm_cfg.get("precision", "auto") + + # Check if either kv_cache_dtype is explicitly fp8 or vLLM precision is fp8 + # When vLLM precision is fp8, it typically implies fp8 kv cache as well + # should enable kv scale sync when both are true + return kv_cache_dtype == "fp8" and vllm_precision == "fp8" + def setup( master_config: MasterConfig, @@ -861,6 +890,129 @@ def _should_use_penguin(master_config: MasterConfig) -> bool: return should_use_penguin +# Need a function to compute the kv cache scales for all the attention layers with the updated policy model +# TODO: Determine the inputs and outputs. inputs: the trained policy model? training data? +# TODO: How to do the calculation? Caculating the kv cache scales needs to do a foward path with some training data, get the activations of each attention layer and compute the scales based on the activations. +# TODO: The calcuation needs to be done only when the sync_kv_scales flat is True, and after policy model is updated. +# TODO: The output should be a dictionary of the kv cache scales for all the attention layers? The structure should be consistent with the required format that can be loaded by vllm using model_runner.model.load_weights() as the other weights. +# Code snippet reference: /lustre/fsw/portfolios/coreai/users/shuangy/src/vllm/vllm/model_executor/layers/quantization/kv_cache.py: +# def create_weights(self, layer: torch.nn.Module): +# """ +# Create "weight" (aka q_scale, k_scale and v_scale) +# for an attention layer. +# """ + # Initialize the Q and KV cache scales to -1.0, an invalid value. + # If the q and k/v_scales appear in the checkpoint, it will be + # overwritten when loading weights. +# layer.q_scale = torch.nn.Parameter(torch.tensor(-1.0), +# requires_grad=False) +# layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), +# requires_grad=False) +# layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), +# requires_grad=False) + # Initialize P = softmax(QK^T) scales +# layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0), +# requires_grad=False) +# requires_grad=False) +# TODO: How to pass the kv scales to refit_policy_generation()? refit_policy_generation() is the function that updates the weights of the policy generation interface. +# When refit_policy_generation() invokes update_weights_from_ipc_handles() or update_weights_from_collective(), if it is fp8 and sync_kv_scales is True, the kv scales should be passed to the policy generation interface. load_weights() once invoked will load the kv scales. +# In order for vllm to really load the kv scales, the kv scales should be passed to the policy generation interface in the same format as the other weights. +# Additionally, vllm process_weights_after_loading() will be invoked after load_weights() to copy the kv scales to the _k_scale and _v_scale attributes. Reference code: /lustre/fsw/portfolios/coreai/users/shuangy/src/vllm/vllm/model_executor/layers/quantization/kv_cache.py + + +def compute_kv_scales_with_data( + policy: ColocatablePolicyInterface, + sample_data: BatchedDataDict, + master_config: MasterConfig, + max_samples: int = 32, +) -> dict[str, float]: + """ + Compute KV cache scales for all attention layers using calibration data. + + Args: + policy: The policy model to calibrate + sample_data: Calibration data batch + master_config: Configuration containing model settings + max_samples: Maximum number of samples to use for calibration + + Returns: + Dictionary mapping parameter names to scale values for K/V cache quantization + """ + # TODO: Review the implementation of this function. + print(f"[KV_SCALES] Computing KV cache scales with {min(max_samples, sample_data.size)} samples...") + + # Limit the number of samples for calibration + if sample_data.size > max_samples: + sample_data = sample_data.slice(0, max_samples) + + # Convert to input format expected by policy + import torch + from nemo_rl.data.llm_message_utils import batched_message_log_to_flat_message + + try: + # Extract tokenized inputs from the batch + batched_flat, input_lengths = batched_message_log_to_flat_message( + sample_data["message_log"], + pad_value_dict={"token_ids": 0} # Use 0 as pad token for calibration + ) + input_ids = batched_flat["token_ids"] + + # Convert to tensor if needed + if not isinstance(input_ids, torch.Tensor): + input_ids = torch.tensor(input_ids, dtype=torch.long) + + # For distributed policy, we'll use a simplified approach + # TODO: Implement proper distributed calibration through worker_group + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + input_ids = input_ids.to(device) + + print(f"[KV_SCALES] Calibration input shape: {input_ids.shape}") + + # Skip the complex hook-based calibration for distributed policy + # TODO: Implement proper distributed calibration using policy.worker_group + + # For distributed Policy, we cannot directly access model.named_modules() + # Instead, we'll use a simplified approach with default scales + print("[KV_SCALES] Using simplified calibration for distributed policy") + + # TODO: For a quick prototype, use a pseudo default scales. + # Need to update later to would use worker_group to run calibration? + default_k_scale = 0.1 # Conservative scale for K projections + default_v_scale = 0.1 # Conservative scale for V projections + # TODO: Current use Qwen3-8B-Base as an example, should be obtained from model config + num_layers = 36 # Default number of layers - should be obtained from model config + + # Generate default KV scales for distributed policy + kv_scales = {} + print("[KV_SCALES] Generating default KV scales for distributed policy") + + # Generate scales for typical transformer layers + for layer_idx in range(num_layers): + k_param_name = f"model.layers.{layer_idx}.self_attn.k_scale" + v_param_name = f"model.layers.{layer_idx}.self_attn.v_scale" + + kv_scales[k_param_name] = default_k_scale + kv_scales[v_param_name] = default_v_scale + + print(f"[KV_SCALES] Computed {len(kv_scales)} KV cache scales") + return kv_scales + + except Exception as e: + print(f"[KV_SCALES] Error computing KV scales: {e}") + # For training stability, we can either: + # 1. Re-raise the exception to fail fast and debug issues early + # 2. Fall back to default scales to continue training + # Current choice: fallback for robustness, but log the error clearly + print("[KV_SCALES] Falling back to default scales to maintain training stability") + print("[KV_SCALES] Note: This may impact FP8 quantization quality") + + # Return default scales + default_scales = {} + for name, module in policy.model.named_modules(): + if "self_attn" in name: + default_scales[f"{name}.k_scale"] = 1.0 + default_scales[f"{name}.v_scale"] = 1.0 + return default_scales def refit_policy_generation( policy: ColocatablePolicyInterface, @@ -868,6 +1020,7 @@ def refit_policy_generation( colocated_inference: bool, _refit_buffer_size_gb: Optional[int] = None, timer: Optional[Timer] = None, + kv_scales: Optional[dict[str, float]] = None, ) -> None: """Refit the policy generation interface with the latest policy weights. @@ -878,6 +1031,7 @@ def refit_policy_generation( If it is None, the buffer size will be computed by the remaining memory. This parameter is primarily used for testing. timer: Optional Timer used to time the prepare/transfer/update phase + kv_scales: Optional dictionary of KV cache scales for FP8 quantization. """ if colocated_inference: policy.offload_before_refit() @@ -915,7 +1069,11 @@ def refit_policy_generation( else: # update weights through nccl futures_train = policy.broadcast_weights_for_collective() - futures_inference = policy_generation.update_weights_from_collective() + if kv_scales: + print(f"[KV_SCALES] Refit: Adding {len(kv_scales)} KV scales to collective weight update") + futures_inference = policy_generation.update_weights_from_collective(kv_scales=kv_scales) + else: + futures_inference = policy_generation.update_weights_from_collective() # wait for all futures to complete ray.get(futures_train) results = ray.get(futures_inference) @@ -964,6 +1122,24 @@ def grpo_train( ) timeout.start_iterations() + # Check if we need to sync KV cache scales (infer from config) + sync_kv_scales = _should_sync_kv_scales(master_config) + kv_scales_cache = None # Cache computed KV scales for reuse + + if sync_kv_scales: + generation_config = master_config["policy"]["generation"] + vllm_cfg = generation_config.get("vllm_cfg", {}) + backend = generation_config.get("backend", "") + kv_cache_dtype = vllm_cfg.get("kv_cache_dtype", "auto") + vllm_precision = vllm_cfg.get("precision", "auto") + policy_backend = "megatron" if master_config["policy"].get("megatron_cfg", {}).get("enabled", False) else "dtensor" + + print(f"[KV_SCALES] FP8 KV cache detected, will sync _k_scale and _v_scale during refit") + print(f"[KV_SCALES] Configuration: policy_backend={policy_backend}, generation_backend={backend}") + print(f"[KV_SCALES] vLLM settings: precision={vllm_precision}, kv_cache_dtype={kv_cache_dtype}") + else: + print("[KV_SCALES] KV cache scale sync not needed (non-FP8 mode)") + NEED_REFIT = True # If policy_generation is None, use the policy as the generation interface (megatron framework backend) if policy_generation is None: @@ -1047,21 +1223,55 @@ def grpo_train( ) input_ids = batched_flat["token_ids"] - # Generate responses - this updates the LLMMessageLogType in repeated_batch - print( - f"▶ Generating responses for batch of size {repeated_batch.size}...", - flush=True, - ) - with timer.time("prepare_for_generation/total"): - if NEED_REFIT and POLICY_GENERATION_STALE: - refit_policy_generation( - policy, policy_generation, colocated_inference, timer=timer + # Generate responses - this updates the LLMMessageLogType in repeated_batch + print( + f"▶ Generating responses for batch of size {repeated_batch.size}...", + flush=True, + ) + with timer.time("prepare_for_generation"): + if NEED_REFIT and POLICY_GENERATION_STALE: + # Compute KV scales if needed for FP8 quantization + if sync_kv_scales and kv_scales_cache is None: + print("[KV_SCALES] Computing KV cache scales for the first time...") + policy.prepare_for_lp_inference() + # kv_scales_cache = compute_kv_scales_with_data( + # policy, repeated_batch, master_config + # ) + kv_scales_cache = {} + + # Create calibration data from flattened messages + calibration_data = BatchedDataDict[ClippedPGLossDataDict]( + { + "input_ids": batched_flat["token_ids"], + "input_lengths": input_lengths, + # "advantages": batched_flat["advantages"], + # "generation_logprobs": batched_flat["generation_logprobs"], + # "token_mask": batched_flat["token_loss_mask"], + # "sample_mask": batched_flat["loss_multiplier"], + } ) - POLICY_GENERATION_STALE = False - else: + # this will be mini-batched inside the policy, so maintain the packed multimodal structure + calibration_data.update(batched_flat.get_multimodal_dict(as_tensors=False)) + calibration_data.to("cpu") + kv_scales = policy.calibrate_qkv_fp8_scales(calibration_data, include_q=True)["layers"] + for k, v in kv_scales.items(): + layer_idx = k.split("_")[1] + k_param_name = f"model.layers.{layer_idx}.self_attn.k_scale" + v_param_name = f"model.layers.{layer_idx}.self_attn.v_scale" + q_param_name = f"model.layers.{layer_idx}.self_attn.attn.q_scale" + + kv_scales_cache[q_param_name] = v["q_scale"] + kv_scales_cache[k_param_name] = v["k_scale"] + kv_scales_cache[v_param_name] = v["v_scale"] + + refit_policy_generation( + policy, policy_generation, colocated_inference, timer=timer + ) + POLICY_GENERATION_STALE = False + else: if colocated_inference: policy.offload_after_refit() # unload optimizer to make space for generation - policy_generation.prepare_for_generation() + policy_generation.prepare_for_generation() dynamic_sampling_num_gen_batches += 1 with timer.time("generation"): @@ -1253,6 +1463,25 @@ def grpo_train( with timer.time("policy_training"): train_results = policy.train(train_data, loss_fn) + # Recompute KV scales after policy training if needed + if sync_kv_scales: + print("[KV_SCALES] Recomputing KV cache scales after policy update...") + # kv_scales_cache = compute_kv_scales_with_data( + # policy, repeated_batch, master_config + # ) + kv_scales = policy.calibrate_qkv_fp8_scales(train_data, include_q=True)["layers"] + for k, v in kv_scales.items(): + layer_idx = k.split("_")[1] + k_param_name = f"model.layers.{layer_idx}.self_attn.k_scale" + v_param_name = f"model.layers.{layer_idx}.self_attn.v_scale" + q_param_name = f"model.layers.{layer_idx}.self_attn.attn.q_scale" + + kv_scales_cache[q_param_name] = v["q_scale"] + kv_scales_cache[k_param_name] = v["k_scale"] + kv_scales_cache[v_param_name] = v["v_scale"] + # Set generation as stale to force refit with new scales + POLICY_GENERATION_STALE = True + is_last_step = (total_steps + 1 >= max_num_steps) or ( (current_epoch + 1 == max_num_epochs) and (current_step + 1 == len(dataloader)) diff --git a/nemo_rl/data/datasets/response_datasets/dapo.py b/nemo_rl/data/datasets/response_datasets/dapo.py new file mode 100644 index 0000000000..d56dc5cc91 --- /dev/null +++ b/nemo_rl/data/datasets/response_datasets/dapo.py @@ -0,0 +1,105 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any + +from datasets import Dataset, load_dataset + +from nemo_rl.data.interfaces import TaskDataSpec + + +def format_math(data: dict[str, str | float | int]) -> dict[str, list[Any] | str]: + return { + "messages": [ + { + "role": "user", + "content": data["problem"], + }, + { + "role": "assistant", + "content": data["answer"], + }, + ], + # For v0.1 release, nemo rl datasets require a task_name key such that user can map a task processor per unique task. + "task_name": "math", + } + + +def format_dapo_math(data: dict[str, Any]) -> dict[str, list[Any] | str]: + # Extract user content from prompt field + user_content = "" + for message in data["prompt"]: + if message["role"] == "user": + user_content = message["content"] + break + + # Extract ground truth from reward_model field + assistant_content = data["reward_model"]["ground_truth"] + + return { + "messages": [ + { + "role": "user", + "content": user_content, + }, + { + "role": "assistant", + "content": assistant_content, + }, + ], + # For v0.1 release, nemo rl datasets require a task_name key such that user can map a task processor per unique task. + "task_name": "math", + } + + +def prepare_dapo_dataset(seed: int = 42) -> dict[str, Dataset | None]: + """Load and split the DAPO dataset into train and test sets.""" + # Load the original dataset for training + train_ds = load_dataset("BytedTsinghua-SIA/DAPO-Math-17k", split="train") + + # Load hendrydong/aime24 dataset for validation + val_ds = load_dataset("HuggingFaceH4/aime_2024", split="train") + + # Shuffle the training dataset with the specified seed + train_ds = train_ds.shuffle(seed=seed) + + # Format the examples, removing original columns + train_formatted = train_ds.map(format_dapo_math, remove_columns=train_ds.column_names) + val_formatted = val_ds.map(format_math, remove_columns=val_ds.column_names) + + # Compute accuracy 16 times per sample (matching the DeepScaleR evaluation setting) + val_repeated = [] + for _ in range(16): + val_repeated.extend(val_formatted) + val_formatted = val_formatted.from_list(val_repeated) + + return { + "train": train_formatted, + "validation": val_formatted, + } + + +class DAPODataset: + def __init__(self, seed: int = 42) -> None: + """Initialize the DAPO dataset with train/test split. + + Args: + seed: Random seed for reproducible splitting + """ + self.formatted_ds = prepare_dapo_dataset(seed=seed) + + self.task_spec = TaskDataSpec( + task_name="DAPO", + ) diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index f5b8463ae0..e2546aeb8a 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -14,7 +14,7 @@ import gc import traceback from typing import Any - +import traceback import torch import zmq @@ -169,12 +169,13 @@ def update_weights_via_ipc_zmq(self) -> bool: f"Error in VllmInternalWorkerExtension.update_weights_via_ipc_zmq: {e}.\n" f"{traceback.format_exc()}" ) + print(traceback.format_exc()) return False @wrap_with_nvtx_name( "vllm_internal_worker_extension/update_weights_from_collective" ) - def update_weights_from_collective(self) -> bool: + def update_weights_from_collective(self, kv_scales: Optional[dict[str, float]] = None) -> bool: """Update the model weights from collective communication.""" assert self.state_dict_info is not None, ( "state_dict_info is not prepared. " diff --git a/nemo_rl/models/generation/vllm/vllm_generation.py b/nemo_rl/models/generation/vllm/vllm_generation.py index 5dcc7eaf2e..b5c81a81fd 100644 --- a/nemo_rl/models/generation/vllm/vllm_generation.py +++ b/nemo_rl/models/generation/vllm/vllm_generation.py @@ -788,7 +788,7 @@ def update_weights_via_ipc_zmq(self) -> list[ray.ObjectRef]: # this function should co-work with lm_policy, so we should wait for all futures to complete outside return futures - def update_weights_from_collective(self) -> list[ray.ObjectRef]: + def update_weights_from_collective(self, kv_scales: Optional[dict[str, float]] = None) -> list[ray.ObjectRef]: """Update weights of the policy using collective communication.""" if not self.worker_group or not self.worker_group.workers: raise RuntimeError("Worker group is not initialized") @@ -801,10 +801,17 @@ def update_weights_from_collective(self) -> list[ray.ObjectRef]: ) # Use run_all_workers_single_data for methods that don't need data - futures = self.worker_group.run_all_workers_single_data( - method_name, - run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"], - ) + if kv_scales: + futures = self.worker_group.run_all_workers_single_data( + method_name, + kv_scales=kv_scales, + run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"], + ) + else: + futures = self.worker_group.run_all_workers_single_data( + method_name, + run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"], + ) # this function should co-work with lm_policy, so we should wait for all futures to complete outside return futures diff --git a/nemo_rl/models/generation/vllm/vllm_worker.py b/nemo_rl/models/generation/vllm/vllm_worker.py index a97d68e669..82eac8c452 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker.py +++ b/nemo_rl/models/generation/vllm/vllm_worker.py @@ -665,7 +665,7 @@ def update_weights_via_ipc_zmq(self) -> bool: return False @wrap_with_nvtx_name("vllm_genertion_worker/update_weights_from_collective") - def update_weights_from_collective(self) -> bool: + def update_weights_from_collective(self, kv_scales: Optional[dict[str, float]] = None) -> bool: """Update the model weights from collective communication.""" try: assert self.llm is not None, ( @@ -677,9 +677,14 @@ def update_weights_from_collective(self) -> bool: "update_weights_from_collective can only be used with async_engine=False. Use update_weights_from_collective_async instead." ) - result_or_coro = self.llm.collective_rpc( - "update_weights_from_collective", args=tuple() - ) + if kv_scales: + result_or_coro = self.llm.collective_rpc( + "update_weights_from_collective", args=(kv_scales,) + ) + else: + result_or_coro = self.llm.collective_rpc( + "update_weights_from_collective", args=tuple() + ) worker_result = result_or_coro[0] if not worker_result: diff --git a/nemo_rl/models/policy/interfaces.py b/nemo_rl/models/policy/interfaces.py index e221621403..aa229b6902 100644 --- a/nemo_rl/models/policy/interfaces.py +++ b/nemo_rl/models/policy/interfaces.py @@ -121,6 +121,31 @@ def score( """Score a batch of data using the policy.""" pass + @abstractmethod + def calibrate_qkv_fp8_scales( + self, + data: BatchedDataDict[GenerationDatumSpec], + micro_batch_size: Optional[int] = None, + percentile: float = 99.9, + margin: float = 1.05, + save_path: Optional[str] = None, + include_q: bool = False, + ) -> dict[str, Any]: + """Calibrate FP8 scales for Q/K/V activations used by KV cache. + + Args: + data: BatchedDataDict containing input_ids and input_lengths. + micro_batch_size: Optional override for micro batch size during calibration. + percentile: Percentile for per-tensor amax estimation. + margin: Safety margin multiplier applied to amax. + save_path: If provided, rank0 will write JSON results to this path. + include_q: Whether to also compute scale for Q in addition to K/V. + + Returns: + Dict with overall configuration and per-layer scales. + """ + pass + @abstractmethod def prepare_for_training(self, *args: Any, **kwargs: Any) -> None: pass diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index c1fde9bcf5..ff30dadcfb 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -690,6 +690,73 @@ def prepare_refit_info(self) -> Optional[dict[str, Any]]: # Only get the first worker's info since all workers will have the same result return results[0] + def calibrate_qkv_fp8_scales( + self, + data: BatchedDataDict[GenerationDatumSpec], + micro_batch_size: Optional[int] = None, + percentile: float = 99.9, + margin: float = 1.05, + save_path: Optional[str] = None, + include_q: bool = False, + ) -> dict[str, Any]: + """触发各 Megatron worker 的 KV-cache FP8 scale 标定,并返回结果。 + + 说明:后端 `MegatronPolicyWorker.calibrate_qkv_fp8_scales` 已实现分布式规约, + 返回的是跨 rank 合并后的结果。因此这里按 DP 分片输入并并行调用,最终取第一个 + worker 的返回即可。 + """ + dp_size = self.sharding_annotations.get_axis_size("data_parallel") + # 仅分 DP 维度;对于动态/打包模式,沿用现有分片逻辑 + if self.use_dynamic_batches: + self.dynamic_batching_args["max_tokens_per_microbatch"] = self.cfg[ + "dynamic_batching" + ]["logprob_mb_tokens"] + sharded_data, _ = data.shard_by_batch_size( # type: ignore + dp_size, + batch_size=None, + dynamic_batching_args=self.dynamic_batching_args, + ) + elif self.use_sequence_packing: + self.sequence_packing_args["max_tokens_per_microbatch"] = self.cfg[ + "sequence_packing" + ]["logprob_mb_tokens"] + sharded_data, _ = data.shard_by_batch_size( + dp_size, + batch_size=None, + sequence_packing_args=self.sequence_packing_args, + ) + else: + sharded_data = data.shard_by_batch_size( # type: ignore + dp_size, + batch_size=None, + ) + + futures = self.worker_group.run_all_workers_sharded_data( + "calibrate_qkv_fp8_scales", + data=sharded_data, + in_sharded_axes=["data_parallel"], + replicate_on_axes=[ + "context_parallel", + "tensor_parallel", + "pipeline_parallel", + ], + output_is_replicated=[ + "context_parallel", + "tensor_parallel", + "pipeline_parallel", + ], + common_kwargs={ + "micro_batch_size": micro_batch_size, + "percentile": percentile, + "margin": margin, + "save_path": save_path, + "include_q": include_q, + }, + ) + results = self.worker_group.get_all_worker_results(futures) + # 由于后端已在分布式内合并,这里返回任一结果均一致 + return results[0] + def get_free_memory_bytes(self) -> int: """Get the available free memory.""" futures = self.worker_group.run_all_workers_single_data("get_free_memory_bytes") diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 3bf211be5c..4896678f7d 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -17,6 +17,8 @@ import time import warnings from collections import defaultdict +import json +import re from contextlib import AbstractContextManager, contextmanager, nullcontext from functools import partial from typing import Any, Iterator, Optional, TypeVar @@ -2193,6 +2195,9 @@ def save_checkpoint( weights_path: The specific directory path where the checkpoint will be saved. optimizer_path: If not None, optimizer and scheduler states are saved if they exist. """ + allocated = torch.cuda.memory_allocated() / (1024**3) # Convert to GB + reserved = torch.cuda.memory_reserved() / (1024**3) # Convert to GB + print(f"[SHARON] GPU Memory before saving checkpoint: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved") if not torch.distributed.is_initialized(): raise RuntimeError( "Distributed process group is not initialized. Cannot save checkpoint." @@ -2253,6 +2258,29 @@ def save_checkpoint( if not is_training: # Restore training state if it was changed self.model.train() + + torch.randn(1).cuda() # wake up torch allocator + if hasattr(self, "optimizer") and self.optimizer is not None: + # Iterate through the state dictionaries for each parameter group + if isinstance(self.optimizer, ChainedOptimizer): + optimizer_state = self.optimizer.state + else: + optimizer_state = self.optimizer._get_state() + for _, state in optimizer_state.items(): + # Iterate through the state items (e.g., momentum, variance) for a parameter + for k, v in state.items(): + # Check if the item is a tensor and on the GPU + if torch.is_tensor(v) and v.is_cuda: + # Move the tensor to CPU and update the state dictionary + state[k] = v.to("cpu") + + gc.collect() + torch.cuda.empty_cache() + + allocated = torch.cuda.memory_allocated() / (1024**3) # Convert to GB + reserved = torch.cuda.memory_reserved() / (1024**3) # Convert to GB + print(f"[SHARON] GPU Memory after saving checkpoint: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved") + except Exception as e: print(f"Failed to save checkpoint to {weights_path}: {e}") @@ -2373,3 +2401,233 @@ def re_enable_float32_expert_bias(self) -> None: router, "_maintain_float32_expert_bias" ): router._maintain_float32_expert_bias() + + @torch.no_grad() + def calibrate_qkv_fp8_scales( + self, + *, + data: BatchedDataDict[Any], + micro_batch_size: Optional[int] = None, + percentile: float = 99.9, + margin: float = 1.05, + save_path: Optional[str] = None, + include_q: bool = False, + ) -> dict[str, Any]: + """One-shot 标定 Q/K/V 的激活 scale(用于 FP8 KV cache)。 + + - 通过 forward hook 捕获每层 `query_key_value` 的输出,分割 Q/K/V,并统计分位 amax。 + - 在并行(DP/TP/PP)环境下,先本地统计分位,再对所有 rank 取最大值以保证保守性。 + - 默认仅返回并保存 K/V 的 scale(E5M2),可选返回 Q(E4M3)。 + + Args: + data: 用于标定的一小批代表性样本,遵循 get_logprobs 的输入约定。 + micro_batch_size: 标定时的 micro batch 大小;若 None 则复用 logprob_batch_size。 + percentile: amax 的分位数(如 99.9)。 + margin: 余量系数,例如 1.05。 + save_path: 若提供,则 rank0 会将结果保存为 JSON。 + include_q: 是否也返回 Q 的 scale(一般只需 K/V)。 + + Returns: + { "format": "fp8", "percentile": float, "margin": float, + "layers": { layer_name: {"k_scale": float, "v_scale": float[, "q_scale": float] } } } + """ + # FP8 动态范围(对称) + FP8_MAX_E4M3 = 448.0 + FP8_MAX_E5M2 = 57344.0 + + self.model.eval() + + # 记录每层的 q/k/v 的局部分位 amax + layer_to_samples_q: dict[str, list[float]] = defaultdict(list) + layer_to_samples_k: dict[str, list[float]] = defaultdict(list) + layer_to_samples_v: dict[str, list[float]] = defaultdict(list) + hook_handles = [] + + def _extract_layer_key(module_name: str) -> str: + # 期望形如 "module.decoder.layers..self_attention.query_key_value" + m = re.search(r"module\.decoder\.layers\.(\d+)", module_name) + if m is not None: + return f"layer_{m.group(1)}" + return module_name + + def _hook_builder(module_name: str): + layer_key = _extract_layer_key(module_name) + + def _hook(module, inputs, output): + out = output[0] if isinstance(output, (tuple, list)) else output + try: + last_dim = out.shape[-1] + assert last_dim % 3 == 0 + qkv_stride = last_dim // 3 + q = out[..., :qkv_stride] + k = out[..., qkv_stride : 2 * qkv_stride] + v = out[..., 2 * qkv_stride : 3 * qkv_stride] + # per-tensor 绝对最大值(局部) + layer_to_samples_q[layer_key].append(float(torch.amax(torch.abs(q)).item())) + layer_to_samples_k[layer_key].append(float(torch.amax(torch.abs(k)).item())) + layer_to_samples_v[layer_key].append(float(torch.amax(torch.abs(v)).item())) + except Exception as e: + print(f"[ALEXQ] Error extracting layer key: {e}") + pass + + return _hook + + # 新增:优先在 core_attention 的 forward_pre 上挂 hook,以捕获已做完 q/k norm 与 RoPE 的 q/k/v + def _pre_hook_builder_core_attention(module_name: str): + layer_key = _extract_layer_key(module_name) + + def _pre_hook(module, inputs): + try: + args = inputs if isinstance(inputs, (tuple, list)) else (inputs,) + if len(args) == 1 and isinstance(args[0], (tuple, list)): + args = args[0] + # 期望前 3 个为 q, k, v(Megatron CoreAttention 的典型签名) + q = args[0] + k = args[1] + v = args[2] + if include_q: + layer_to_samples_q[layer_key].append(float(torch.amax(torch.abs(q)).item())) + layer_to_samples_k[layer_key].append(float(torch.amax(torch.abs(k)).item())) + layer_to_samples_v[layer_key].append(float(torch.amax(torch.abs(v)).item())) + except Exception as e: + print(f"[ALEXQ] Error in core_attention pre-hook on {module_name}: {e}") + pass + + return _pre_hook + + matched_modules = [] + # 1) 优先尝试在 core_attention 上注册 forward_pre_hook + for name, module in self.model.named_modules(): + print(f"[ALEXQ] Module name: {name}") + if "self_attention.core_attention" in name: + try: + handle = module.register_forward_pre_hook(_pre_hook_builder_core_attention(name)) + hook_handles.append(handle) + matched_modules.append((name, module.__class__.__name__, "pre")) + except Exception as e: + print(f"[ALEXQ] Error registering pre-hook on {name}: {e}") + continue + + # 2) 若未命中 core_attention,则退回到原先在 QKV 投影输出上的 forward_hook + if not hook_handles: + qkv_name_patterns = ("query_key_value", "linear_qkv", ".qkv", "_qkv") + for name, module in self.model.named_modules(): + if any(pat in name for pat in qkv_name_patterns): + try: + handle = module.register_forward_hook(_hook_builder(name)) + hook_handles.append(handle) + matched_modules.append((name, module.__class__.__name__, "post")) + except Exception as e: + print(f"[ALEXQ] Error registering hook on {name}: {e}") + continue + + if not hook_handles: + print("[ALEXQ] No QKV proj modules matched for hook. Example module/param names:") + try: + # 打印前若干个模块与参数名,帮助定位实际命名 + cnt = 0 + for n, _m in self.model.named_modules(): + if cnt >= 10: + break + print(f" [module] {n}") + cnt += 1 + cnt = 0 + for n, _p in self.model.named_parameters(): + if cnt >= 10: + break + print(f" [param] {n}") + cnt += 1 + except Exception: + pass + else: + # 轻量打印匹配到的前几个模块,确认命中 + print("[ALEXQ] Registered hooks on modules (showing up to 8):") + for i, (mn, cls, kind) in enumerate(matched_modules[:8]): + print(f" {i:02d}: {mn} <{cls}> [{kind}]") + + # 运行一次推理流程以触发 hooks(重用 get_logprobs 的前向路径) + try: + _ = self.get_logprobs(data=data, micro_batch_size=micro_batch_size) + finally: + for h in hook_handles: + try: + h.remove() + except Exception as e: + print(f"[ALEXQ] Error removing hook: {e}") + pass + + # 计算本地分位 amax + def _percentile(values: list[float], p: float) -> float: + if not values: + return 0.0 + t = torch.tensor(sorted(values), device="cuda", dtype=torch.float32) + rank = max(0, min(len(values) - 1, int(round((p / 100.0) * (len(values) - 1))))) + return float(t[rank].item()) + + local_layer_to_pamax = {} + for layer_key in set(list(layer_to_samples_k.keys()) + list(layer_to_samples_v.keys()) + (list(layer_to_samples_q.keys()) if include_q else [])): + entry = {} + if include_q: + entry["q_amax_p"] = _percentile(layer_to_samples_q.get(layer_key, []), percentile) + entry["k_amax_p"] = _percentile(layer_to_samples_k.get(layer_key, []), percentile) + entry["v_amax_p"] = _percentile(layer_to_samples_v.get(layer_key, []), percentile) + local_layer_to_pamax[layer_key] = entry + + # 合并所有 rank:对分位 amax 取最大(保守) + world_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 + gathered = [None for _ in range(world_size)] if world_size > 1 else None + if world_size > 1: + torch.distributed.all_gather_object(gathered, local_layer_to_pamax) + merged = defaultdict(dict) + for d in gathered: # type: ignore + if d is None: + continue + for k, v in d.items(): + dst = merged[k] + for kk, vv in v.items(): + dst[kk] = max(dst.get(kk, 0.0), float(vv)) + layer_to_pamax = dict(merged) + else: + layer_to_pamax = local_layer_to_pamax + + # 计算 scale(对称量化):scale = pamax / fp8_max + result_layers = {} + for layer_key, vals in layer_to_pamax.items(): + out_entry = {} + if include_q: + q_scale = (vals.get("q_amax_p", 0.0) * margin) / FP8_MAX_E4M3 + out_entry["q_scale"] = float(q_scale) + k_scale = (vals.get("k_amax_p", 0.0) * margin) / FP8_MAX_E4M3 + v_scale = (vals.get("v_amax_p", 0.0) * margin) / FP8_MAX_E4M3 + out_entry["k_scale"] = float(k_scale) + out_entry["v_scale"] = float(v_scale) + result_layers[layer_key] = out_entry + + final_result = { + "format": "fp8", + "percentile": percentile, + "margin": margin, + "layers": result_layers, + } + print(f"[ALEXQ] Calibrated KV cache scales: {final_result}") + + # 将结果在所有 rank 同步(广播 rank0 的结果) + if world_size > 1: + if torch.distributed.get_rank() == 0: + obj_list = [final_result] + torch.distributed.broadcast_object_list(obj_list, src=0) + final_result = obj_list[0] + else: + obj_list = [None] + torch.distributed.broadcast_object_list(obj_list, src=0) + final_result = obj_list[0] # type: ignore + + # 可选保存至 JSON(仅 rank0) + if save_path is not None and (not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0): + try: + with open(save_path, "w") as f: + json.dump(final_result, f) + except Exception: + pass + + return final_result From 312d2048a2689c11b1cec7c5ce8319cfda401142 Mon Sep 17 00:00:00 2001 From: Shuang Yu Date: Mon, 10 Nov 2025 22:20:39 -0800 Subject: [PATCH 02/40] kv cache fp8 code refine and cleanup Signed-off-by: Zhaopeng Qiu --- nemo_rl/algorithms/grpo.py | 98 +++++++------------ .../models/generation/vllm/vllm_backend.py | 2 - nemo_rl/models/policy/lm_policy.py | 9 +- .../models/policy/megatron_policy_worker.py | 69 ++++++------- 4 files changed, 74 insertions(+), 104 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 76358078a7..1011e648c1 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -180,36 +180,6 @@ class MasterConfig(TypedDict): # Setup & Initialization # =============================================================================== -# Function to check if KV cache scales should be calculated and synchronized during refit -# TODO: Where and how to calcualte this kv cache scales? -# TODO: This should be checked only once and reused during the whole training process. Should the flag be stored somewhere? -def _should_sync_kv_scales(master_config: MasterConfig) -> bool: - """ - Check if KV cache scales should be synchronized during refit. - - Returns True if: - - vLLM backend is used for generation - - Either kv_cache_dtype is fp8 OR vLLM precision is fp8 (which implies fp8 kv cache) - - This indicates we need to sync _k_scale and _v_scale values - """ - generation_config = master_config["policy"]["generation"] - if generation_config is None: - return False - - backend = generation_config.get("backend", "") - if backend != "vllm": - return False - - vllm_cfg = generation_config.get("vllm_cfg", {}) - kv_cache_dtype = vllm_cfg.get("kv_cache_dtype", "auto") - vllm_precision = vllm_cfg.get("precision", "auto") - - # Check if either kv_cache_dtype is explicitly fp8 or vLLM precision is fp8 - # When vLLM precision is fp8, it typically implies fp8 kv cache as well - # should enable kv scale sync when both are true - return kv_cache_dtype == "fp8" and vllm_precision == "fp8" - - def setup( master_config: MasterConfig, tokenizer: TokenizerType, @@ -859,11 +829,11 @@ def _should_use_async_rollouts(master_config: MasterConfig) -> bool: generation_config = master_config["policy"]["generation"] if generation_config is None: return False - + backend = generation_config.get("backend", "") if backend != "vllm": return False - + vllm_cfg = generation_config.get("vllm_cfg", {}) return vllm_cfg.get("async_engine", False) @@ -1124,7 +1094,7 @@ def grpo_train( # Check if we need to sync KV cache scales (infer from config) sync_kv_scales = _should_sync_kv_scales(master_config) - kv_scales_cache = None # Cache computed KV scales for reuse + kv_scales_cache = None # Cache reused for compuated kv scales if sync_kv_scales: generation_config = master_config["policy"]["generation"] @@ -1134,11 +1104,20 @@ def grpo_train( vllm_precision = vllm_cfg.get("precision", "auto") policy_backend = "megatron" if master_config["policy"].get("megatron_cfg", {}).get("enabled", False) else "dtensor" - print(f"[KV_SCALES] FP8 KV cache detected, will sync _k_scale and _v_scale during refit") + print(f"[KV_SCALES] FP8 KV cache detected, will sync q_scale, _k_scale and _v_scale during refit") print(f"[KV_SCALES] Configuration: policy_backend={policy_backend}, generation_backend={backend}") print(f"[KV_SCALES] vLLM settings: precision={vllm_precision}, kv_cache_dtype={kv_cache_dtype}") + + # Temoprary assert check to flag error when kv cache fp8 is enabled but either of thefollowing conditions are met: + # 1. policy backend is dtensor + # 2. async rollouts is enabled + # 3. pipeline_model_parallel_size is greater than 1 for the megatron backend + # TODO: Add the related support + assert policy_backend != "dtensor", "DTensor backend is not supported with kv cache fp8 enabled." + assert not _should_use_async_rollouts(master_config), "Async rollouts is not supported with kv cache fp8 enabled." + assert master_config["policy"]["megatron_cfg"]["pipeline_model_parallel_size"] == 1, "Pipeline model parallel size must be 1 for megatron backend with kv cache fp8 enabled." else: - print("[KV_SCALES] KV cache scale sync not needed (non-FP8 mode)") + print("[KV_SCALES] KV cache scale sync not needed (non-FP8 mode or kv_cache_dtype is not fp8)") NEED_REFIT = True # If policy_generation is None, use the policy as the generation interface (megatron framework backend) @@ -1169,6 +1148,7 @@ def grpo_train( colocated_inference = master_config["policy"]["generation"]["colocated"]["enabled"] # Run validation at the start if configured + # TODO: Add validation with kv scales if needed if val_at_start and current_step == 0: print("\n🔍 Running initial validation...", flush=True) if NEED_REFIT and POLICY_GENERATION_STALE: @@ -1234,20 +1214,12 @@ def grpo_train( if sync_kv_scales and kv_scales_cache is None: print("[KV_SCALES] Computing KV cache scales for the first time...") policy.prepare_for_lp_inference() - # kv_scales_cache = compute_kv_scales_with_data( - # policy, repeated_batch, master_config - # ) kv_scales_cache = {} - # Create calibration data from flattened messages calibration_data = BatchedDataDict[ClippedPGLossDataDict]( { "input_ids": batched_flat["token_ids"], - "input_lengths": input_lengths, - # "advantages": batched_flat["advantages"], - # "generation_logprobs": batched_flat["generation_logprobs"], - # "token_mask": batched_flat["token_loss_mask"], - # "sample_mask": batched_flat["loss_multiplier"], + "input_lengths": input_lengths } ) # this will be mini-batched inside the policy, so maintain the packed multimodal structure @@ -1258,6 +1230,7 @@ def grpo_train( layer_idx = k.split("_")[1] k_param_name = f"model.layers.{layer_idx}.self_attn.k_scale" v_param_name = f"model.layers.{layer_idx}.self_attn.v_scale" + # q_param_name is different from k_param_name and v_param_name because vllm handles the param mappings differently for q and k/v q_param_name = f"model.layers.{layer_idx}.self_attn.attn.q_scale" kv_scales_cache[q_param_name] = v["q_scale"] @@ -1265,7 +1238,8 @@ def grpo_train( kv_scales_cache[v_param_name] = v["v_scale"] refit_policy_generation( - policy, policy_generation, colocated_inference, timer=timer + policy, policy_generation, colocated_inference, timer=timer, + kv_scales=kv_scales_cache if sync_kv_scales else None ) POLICY_GENERATION_STALE = False else: @@ -1463,24 +1437,21 @@ def grpo_train( with timer.time("policy_training"): train_results = policy.train(train_data, loss_fn) - # Recompute KV scales after policy training if needed - if sync_kv_scales: - print("[KV_SCALES] Recomputing KV cache scales after policy update...") - # kv_scales_cache = compute_kv_scales_with_data( - # policy, repeated_batch, master_config - # ) - kv_scales = policy.calibrate_qkv_fp8_scales(train_data, include_q=True)["layers"] - for k, v in kv_scales.items(): - layer_idx = k.split("_")[1] - k_param_name = f"model.layers.{layer_idx}.self_attn.k_scale" - v_param_name = f"model.layers.{layer_idx}.self_attn.v_scale" - q_param_name = f"model.layers.{layer_idx}.self_attn.attn.q_scale" - - kv_scales_cache[q_param_name] = v["q_scale"] - kv_scales_cache[k_param_name] = v["k_scale"] - kv_scales_cache[v_param_name] = v["v_scale"] - # Set generation as stale to force refit with new scales - POLICY_GENERATION_STALE = True + # Recompute KV scales after policy training if needed + if sync_kv_scales: + print("[KV_SCALES] Recomputing KV cache scales after policy update...") + kv_scales = policy.calibrate_qkv_fp8_scales(train_data, include_q=True)["layers"] + for k, v in kv_scales.items(): + layer_idx = k.split("_")[1] + k_param_name = f"model.layers.{layer_idx}.self_attn.k_scale" + v_param_name = f"model.layers.{layer_idx}.self_attn.v_scale" + q_param_name = f"model.layers.{layer_idx}.self_attn.attn.q_scale" + + kv_scales_cache[q_param_name] = v["q_scale"] + kv_scales_cache[k_param_name] = v["k_scale"] + kv_scales_cache[v_param_name] = v["v_scale"] + # Set generation as stale to force refit with new scales + POLICY_GENERATION_STALE = True is_last_step = (total_steps + 1 >= max_num_steps) or ( (current_epoch + 1 == max_num_epochs) @@ -1488,6 +1459,7 @@ def grpo_train( ) # Run validation if it's a validation step + # TODO: Add validation with kv scales if needed if val_period > 0 and (total_steps + 1) % val_period == 0: if NEED_REFIT and POLICY_GENERATION_STALE: refit_policy_generation( diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index e2546aeb8a..290f51d8ed 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -14,7 +14,6 @@ import gc import traceback from typing import Any -import traceback import torch import zmq @@ -169,7 +168,6 @@ def update_weights_via_ipc_zmq(self) -> bool: f"Error in VllmInternalWorkerExtension.update_weights_via_ipc_zmq: {e}.\n" f"{traceback.format_exc()}" ) - print(traceback.format_exc()) return False @wrap_with_nvtx_name( diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index ff30dadcfb..32739cd79f 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -699,11 +699,11 @@ def calibrate_qkv_fp8_scales( save_path: Optional[str] = None, include_q: bool = False, ) -> dict[str, Any]: - """触发各 Megatron worker 的 KV-cache FP8 scale 标定,并返回结果。 + """Trigger KV-cache FP8 scale calibration across Megatron workers and return results. - 说明:后端 `MegatronPolicyWorker.calibrate_qkv_fp8_scales` 已实现分布式规约, - 返回的是跨 rank 合并后的结果。因此这里按 DP 分片输入并并行调用,最终取第一个 - worker 的返回即可。 + Note: The backend `MegatronPolicyWorker.calibrate_qkv_fp8_scales` already implements + distributed reduction, returning results merged across ranks. Therefore, we shard the + input by DP and call in parallel, then take the result from the first worker. """ dp_size = self.sharding_annotations.get_axis_size("data_parallel") # 仅分 DP 维度;对于动态/打包模式,沿用现有分片逻辑 @@ -754,7 +754,6 @@ def calibrate_qkv_fp8_scales( }, ) results = self.worker_group.get_all_worker_results(futures) - # 由于后端已在分布式内合并,这里返回任一结果均一致 return results[0] def get_free_memory_bytes(self) -> int: diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 4896678f7d..c7eadb7cf9 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -2195,9 +2195,10 @@ def save_checkpoint( weights_path: The specific directory path where the checkpoint will be saved. optimizer_path: If not None, optimizer and scheduler states are saved if they exist. """ + # Temporary fix to avoid OOM after saving checkpoint allocated = torch.cuda.memory_allocated() / (1024**3) # Convert to GB reserved = torch.cuda.memory_reserved() / (1024**3) # Convert to GB - print(f"[SHARON] GPU Memory before saving checkpoint: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved") + print(f"GPU Memory before saving checkpoint: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved") if not torch.distributed.is_initialized(): raise RuntimeError( "Distributed process group is not initialized. Cannot save checkpoint." @@ -2259,6 +2260,7 @@ def save_checkpoint( if not is_training: # Restore training state if it was changed self.model.train() + # Temporary fix to avoid OOM after saving checkpoint: https://github.com/NVIDIA-NeMo/RL/issues/1057 torch.randn(1).cuda() # wake up torch allocator if hasattr(self, "optimizer") and self.optimizer is not None: # Iterate through the state dictionaries for each parameter group @@ -2279,7 +2281,7 @@ def save_checkpoint( allocated = torch.cuda.memory_allocated() / (1024**3) # Convert to GB reserved = torch.cuda.memory_reserved() / (1024**3) # Convert to GB - print(f"[SHARON] GPU Memory after saving checkpoint: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved") + print(f"GPU Memory after saving checkpoint: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved") except Exception as e: @@ -2413,38 +2415,37 @@ def calibrate_qkv_fp8_scales( save_path: Optional[str] = None, include_q: bool = False, ) -> dict[str, Any]: - """One-shot 标定 Q/K/V 的激活 scale(用于 FP8 KV cache)。 + """One-shot calibration of Q/K/V activation scales (for FP8 KV cache). - - 通过 forward hook 捕获每层 `query_key_value` 的输出,分割 Q/K/V,并统计分位 amax。 - - 在并行(DP/TP/PP)环境下,先本地统计分位,再对所有 rank 取最大值以保证保守性。 - - 默认仅返回并保存 K/V 的 scale(E5M2),可选返回 Q(E4M3)。 + - Captures each layer's `query_key_value` output through forward hooks, splits Q/K/V, and computes percentile amax. + - In parallel (DP/TP/PP) environments, first computes local percentiles, then takes max across all ranks for conservativeness. + - By default only returns and saves K/V scales (E5M2), optionally returns Q (E4M3). Args: - data: 用于标定的一小批代表性样本,遵循 get_logprobs 的输入约定。 - micro_batch_size: 标定时的 micro batch 大小;若 None 则复用 logprob_batch_size。 - percentile: amax 的分位数(如 99.9)。 - margin: 余量系数,例如 1.05。 - save_path: 若提供,则 rank0 会将结果保存为 JSON。 - include_q: 是否也返回 Q 的 scale(一般只需 K/V)。 + data: Representative sample batch for calibration, following get_logprobs input conventions. + micro_batch_size: Micro batch size during calibration; if None, reuses logprob_batch_size. + percentile: Percentile for amax (e.g. 99.9). + margin: Margin factor, e.g. 1.05. + save_path: If provided, rank0 will save results as JSON. + include_q: Whether to also return Q scale (usually only K/V needed). Returns: { "format": "fp8", "percentile": float, "margin": float, "layers": { layer_name: {"k_scale": float, "v_scale": float[, "q_scale": float] } } } """ - # FP8 动态范围(对称) + FP8_MAX_E4M3 = 448.0 - FP8_MAX_E5M2 = 57344.0 self.model.eval() - # 记录每层的 q/k/v 的局部分位 amax + # Record local percentile amax for q/k/v of each layer layer_to_samples_q: dict[str, list[float]] = defaultdict(list) layer_to_samples_k: dict[str, list[float]] = defaultdict(list) layer_to_samples_v: dict[str, list[float]] = defaultdict(list) hook_handles = [] def _extract_layer_key(module_name: str) -> str: - # 期望形如 "module.decoder.layers..self_attention.query_key_value" + # Expected format: "module.decoder.layers..self_attention.query_key_value" m = re.search(r"module\.decoder\.layers\.(\d+)", module_name) if m is not None: return f"layer_{m.group(1)}" @@ -2472,7 +2473,7 @@ def _hook(module, inputs, output): return _hook - # 新增:优先在 core_attention 的 forward_pre 上挂 hook,以捕获已做完 q/k norm 与 RoPE 的 q/k/v + # Hook to capture q/k/v after q/k norm and RoPE def _pre_hook_builder_core_attention(module_name: str): layer_key = _extract_layer_key(module_name) @@ -2490,25 +2491,25 @@ def _pre_hook(module, inputs): layer_to_samples_k[layer_key].append(float(torch.amax(torch.abs(k)).item())) layer_to_samples_v[layer_key].append(float(torch.amax(torch.abs(v)).item())) except Exception as e: - print(f"[ALEXQ] Error in core_attention pre-hook on {module_name}: {e}") + print(f"[KV_SCALES] Error in core_attention pre-hook on {module_name}: {e}") pass return _pre_hook matched_modules = [] - # 1) 优先尝试在 core_attention 上注册 forward_pre_hook + # 1) Try to register forward_pre_hook on core_attention first for name, module in self.model.named_modules(): - print(f"[ALEXQ] Module name: {name}") + print(f"[KV_SCALES] Module name: {name}") if "self_attention.core_attention" in name: try: handle = module.register_forward_pre_hook(_pre_hook_builder_core_attention(name)) hook_handles.append(handle) matched_modules.append((name, module.__class__.__name__, "pre")) except Exception as e: - print(f"[ALEXQ] Error registering pre-hook on {name}: {e}") + print(f"[KV_SCALES] Error registering pre-hook on {name}: {e}") continue - # 2) 若未命中 core_attention,则退回到原先在 QKV 投影输出上的 forward_hook + # 2) If core_attention is not hit, fall back to forward_hook on QKV projection output if not hook_handles: qkv_name_patterns = ("query_key_value", "linear_qkv", ".qkv", "_qkv") for name, module in self.model.named_modules(): @@ -2518,13 +2519,13 @@ def _pre_hook(module, inputs): hook_handles.append(handle) matched_modules.append((name, module.__class__.__name__, "post")) except Exception as e: - print(f"[ALEXQ] Error registering hook on {name}: {e}") + print(f"[KV_SCALES] Error registering hook on {name}: {e}") continue if not hook_handles: - print("[ALEXQ] No QKV proj modules matched for hook. Example module/param names:") + print("[KV_SCALES] No QKV proj modules matched for hook. Example module/param names:") try: - # 打印前若干个模块与参数名,帮助定位实际命名 + # Print the first 10 modules and parameters to help locate the actual names cnt = 0 for n, _m in self.model.named_modules(): if cnt >= 10: @@ -2540,12 +2541,12 @@ def _pre_hook(module, inputs): except Exception: pass else: - # 轻量打印匹配到的前几个模块,确认命中 - print("[ALEXQ] Registered hooks on modules (showing up to 8):") + # Lightly print the first few modules and parameters to confirm hits + print("[KV_SCALES] Registered hooks on modules (showing up to 8):") for i, (mn, cls, kind) in enumerate(matched_modules[:8]): print(f" {i:02d}: {mn} <{cls}> [{kind}]") - # 运行一次推理流程以触发 hooks(重用 get_logprobs 的前向路径) + # Run a forward pass to trigger hooks (reuse get_logprobs forward path) try: _ = self.get_logprobs(data=data, micro_batch_size=micro_batch_size) finally: @@ -2553,10 +2554,10 @@ def _pre_hook(module, inputs): try: h.remove() except Exception as e: - print(f"[ALEXQ] Error removing hook: {e}") + print(f"[KV_SCALES] Error removing hook: {e}") pass - # 计算本地分位 amax + # Compute local percentile amax def _percentile(values: list[float], p: float) -> float: if not values: return 0.0 @@ -2573,7 +2574,7 @@ def _percentile(values: list[float], p: float) -> float: entry["v_amax_p"] = _percentile(layer_to_samples_v.get(layer_key, []), percentile) local_layer_to_pamax[layer_key] = entry - # 合并所有 rank:对分位 amax 取最大(保守) + # Merge across all ranks: take maximum of percentile amax (conservative approach) world_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 gathered = [None for _ in range(world_size)] if world_size > 1 else None if world_size > 1: @@ -2590,7 +2591,7 @@ def _percentile(values: list[float], p: float) -> float: else: layer_to_pamax = local_layer_to_pamax - # 计算 scale(对称量化):scale = pamax / fp8_max + # Compute scale (symmetric quantization): scale = pamax / fp8_max result_layers = {} for layer_key, vals in layer_to_pamax.items(): out_entry = {} @@ -2609,9 +2610,9 @@ def _percentile(values: list[float], p: float) -> float: "margin": margin, "layers": result_layers, } - print(f"[ALEXQ] Calibrated KV cache scales: {final_result}") + print(f"[KV_SCALES] Calibrated KV cache scales: {final_result}") - # 将结果在所有 rank 同步(广播 rank0 的结果) + # Sync results across all ranks (broadcast rank0's result) if world_size > 1: if torch.distributed.get_rank() == 0: obj_list = [final_result] From e2cbca849e4f38a24310471bcbaddba76be66c1b Mon Sep 17 00:00:00 2001 From: Shuang Yu Date: Mon, 10 Nov 2025 22:21:41 -0800 Subject: [PATCH 03/40] Fix indentiation error. Enable using environment variables to set FP8 max. Signed-off-by: Zhaopeng Qiu --- nemo_rl/algorithms/grpo.py | 105 +++++++++--------- nemo_rl/models/policy/lm_policy.py | 1 - .../models/policy/megatron_policy_worker.py | 26 +++-- 3 files changed, 72 insertions(+), 60 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 1011e648c1..861843cba7 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -1203,49 +1203,49 @@ def grpo_train( ) input_ids = batched_flat["token_ids"] - # Generate responses - this updates the LLMMessageLogType in repeated_batch - print( - f"▶ Generating responses for batch of size {repeated_batch.size}...", - flush=True, - ) - with timer.time("prepare_for_generation"): - if NEED_REFIT and POLICY_GENERATION_STALE: - # Compute KV scales if needed for FP8 quantization - if sync_kv_scales and kv_scales_cache is None: - print("[KV_SCALES] Computing KV cache scales for the first time...") - policy.prepare_for_lp_inference() - kv_scales_cache = {} - # Create calibration data from flattened messages - calibration_data = BatchedDataDict[ClippedPGLossDataDict]( - { - "input_ids": batched_flat["token_ids"], - "input_lengths": input_lengths - } + # Generate responses - this updates the LLMMessageLogType in repeated_batch + print( + f"▶ Generating responses for batch of size {repeated_batch.size}...", + flush=True, + ) + with timer.time("prepare_for_generation"): + if NEED_REFIT and POLICY_GENERATION_STALE: + # Compute KV scales if needed for FP8 quantization + if sync_kv_scales and kv_scales_cache is None: + print("[KV_SCALES] Computing KV cache scales for the first time...") + policy.prepare_for_lp_inference() + kv_scales_cache = {} + # Create calibration data from flattened messages + calibration_data = BatchedDataDict[ClippedPGLossDataDict]( + { + "input_ids": batched_flat["token_ids"], + "input_lengths": input_lengths + } + ) + # this will be mini-batched inside the policy, so maintain the packed multimodal structure + calibration_data.update(batched_flat.get_multimodal_dict(as_tensors=False)) + calibration_data.to("cpu") + kv_scales = policy.calibrate_qkv_fp8_scales(calibration_data, include_q=True)["layers"] + for k, v in kv_scales.items(): + layer_idx = k.split("_")[1] + k_param_name = f"model.layers.{layer_idx}.self_attn.k_scale" + v_param_name = f"model.layers.{layer_idx}.self_attn.v_scale" + # q_param_name is different from k_param_name and v_param_name because vllm handles the param mappings differently for q and k/v + q_param_name = f"model.layers.{layer_idx}.self_attn.attn.q_scale" + + kv_scales_cache[q_param_name] = v["q_scale"] + kv_scales_cache[k_param_name] = v["k_scale"] + kv_scales_cache[v_param_name] = v["v_scale"] + + refit_policy_generation( + policy, policy_generation, colocated_inference, timer=timer, + kv_scales=kv_scales_cache if sync_kv_scales else None ) - # this will be mini-batched inside the policy, so maintain the packed multimodal structure - calibration_data.update(batched_flat.get_multimodal_dict(as_tensors=False)) - calibration_data.to("cpu") - kv_scales = policy.calibrate_qkv_fp8_scales(calibration_data, include_q=True)["layers"] - for k, v in kv_scales.items(): - layer_idx = k.split("_")[1] - k_param_name = f"model.layers.{layer_idx}.self_attn.k_scale" - v_param_name = f"model.layers.{layer_idx}.self_attn.v_scale" - # q_param_name is different from k_param_name and v_param_name because vllm handles the param mappings differently for q and k/v - q_param_name = f"model.layers.{layer_idx}.self_attn.attn.q_scale" - - kv_scales_cache[q_param_name] = v["q_scale"] - kv_scales_cache[k_param_name] = v["k_scale"] - kv_scales_cache[v_param_name] = v["v_scale"] - - refit_policy_generation( - policy, policy_generation, colocated_inference, timer=timer, - kv_scales=kv_scales_cache if sync_kv_scales else None - ) - POLICY_GENERATION_STALE = False - else: + POLICY_GENERATION_STALE = False + else: if colocated_inference: policy.offload_after_refit() # unload optimizer to make space for generation - policy_generation.prepare_for_generation() + policy_generation.prepare_for_generation() dynamic_sampling_num_gen_batches += 1 with timer.time("generation"): @@ -1439,19 +1439,20 @@ def grpo_train( # Recompute KV scales after policy training if needed if sync_kv_scales: - print("[KV_SCALES] Recomputing KV cache scales after policy update...") - kv_scales = policy.calibrate_qkv_fp8_scales(train_data, include_q=True)["layers"] - for k, v in kv_scales.items(): - layer_idx = k.split("_")[1] - k_param_name = f"model.layers.{layer_idx}.self_attn.k_scale" - v_param_name = f"model.layers.{layer_idx}.self_attn.v_scale" - q_param_name = f"model.layers.{layer_idx}.self_attn.attn.q_scale" + with timer.time("recompute_kv_scales"): + print("[KV_SCALES] Recomputing KV cache scales after policy update...") + kv_scales = policy.calibrate_qkv_fp8_scales(train_data, include_q=True)["layers"] + for k, v in kv_scales.items(): + layer_idx = k.split("_")[1] + k_param_name = f"model.layers.{layer_idx}.self_attn.k_scale" + v_param_name = f"model.layers.{layer_idx}.self_attn.v_scale" + q_param_name = f"model.layers.{layer_idx}.self_attn.attn.q_scale" - kv_scales_cache[q_param_name] = v["q_scale"] - kv_scales_cache[k_param_name] = v["k_scale"] - kv_scales_cache[v_param_name] = v["v_scale"] - # Set generation as stale to force refit with new scales - POLICY_GENERATION_STALE = True + kv_scales_cache[q_param_name] = v["q_scale"] + kv_scales_cache[k_param_name] = v["k_scale"] + kv_scales_cache[v_param_name] = v["v_scale"] + # Set generation as stale to force refit with new scales + POLICY_GENERATION_STALE = True is_last_step = (total_steps + 1 >= max_num_steps) or ( (current_epoch + 1 == max_num_epochs) diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index 32739cd79f..feb25fcb2a 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -706,7 +706,6 @@ def calibrate_qkv_fp8_scales( input by DP and call in parallel, then take the result from the first worker. """ dp_size = self.sharding_annotations.get_axis_size("data_parallel") - # 仅分 DP 维度;对于动态/打包模式,沿用现有分片逻辑 if self.use_dynamic_batches: self.dynamic_batching_args["max_tokens_per_microbatch"] = self.cfg[ "dynamic_batching" diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index c7eadb7cf9..461c2493d5 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -2419,7 +2419,7 @@ def calibrate_qkv_fp8_scales( - Captures each layer's `query_key_value` output through forward hooks, splits Q/K/V, and computes percentile amax. - In parallel (DP/TP/PP) environments, first computes local percentiles, then takes max across all ranks for conservativeness. - - By default only returns and saves K/V scales (E5M2), optionally returns Q (E4M3). + - By default only returns and saves K/V scales, optionally returns Q. Args: data: Representative sample batch for calibration, following get_logprobs input conventions. @@ -2434,7 +2434,19 @@ def calibrate_qkv_fp8_scales( "layers": { layer_name: {"k_scale": float, "v_scale": float[, "q_scale": float] } } } """ - FP8_MAX_E4M3 = 448.0 + # Allow overriding FP8 max for Q, K, V via environment variables for ease of testing. + # Defaults align with FP8 e4m3 max magnitude. + # Use different defaults for Q, K, V to adapt to distribution diffefences + def _get_env_float(name: str, default: float) -> float: + try: + val = os.getenv(name, None) + return float(val) if val is not None and val != "" else default + except Exception: + return default + + FP8_MAX_Q = _get_env_float("FP8_MAX_Q", 448.0) + FP8_MAX_K = _get_env_float("FP8_MAX_K", 448.0) + FP8_MAX_V = _get_env_float("FP8_MAX_V", 448.0) self.model.eval() @@ -2482,7 +2494,7 @@ def _pre_hook(module, inputs): args = inputs if isinstance(inputs, (tuple, list)) else (inputs,) if len(args) == 1 and isinstance(args[0], (tuple, list)): args = args[0] - # 期望前 3 个为 q, k, v(Megatron CoreAttention 的典型签名) + # Expected first 3 args to be q, k, v (typical signature for Megatron CoreAttention) q = args[0] k = args[1] v = args[2] @@ -2596,10 +2608,10 @@ def _percentile(values: list[float], p: float) -> float: for layer_key, vals in layer_to_pamax.items(): out_entry = {} if include_q: - q_scale = (vals.get("q_amax_p", 0.0) * margin) / FP8_MAX_E4M3 + q_scale = (vals.get("q_amax_p", 0.0) * margin) / FP8_MAX_Q out_entry["q_scale"] = float(q_scale) - k_scale = (vals.get("k_amax_p", 0.0) * margin) / FP8_MAX_E4M3 - v_scale = (vals.get("v_amax_p", 0.0) * margin) / FP8_MAX_E4M3 + k_scale = (vals.get("k_amax_p", 0.0) * margin) / FP8_MAX_K + v_scale = (vals.get("v_amax_p", 0.0) * margin) / FP8_MAX_V out_entry["k_scale"] = float(k_scale) out_entry["v_scale"] = float(v_scale) result_layers[layer_key] = out_entry @@ -2623,7 +2635,7 @@ def _percentile(values: list[float], p: float) -> float: torch.distributed.broadcast_object_list(obj_list, src=0) final_result = obj_list[0] # type: ignore - # 可选保存至 JSON(仅 rank0) + # Optional save to JSON (only rank0) if save_path is not None and (not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0): try: with open(save_path, "w") as f: From c92dc557f6daf05574b5fa265fb5194ed89f6186 Mon Sep 17 00:00:00 2001 From: Shuang Yu Date: Mon, 10 Nov 2025 22:22:35 -0800 Subject: [PATCH 04/40] Correct typos Signed-off-by: Zhaopeng Qiu --- nemo_rl/algorithms/grpo.py | 8 +++++--- nemo_rl/models/policy/megatron_policy_worker.py | 6 +++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 861843cba7..4b3cea8dcb 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -180,6 +180,7 @@ class MasterConfig(TypedDict): # Setup & Initialization # =============================================================================== + def setup( master_config: MasterConfig, tokenizer: TokenizerType, @@ -984,6 +985,7 @@ def compute_kv_scales_with_data( default_scales[f"{name}.v_scale"] = 1.0 return default_scales + def refit_policy_generation( policy: ColocatablePolicyInterface, policy_generation: GenerationInterface, @@ -1094,7 +1096,7 @@ def grpo_train( # Check if we need to sync KV cache scales (infer from config) sync_kv_scales = _should_sync_kv_scales(master_config) - kv_scales_cache = None # Cache reused for compuated kv scales + kv_scales_cache = None # Cache reused for computed kv scales if sync_kv_scales: generation_config = master_config["policy"]["generation"] @@ -1108,14 +1110,14 @@ def grpo_train( print(f"[KV_SCALES] Configuration: policy_backend={policy_backend}, generation_backend={backend}") print(f"[KV_SCALES] vLLM settings: precision={vllm_precision}, kv_cache_dtype={kv_cache_dtype}") - # Temoprary assert check to flag error when kv cache fp8 is enabled but either of thefollowing conditions are met: + # Temporary assert check to flag error when kv cache fp8 is enabled but either of thefollowing conditions are met: # 1. policy backend is dtensor # 2. async rollouts is enabled # 3. pipeline_model_parallel_size is greater than 1 for the megatron backend # TODO: Add the related support assert policy_backend != "dtensor", "DTensor backend is not supported with kv cache fp8 enabled." assert not _should_use_async_rollouts(master_config), "Async rollouts is not supported with kv cache fp8 enabled." - assert master_config["policy"]["megatron_cfg"]["pipeline_model_parallel_size"] == 1, "Pipeline model parallel size must be 1 for megatron backend with kv cache fp8 enabled." + assert master_config["policy"]["megatron_cfg"].get("pipeline_model_parallel_size", 1) == 1, "Pipeline model parallel size must be 1 for megatron backend with kv cache fp8 enabled." else: print("[KV_SCALES] KV cache scale sync not needed (non-FP8 mode or kv_cache_dtype is not fp8)") diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 461c2493d5..6c8198b104 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -2475,12 +2475,12 @@ def _hook(module, inputs, output): q = out[..., :qkv_stride] k = out[..., qkv_stride : 2 * qkv_stride] v = out[..., 2 * qkv_stride : 3 * qkv_stride] - # per-tensor 绝对最大值(局部) + # per-tensor absolute maximum value (local) layer_to_samples_q[layer_key].append(float(torch.amax(torch.abs(q)).item())) layer_to_samples_k[layer_key].append(float(torch.amax(torch.abs(k)).item())) layer_to_samples_v[layer_key].append(float(torch.amax(torch.abs(v)).item())) except Exception as e: - print(f"[ALEXQ] Error extracting layer key: {e}") + print(f"[KV_SCALES] Error extracting layer key: {e}") pass return _hook @@ -2567,7 +2567,7 @@ def _pre_hook(module, inputs): h.remove() except Exception as e: print(f"[KV_SCALES] Error removing hook: {e}") - pass + pass # Compute local percentile amax def _percentile(values: list[float], p: float) -> float: From ad0978558d47d01da0df7cc37faf8c2dd22522b6 Mon Sep 17 00:00:00 2001 From: Shuang Yu Date: Mon, 10 Nov 2025 22:22:36 -0800 Subject: [PATCH 05/40] Update fp8.py Signed-off-by: Zhaopeng Qiu --- nemo_rl/models/generation/fp8.py | 114 ++++++++++++++++++++++++++++++- 1 file changed, 113 insertions(+), 1 deletion(-) diff --git a/nemo_rl/models/generation/fp8.py b/nemo_rl/models/generation/fp8.py index 474cf88a46..00f3dc43d5 100644 --- a/nemo_rl/models/generation/fp8.py +++ b/nemo_rl/models/generation/fp8.py @@ -40,6 +40,7 @@ class FP8Config: num_first_layers_in_bf16: int = 0 num_last_layers_in_bf16: int = 0 model_parallel_size: int = None + kv_cache_dtype: str = "auto" @dataclass() @@ -104,6 +105,106 @@ def patched_run_workers(self, *args, **kwargs): fp8_patches_applied = True +def kv_cache_process_weights_after_loading(self, layer: torch.nn.Module) -> None: + """ + Modified version of BaseKVCacheMethod.process_weights_after_loading that doesn't delete + k_scale, v_scale, q_scale, and prob_scale parameters to allow for dynamic updates. + """ + import torch + from vllm.logger import init_logger + from vllm.platforms import current_platform + + logger = init_logger(__name__) + + # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 + # regardless whether the kv-scale is available in the checkpoint. + # No need to process kv scales after loading if we are going to + # calculate them on the fly. + if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales: + if layer.k_scale > 0.0 and layer.v_scale > 0.0: + # We prefer to use separate k_scale and v_scale if present + k_scale = layer.k_scale.to("cpu").tolist() + v_scale = layer.v_scale.to("cpu").tolist() + if current_platform.is_fp8_fnuz(): + k_scale *= 2 + v_scale *= 2 + elif layer.k_scale < 0.0 and layer.v_scale < 0.0: + # If no scales were loaded (both scales are invalid negative + # values), use the default value of 1.0 + k_scale = 1.0 + v_scale = 1.0 + else: + # If we find a single kv_scale in the checkpoint, we remap + # kv_scale to k_scale during weight loading, and duplicate + # k_scale to v_scale here + assert layer.k_scale > 0.0 + scale_to_duplicate = max(layer.k_scale, layer.v_scale) + k_scale = scale_to_duplicate.to("cpu").tolist() + v_scale = scale_to_duplicate.to("cpu").tolist() + if current_platform.is_fp8_fnuz(): + k_scale *= 2 + v_scale *= 2 + + if not isinstance(k_scale, float) or not isinstance( + v_scale, float): + raise ValueError("Only support per-tensor scaling factor " + "for fp8 KV cache") + + if layer.q_scale < 0.0: + logger.warning_once( + "Checkpoint does not provide a q scaling factor. " + "Setting it to k_scale. This only matters for " + "the flash-attn backend.") + layer._q_scale.copy_(k_scale) + # These are used in the final Attention.forward() + layer._k_scale.copy_(k_scale) + layer._v_scale.copy_(v_scale) + layer._k_scale_float = k_scale + layer._v_scale_float = v_scale + if (k_scale == 1.0 and v_scale == 1.0 + and "e5m2" not in layer.kv_cache_dtype): + logger.warning_once( + "Using KV cache scaling factor 1.0 for fp8_e4m3. This " + "may cause accuracy issues. Please make sure k/v_scale " + "scaling factors are available in the fp8 checkpoint.") + + if layer.q_scale > 0.0: + q_scale = layer.q_scale + if current_platform.is_fp8_fnuz(): + q_scale *= 2 + layer.calculate_kv_scales = False + else: + q_scale = 1.0 + if layer.prob_scale > 0.0: + prob_scale = layer.prob_scale + if current_platform.is_fp8_fnuz(): + prob_scale *= 2 + else: + prob_scale = 1.0 + + is_singleton_float = lambda x: isinstance(x, float) or isinstance( + x, torch.Tensor) and x.numel() == 1 and x.is_floating_point() + if not is_singleton_float(q_scale) or not is_singleton_float( + prob_scale): + raise ValueError("Only support per-tensor scaling factor" + "for fp8-quantized Q/prob") + + # These are used in the final Attention.forward() + layer._q_scale.copy_(q_scale) + layer._prob_scale.copy_(prob_scale) + if layer.kv_cache_dtype == "fp8" and (q_scale == 1.0 + or prob_scale == 1.0): + logger.warning_once( + f"Using uncalibrated q_scale {q_scale} and/or prob_scale " + f"{prob_scale} with fp8 attention. This may cause accuracy " + "issues. Please make sure q/prob scaling factors are " + "available in the fp8 checkpoint.") + + # IMPORTANT: We DON'T delete the parameters here to allow for dynamic updates + # Original code deleted: layer.k_scale, layer.v_scale, layer.q_scale, layer.prob_scale + print(f"[KV_SCALES] Patched process_weights_after_loading: keeping k_scale, v_scale parameters for dynamic updates") + + def apply_fp8_patches(self, fp8_config): global global_fp8_config, fp8_patches_applied assert not fp8_patches_applied @@ -129,6 +230,11 @@ def apply_fp8_patches(self, fp8_config): patcher4 = patch(func4_path, _per_token_group_quant_fp8_colmajor) fp8_state.vllm_patches.append(patcher2, patcher3, patcher4) + # Patch the vllm kv_cache.py process_weights_after_loading() to remove the deletion of k_scale and v_scale + func5_path = "vllm.model_executor.layers.quantization.kv_cache.BaseKVCacheMethod.process_weights_after_loading" + patcher5 = patch(func5_path, kv_cache_process_weights_after_loading) + fp8_state.vllm_patches.append(patcher5) + for p in fp8_state.vllm_patches: p.start() @@ -151,6 +257,7 @@ def init_fp8(vllm_cfg, model_name, model_parallel_size): num_first_layers_in_bf16=vllm_cfg.get("num_first_layers_in_bf16", 0), num_last_layers_in_bf16=vllm_cfg.get("num_last_layers_in_bf16", 0), model_parallel_size=model_parallel_size, + kv_cache_dtype=vllm_cfg.get("kv_cache_dtype", "auto"), ) if vllm_cfg.get("use_deep_gemm", False): @@ -168,6 +275,7 @@ def init_fp8(vllm_cfg, model_name, model_parallel_size): # create fp8 kwargs for vllm's LLM(...) num_first_layers_in_bf16 = vllm_cfg.get("num_first_layers_in_bf16", 0) num_last_layers_in_bf16 = vllm_cfg.get("num_last_layers_in_bf16", 0) + kv_cache_dtype = vllm_cfg.get("kv_cache_dtype", "auto") fp8_block_quant_kwargs = dict(FP8_BLOCK_QUANT_KWARGS) if num_first_layers_in_bf16 > 0 or num_last_layers_in_bf16 > 0: @@ -192,8 +300,12 @@ def init_fp8(vllm_cfg, model_name, model_parallel_size): fp8_block_quant_kwargs["ignored_layers"] = bf16_params + # TODO: Remove this after debugging. + print(f"[KV_SCALES] Global FP8 config: {global_fp8_config}") vllm_kwargs = { "quantization": "fp8", + # Conditionally set kv_cache_dtype to fp8 if global config kv_cache_dtype is fp8 + "kv_cache_dtype": "fp8" if global_fp8_config.kv_cache_dtype == "fp8" else "auto", "hf_overrides": {"quantization_config": fp8_block_quant_kwargs}, } return vllm_kwargs @@ -552,4 +664,4 @@ def per_token_group_quant_fp8( per_token_group_quant_fp8 as vllm_per_token_group_quant_fp8, ) - return vllm_per_token_group_quant_fp8(*args, **kwargs) + return vllm_per_token_group_quant_fp8(*args, **kwargs) \ No newline at end of file From dca01e2d0856b0d9a76f281286eee103702f6407 Mon Sep 17 00:00:00 2001 From: Shuang Yu Date: Mon, 10 Nov 2025 22:22:36 -0800 Subject: [PATCH 06/40] Remove dapo.py from response_datasets Signed-off-by: Zhaopeng Qiu --- .../data/datasets/response_datasets/dapo.py | 105 ------------------ 1 file changed, 105 deletions(-) delete mode 100644 nemo_rl/data/datasets/response_datasets/dapo.py diff --git a/nemo_rl/data/datasets/response_datasets/dapo.py b/nemo_rl/data/datasets/response_datasets/dapo.py deleted file mode 100644 index d56dc5cc91..0000000000 --- a/nemo_rl/data/datasets/response_datasets/dapo.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import Any - -from datasets import Dataset, load_dataset - -from nemo_rl.data.interfaces import TaskDataSpec - - -def format_math(data: dict[str, str | float | int]) -> dict[str, list[Any] | str]: - return { - "messages": [ - { - "role": "user", - "content": data["problem"], - }, - { - "role": "assistant", - "content": data["answer"], - }, - ], - # For v0.1 release, nemo rl datasets require a task_name key such that user can map a task processor per unique task. - "task_name": "math", - } - - -def format_dapo_math(data: dict[str, Any]) -> dict[str, list[Any] | str]: - # Extract user content from prompt field - user_content = "" - for message in data["prompt"]: - if message["role"] == "user": - user_content = message["content"] - break - - # Extract ground truth from reward_model field - assistant_content = data["reward_model"]["ground_truth"] - - return { - "messages": [ - { - "role": "user", - "content": user_content, - }, - { - "role": "assistant", - "content": assistant_content, - }, - ], - # For v0.1 release, nemo rl datasets require a task_name key such that user can map a task processor per unique task. - "task_name": "math", - } - - -def prepare_dapo_dataset(seed: int = 42) -> dict[str, Dataset | None]: - """Load and split the DAPO dataset into train and test sets.""" - # Load the original dataset for training - train_ds = load_dataset("BytedTsinghua-SIA/DAPO-Math-17k", split="train") - - # Load hendrydong/aime24 dataset for validation - val_ds = load_dataset("HuggingFaceH4/aime_2024", split="train") - - # Shuffle the training dataset with the specified seed - train_ds = train_ds.shuffle(seed=seed) - - # Format the examples, removing original columns - train_formatted = train_ds.map(format_dapo_math, remove_columns=train_ds.column_names) - val_formatted = val_ds.map(format_math, remove_columns=val_ds.column_names) - - # Compute accuracy 16 times per sample (matching the DeepScaleR evaluation setting) - val_repeated = [] - for _ in range(16): - val_repeated.extend(val_formatted) - val_formatted = val_formatted.from_list(val_repeated) - - return { - "train": train_formatted, - "validation": val_formatted, - } - - -class DAPODataset: - def __init__(self, seed: int = 42) -> None: - """Initialize the DAPO dataset with train/test split. - - Args: - seed: Random seed for reproducible splitting - """ - self.formatted_ds = prepare_dapo_dataset(seed=seed) - - self.task_spec = TaskDataSpec( - task_name="DAPO", - ) From 03e47a9471eeba8039bf2935907cfda77035bac1 Mon Sep 17 00:00:00 2001 From: Shuang Yu Date: Mon, 10 Nov 2025 22:23:19 -0800 Subject: [PATCH 07/40] Update sanity check in grpo.py. Remove redundant code in megatron backend. Signed-off-by: Zhaopeng Qiu --- nemo_rl/algorithms/grpo.py | 22 ++----------------- .../models/policy/megatron_policy_worker.py | 16 +------------- 2 files changed, 3 insertions(+), 35 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 4b3cea8dcb..3424254d1b 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -1098,26 +1098,8 @@ def grpo_train( sync_kv_scales = _should_sync_kv_scales(master_config) kv_scales_cache = None # Cache reused for computed kv scales - if sync_kv_scales: - generation_config = master_config["policy"]["generation"] - vllm_cfg = generation_config.get("vllm_cfg", {}) - backend = generation_config.get("backend", "") - kv_cache_dtype = vllm_cfg.get("kv_cache_dtype", "auto") - vllm_precision = vllm_cfg.get("precision", "auto") - policy_backend = "megatron" if master_config["policy"].get("megatron_cfg", {}).get("enabled", False) else "dtensor" - - print(f"[KV_SCALES] FP8 KV cache detected, will sync q_scale, _k_scale and _v_scale during refit") - print(f"[KV_SCALES] Configuration: policy_backend={policy_backend}, generation_backend={backend}") - print(f"[KV_SCALES] vLLM settings: precision={vllm_precision}, kv_cache_dtype={kv_cache_dtype}") - - # Temporary assert check to flag error when kv cache fp8 is enabled but either of thefollowing conditions are met: - # 1. policy backend is dtensor - # 2. async rollouts is enabled - # 3. pipeline_model_parallel_size is greater than 1 for the megatron backend - # TODO: Add the related support - assert policy_backend != "dtensor", "DTensor backend is not supported with kv cache fp8 enabled." - assert not _should_use_async_rollouts(master_config), "Async rollouts is not supported with kv cache fp8 enabled." - assert master_config["policy"]["megatron_cfg"].get("pipeline_model_parallel_size", 1) == 1, "Pipeline model parallel size must be 1 for megatron backend with kv cache fp8 enabled." + if sync_kv_scales: + print(f"[KV_SCALES] FP8 KV cache detected, will sync q_scale, _k_scale and _v_scale during refit") # Note: KV cache FP8 compatibility assertions are now handled in the setup function else: print("[KV_SCALES] KV cache scale sync not needed (non-FP8 mode or kv_cache_dtype is not fp8)") diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 6c8198b104..f3eb2d56a8 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -2509,9 +2509,8 @@ def _pre_hook(module, inputs): return _pre_hook matched_modules = [] - # 1) Try to register forward_pre_hook on core_attention first + # Try to register forward_pre_hook on core_attention first for name, module in self.model.named_modules(): - print(f"[KV_SCALES] Module name: {name}") if "self_attention.core_attention" in name: try: handle = module.register_forward_pre_hook(_pre_hook_builder_core_attention(name)) @@ -2521,19 +2520,6 @@ def _pre_hook(module, inputs): print(f"[KV_SCALES] Error registering pre-hook on {name}: {e}") continue - # 2) If core_attention is not hit, fall back to forward_hook on QKV projection output - if not hook_handles: - qkv_name_patterns = ("query_key_value", "linear_qkv", ".qkv", "_qkv") - for name, module in self.model.named_modules(): - if any(pat in name for pat in qkv_name_patterns): - try: - handle = module.register_forward_hook(_hook_builder(name)) - hook_handles.append(handle) - matched_modules.append((name, module.__class__.__name__, "post")) - except Exception as e: - print(f"[KV_SCALES] Error registering hook on {name}: {e}") - continue - if not hook_handles: print("[KV_SCALES] No QKV proj modules matched for hook. Example module/param names:") try: From 96955de7df3916c47f744c813ff859b66520609b Mon Sep 17 00:00:00 2001 From: Shuang Yu Date: Mon, 10 Nov 2025 22:23:19 -0800 Subject: [PATCH 08/40] Remove redundant comments Signed-off-by: Zhaopeng Qiu --- nemo_rl/algorithms/grpo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 3424254d1b..569e9ec046 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -1099,7 +1099,7 @@ def grpo_train( kv_scales_cache = None # Cache reused for computed kv scales if sync_kv_scales: - print(f"[KV_SCALES] FP8 KV cache detected, will sync q_scale, _k_scale and _v_scale during refit") # Note: KV cache FP8 compatibility assertions are now handled in the setup function + print(f"[KV_SCALES] FP8 KV cache detected, will sync q_scale, _k_scale and _v_scale during refit") else: print("[KV_SCALES] KV cache scale sync not needed (non-FP8 mode or kv_cache_dtype is not fp8)") From b1214cd0901632a2dfdf27812d970a5d4916b4b0 Mon Sep 17 00:00:00 2001 From: Shuang Yu Date: Mon, 10 Nov 2025 22:23:19 -0800 Subject: [PATCH 09/40] Remove _hook_builder Signed-off-by: Zhaopeng Qiu --- .../models/policy/megatron_policy_worker.py | 22 ------------------- 1 file changed, 22 deletions(-) diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index f3eb2d56a8..f190e1663e 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -2463,28 +2463,6 @@ def _extract_layer_key(module_name: str) -> str: return f"layer_{m.group(1)}" return module_name - def _hook_builder(module_name: str): - layer_key = _extract_layer_key(module_name) - - def _hook(module, inputs, output): - out = output[0] if isinstance(output, (tuple, list)) else output - try: - last_dim = out.shape[-1] - assert last_dim % 3 == 0 - qkv_stride = last_dim // 3 - q = out[..., :qkv_stride] - k = out[..., qkv_stride : 2 * qkv_stride] - v = out[..., 2 * qkv_stride : 3 * qkv_stride] - # per-tensor absolute maximum value (local) - layer_to_samples_q[layer_key].append(float(torch.amax(torch.abs(q)).item())) - layer_to_samples_k[layer_key].append(float(torch.amax(torch.abs(k)).item())) - layer_to_samples_v[layer_key].append(float(torch.amax(torch.abs(v)).item())) - except Exception as e: - print(f"[KV_SCALES] Error extracting layer key: {e}") - pass - - return _hook - # Hook to capture q/k/v after q/k norm and RoPE def _pre_hook_builder_core_attention(module_name: str): layer_key = _extract_layer_key(module_name) From be8853d68c56f73b219ec6b235610d25dfdbfedc Mon Sep 17 00:00:00 2001 From: Zhaopeng Qiu Date: Tue, 11 Nov 2025 02:20:11 -0800 Subject: [PATCH 10/40] rebase and update refitting process Signed-off-by: Zhaopeng Qiu --- .../grpo_math_8B_megatron_fp8_kvcache.yaml | 19 ++ nemo_rl/algorithms/grpo.py | 156 ++++------------ nemo_rl/models/generation/fp8.py | 11 +- nemo_rl/models/generation/vllm/config.py | 1 + .../models/generation/vllm/vllm_backend.py | 4 +- .../models/generation/vllm/vllm_generation.py | 17 +- nemo_rl/models/generation/vllm/vllm_worker.py | 13 +- nemo_rl/models/policy/interfaces.py | 2 +- nemo_rl/models/policy/lm_policy.py | 9 +- .../models/policy/megatron_policy_worker.py | 174 +++++++++++++----- 10 files changed, 209 insertions(+), 197 deletions(-) create mode 100644 examples/configs/grpo_math_8B_megatron_fp8_kvcache.yaml diff --git a/examples/configs/grpo_math_8B_megatron_fp8_kvcache.yaml b/examples/configs/grpo_math_8B_megatron_fp8_kvcache.yaml new file mode 100644 index 0000000000..8932af29b2 --- /dev/null +++ b/examples/configs/grpo_math_8B_megatron_fp8_kvcache.yaml @@ -0,0 +1,19 @@ +# GRPO Algorithm Configuration +defaults: "grpo_math_8B_megatron.yaml" + +loss_fn: + use_importance_sampling_correction: true + +policy: + model_name: "Qwen/Qwen3-8B-Base" + megatron_cfg: + converter_type: "Qwen3ForCausalLM" + tensor_model_parallel_size: 4 + pipeline_model_parallel_size: 1 + generation: + vllm_cfg: + precision: 'fp8' + kv_cache_dtype: 'fp8' + use_deep_gemm: true + num_last_layers_in_bf16: 0 + num_first_layers_in_bf16: 0 diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 569e9ec046..ed1e156cab 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -500,6 +500,14 @@ def init_vllm(): assert loss_config["use_importance_sampling_correction"] is True, ( "Importance sampling must be enabled for vLLM FP8 generation for good convergence!" ) + if generation_config["vllm_cfg"]["kv_cache_dtype"] == "fp8": + # Temporary additional FP8 KV cache compatibility checks + # TODO: Add the related support + assert policy_config["dtensor_cfg"]["enabled"] == False, "DTensor backend is not supported with kv cache fp8 enabled." + assert not _should_use_async_rollouts(master_config), "Async rollouts is not supported with kv cache fp8 enabled." + assert policy_config["megatron_cfg"]["pipeline_model_parallel_size"] == 1, "Pipeline model parallel size must be 1 for megatron backend with kv cache fp8 enabled." + + ## make vllm hf overrides match the training policy generation_config["vllm_cfg"]["hf_overrides"] = policy_config.get( "hf_config_overrides", {} ) @@ -861,129 +869,28 @@ def _should_use_penguin(master_config: MasterConfig) -> bool: return should_use_penguin -# Need a function to compute the kv cache scales for all the attention layers with the updated policy model -# TODO: Determine the inputs and outputs. inputs: the trained policy model? training data? -# TODO: How to do the calculation? Caculating the kv cache scales needs to do a foward path with some training data, get the activations of each attention layer and compute the scales based on the activations. -# TODO: The calcuation needs to be done only when the sync_kv_scales flat is True, and after policy model is updated. -# TODO: The output should be a dictionary of the kv cache scales for all the attention layers? The structure should be consistent with the required format that can be loaded by vllm using model_runner.model.load_weights() as the other weights. -# Code snippet reference: /lustre/fsw/portfolios/coreai/users/shuangy/src/vllm/vllm/model_executor/layers/quantization/kv_cache.py: -# def create_weights(self, layer: torch.nn.Module): -# """ -# Create "weight" (aka q_scale, k_scale and v_scale) -# for an attention layer. -# """ - # Initialize the Q and KV cache scales to -1.0, an invalid value. - # If the q and k/v_scales appear in the checkpoint, it will be - # overwritten when loading weights. -# layer.q_scale = torch.nn.Parameter(torch.tensor(-1.0), -# requires_grad=False) -# layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), -# requires_grad=False) -# layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), -# requires_grad=False) - # Initialize P = softmax(QK^T) scales -# layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0), -# requires_grad=False) -# requires_grad=False) -# TODO: How to pass the kv scales to refit_policy_generation()? refit_policy_generation() is the function that updates the weights of the policy generation interface. -# When refit_policy_generation() invokes update_weights_from_ipc_handles() or update_weights_from_collective(), if it is fp8 and sync_kv_scales is True, the kv scales should be passed to the policy generation interface. load_weights() once invoked will load the kv scales. -# In order for vllm to really load the kv scales, the kv scales should be passed to the policy generation interface in the same format as the other weights. -# Additionally, vllm process_weights_after_loading() will be invoked after load_weights() to copy the kv scales to the _k_scale and _v_scale attributes. Reference code: /lustre/fsw/portfolios/coreai/users/shuangy/src/vllm/vllm/model_executor/layers/quantization/kv_cache.py - - -def compute_kv_scales_with_data( - policy: ColocatablePolicyInterface, - sample_data: BatchedDataDict, - master_config: MasterConfig, - max_samples: int = 32, -) -> dict[str, float]: + +# Function to check if KV cache scales should be calculated and synchronized during refit +def _should_sync_kv_scales(master_config: MasterConfig) -> bool: """ - Compute KV cache scales for all attention layers using calibration data. + Check if KV cache scales should be synchronized during refit. - Args: - policy: The policy model to calibrate - sample_data: Calibration data batch - master_config: Configuration containing model settings - max_samples: Maximum number of samples to use for calibration - - Returns: - Dictionary mapping parameter names to scale values for K/V cache quantization + Returns True if: + - vLLM precision is fp8 and kv_cache_dtype is fp8 """ - # TODO: Review the implementation of this function. - print(f"[KV_SCALES] Computing KV cache scales with {min(max_samples, sample_data.size)} samples...") - - # Limit the number of samples for calibration - if sample_data.size > max_samples: - sample_data = sample_data.slice(0, max_samples) + generation_config = master_config["policy"]["generation"] + if generation_config is None: + return False - # Convert to input format expected by policy - import torch - from nemo_rl.data.llm_message_utils import batched_message_log_to_flat_message + backend = generation_config.get("backend", "") + if backend != "vllm": + return False - try: - # Extract tokenized inputs from the batch - batched_flat, input_lengths = batched_message_log_to_flat_message( - sample_data["message_log"], - pad_value_dict={"token_ids": 0} # Use 0 as pad token for calibration - ) - input_ids = batched_flat["token_ids"] - - # Convert to tensor if needed - if not isinstance(input_ids, torch.Tensor): - input_ids = torch.tensor(input_ids, dtype=torch.long) - - # For distributed policy, we'll use a simplified approach - # TODO: Implement proper distributed calibration through worker_group - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - input_ids = input_ids.to(device) - - print(f"[KV_SCALES] Calibration input shape: {input_ids.shape}") - - # Skip the complex hook-based calibration for distributed policy - # TODO: Implement proper distributed calibration using policy.worker_group - - # For distributed Policy, we cannot directly access model.named_modules() - # Instead, we'll use a simplified approach with default scales - print("[KV_SCALES] Using simplified calibration for distributed policy") - - # TODO: For a quick prototype, use a pseudo default scales. - # Need to update later to would use worker_group to run calibration? - default_k_scale = 0.1 # Conservative scale for K projections - default_v_scale = 0.1 # Conservative scale for V projections - # TODO: Current use Qwen3-8B-Base as an example, should be obtained from model config - num_layers = 36 # Default number of layers - should be obtained from model config - - # Generate default KV scales for distributed policy - kv_scales = {} - print("[KV_SCALES] Generating default KV scales for distributed policy") - - # Generate scales for typical transformer layers - for layer_idx in range(num_layers): - k_param_name = f"model.layers.{layer_idx}.self_attn.k_scale" - v_param_name = f"model.layers.{layer_idx}.self_attn.v_scale" - - kv_scales[k_param_name] = default_k_scale - kv_scales[v_param_name] = default_v_scale - - print(f"[KV_SCALES] Computed {len(kv_scales)} KV cache scales") - return kv_scales + vllm_cfg = generation_config.get("vllm_cfg", {}) + kv_cache_dtype = vllm_cfg.get("kv_cache_dtype", "auto") + vllm_precision = vllm_cfg.get("precision", "auto") - except Exception as e: - print(f"[KV_SCALES] Error computing KV scales: {e}") - # For training stability, we can either: - # 1. Re-raise the exception to fail fast and debug issues early - # 2. Fall back to default scales to continue training - # Current choice: fallback for robustness, but log the error clearly - print("[KV_SCALES] Falling back to default scales to maintain training stability") - print("[KV_SCALES] Note: This may impact FP8 quantization quality") - - # Return default scales - default_scales = {} - for name, module in policy.model.named_modules(): - if "self_attn" in name: - default_scales[f"{name}.k_scale"] = 1.0 - default_scales[f"{name}.v_scale"] = 1.0 - return default_scales + return kv_cache_dtype == "fp8" and vllm_precision == "fp8" def refit_policy_generation( @@ -1009,6 +916,9 @@ def refit_policy_generation( policy.offload_before_refit() policy_generation.prepare_for_generation(tags=["weights"]) + if kv_scales: + print(f"[KV_SCALES] Refit: Adding {len(kv_scales)} KV scales to weight update") + # Create a context manager that does nothing when timer is None timer_context = ( timer.time("prepare_for_generation/transfer_and_update_weights") @@ -1031,7 +941,7 @@ def refit_policy_generation( ) futures_train = policy.stream_weights_via_ipc_zmq( - buffer_size_bytes=buffer_size_bytes + buffer_size_bytes=buffer_size_bytes, kv_scales=kv_scales ) futures_inference = policy_generation.update_weights_via_ipc_zmq() # wait for all futures to complete @@ -1040,12 +950,8 @@ def refit_policy_generation( update_success = all(result for result in results if result is not None) else: # update weights through nccl - futures_train = policy.broadcast_weights_for_collective() - if kv_scales: - print(f"[KV_SCALES] Refit: Adding {len(kv_scales)} KV scales to collective weight update") - futures_inference = policy_generation.update_weights_from_collective(kv_scales=kv_scales) - else: - futures_inference = policy_generation.update_weights_from_collective() + futures_train = policy.broadcast_weights_for_collective(kv_scales=kv_scales) + futures_inference = policy_generation.update_weights_from_collective() # wait for all futures to complete ray.get(futures_train) results = ray.get(futures_inference) @@ -1192,7 +1098,7 @@ def grpo_train( f"▶ Generating responses for batch of size {repeated_batch.size}...", flush=True, ) - with timer.time("prepare_for_generation"): + with timer.time("prepare_for_generation/total"): if NEED_REFIT and POLICY_GENERATION_STALE: # Compute KV scales if needed for FP8 quantization if sync_kv_scales and kv_scales_cache is None: diff --git a/nemo_rl/models/generation/fp8.py b/nemo_rl/models/generation/fp8.py index 00f3dc43d5..70b20c6d7c 100644 --- a/nemo_rl/models/generation/fp8.py +++ b/nemo_rl/models/generation/fp8.py @@ -148,7 +148,7 @@ def kv_cache_process_weights_after_loading(self, layer: torch.nn.Module) -> None if not isinstance(k_scale, float) or not isinstance( v_scale, float): raise ValueError("Only support per-tensor scaling factor " - "for fp8 KV cache") + "for fp8 KV cache") if layer.q_scale < 0.0: logger.warning_once( @@ -156,6 +156,8 @@ def kv_cache_process_weights_after_loading(self, layer: torch.nn.Module) -> None "Setting it to k_scale. This only matters for " "the flash-attn backend.") layer._q_scale.copy_(k_scale) + layer._q_scale_float = k_scale + # These are used in the final Attention.forward() layer._k_scale.copy_(k_scale) layer._v_scale.copy_(v_scale) @@ -187,13 +189,16 @@ def kv_cache_process_weights_after_loading(self, layer: torch.nn.Module) -> None if not is_singleton_float(q_scale) or not is_singleton_float( prob_scale): raise ValueError("Only support per-tensor scaling factor" - "for fp8-quantized Q/prob") + "for fp8-quantized Q/prob") # These are used in the final Attention.forward() layer._q_scale.copy_(q_scale) + layer._q_scale_float = q_scale.item() if isinstance( + q_scale, torch.Tensor) else q_scale + layer._prob_scale.copy_(prob_scale) if layer.kv_cache_dtype == "fp8" and (q_scale == 1.0 - or prob_scale == 1.0): + or prob_scale == 1.0): logger.warning_once( f"Using uncalibrated q_scale {q_scale} and/or prob_scale " f"{prob_scale} with fp8 attention. This may cause accuracy " diff --git a/nemo_rl/models/generation/vllm/config.py b/nemo_rl/models/generation/vllm/config.py index 8ea82ec4db..c3a0171679 100644 --- a/nemo_rl/models/generation/vllm/config.py +++ b/nemo_rl/models/generation/vllm/config.py @@ -28,6 +28,7 @@ class VllmSpecificArgs(TypedDict): async_engine: bool load_format: NotRequired[str] precision: NotRequired[str] + kv_cache_dtype: NotRequired[str] enforce_eager: NotRequired[bool] # By default, NeMo RL only has a Python handle to the vllm.LLM generation engine. The expose_http_server flag here will expose that generation engine as an HTTP server. # Exposing vLLM as a server is useful in instances where the multi-turn rollout is performed with utilities outside of NeMo RL, but the user still wants to take advantage of the refit logic in NeMo RL that keeps the policy and generation up to date. diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index 290f51d8ed..678c2d7028 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -13,7 +13,7 @@ # limitations under the License. import gc import traceback -from typing import Any +from typing import Any, Optional import torch import zmq @@ -173,7 +173,7 @@ def update_weights_via_ipc_zmq(self) -> bool: @wrap_with_nvtx_name( "vllm_internal_worker_extension/update_weights_from_collective" ) - def update_weights_from_collective(self, kv_scales: Optional[dict[str, float]] = None) -> bool: + def update_weights_from_collective(self) -> bool: """Update the model weights from collective communication.""" assert self.state_dict_info is not None, ( "state_dict_info is not prepared. " diff --git a/nemo_rl/models/generation/vllm/vllm_generation.py b/nemo_rl/models/generation/vllm/vllm_generation.py index b5c81a81fd..5dcc7eaf2e 100644 --- a/nemo_rl/models/generation/vllm/vllm_generation.py +++ b/nemo_rl/models/generation/vllm/vllm_generation.py @@ -788,7 +788,7 @@ def update_weights_via_ipc_zmq(self) -> list[ray.ObjectRef]: # this function should co-work with lm_policy, so we should wait for all futures to complete outside return futures - def update_weights_from_collective(self, kv_scales: Optional[dict[str, float]] = None) -> list[ray.ObjectRef]: + def update_weights_from_collective(self) -> list[ray.ObjectRef]: """Update weights of the policy using collective communication.""" if not self.worker_group or not self.worker_group.workers: raise RuntimeError("Worker group is not initialized") @@ -801,17 +801,10 @@ def update_weights_from_collective(self, kv_scales: Optional[dict[str, float]] = ) # Use run_all_workers_single_data for methods that don't need data - if kv_scales: - futures = self.worker_group.run_all_workers_single_data( - method_name, - kv_scales=kv_scales, - run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"], - ) - else: - futures = self.worker_group.run_all_workers_single_data( - method_name, - run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"], - ) + futures = self.worker_group.run_all_workers_single_data( + method_name, + run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"], + ) # this function should co-work with lm_policy, so we should wait for all futures to complete outside return futures diff --git a/nemo_rl/models/generation/vllm/vllm_worker.py b/nemo_rl/models/generation/vllm/vllm_worker.py index 82eac8c452..a97d68e669 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker.py +++ b/nemo_rl/models/generation/vllm/vllm_worker.py @@ -665,7 +665,7 @@ def update_weights_via_ipc_zmq(self) -> bool: return False @wrap_with_nvtx_name("vllm_genertion_worker/update_weights_from_collective") - def update_weights_from_collective(self, kv_scales: Optional[dict[str, float]] = None) -> bool: + def update_weights_from_collective(self) -> bool: """Update the model weights from collective communication.""" try: assert self.llm is not None, ( @@ -677,14 +677,9 @@ def update_weights_from_collective(self, kv_scales: Optional[dict[str, float]] = "update_weights_from_collective can only be used with async_engine=False. Use update_weights_from_collective_async instead." ) - if kv_scales: - result_or_coro = self.llm.collective_rpc( - "update_weights_from_collective", args=(kv_scales,) - ) - else: - result_or_coro = self.llm.collective_rpc( - "update_weights_from_collective", args=tuple() - ) + result_or_coro = self.llm.collective_rpc( + "update_weights_from_collective", args=tuple() + ) worker_result = result_or_coro[0] if not worker_result: diff --git a/nemo_rl/models/policy/interfaces.py b/nemo_rl/models/policy/interfaces.py index aa229b6902..e486821536 100644 --- a/nemo_rl/models/policy/interfaces.py +++ b/nemo_rl/models/policy/interfaces.py @@ -189,5 +189,5 @@ def stream_weights_via_ipc_zmq( pass @abstractmethod - def broadcast_weights_for_collective(self) -> list[ray.ObjectRef]: + def broadcast_weights_for_collective(self, kv_scales: Optional[dict[str, float]] = None) -> list[ray.ObjectRef]: pass diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index feb25fcb2a..72a67b1458 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -762,17 +762,18 @@ def get_free_memory_bytes(self) -> int: free_memory_bytes = min(ray.get(future) for future in futures) return free_memory_bytes - def stream_weights_via_ipc_zmq(self, buffer_size_bytes: int) -> list[ray.ObjectRef]: + def stream_weights_via_ipc_zmq(self, buffer_size_bytes: int, kv_scales: Optional[dict[str, float]] = None) -> list[ray.ObjectRef]: """Send the weights for IPC handles via ZMQ socket.""" futures = self.worker_group.run_all_workers_single_data( - "stream_weights_via_ipc_zmq", buffer_size_bytes=buffer_size_bytes + "stream_weights_via_ipc_zmq", buffer_size_bytes=buffer_size_bytes, kv_scales=kv_scales ) return futures - def broadcast_weights_for_collective(self) -> list[ray.ObjectRef]: + def broadcast_weights_for_collective(self, kv_scales: Optional[dict[str, float]] = None) -> list[ray.ObjectRef]: """Broadcast the weights for collective communication.""" futures = self.worker_group.run_all_workers_single_data( - "broadcast_weights_for_collective" + "broadcast_weights_for_collective", + kv_scales=kv_scales, ) # this function should co-work with vllm, so we should wait for all futures to complete outside return futures diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index f190e1663e..6e7e467a28 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -1918,6 +1918,32 @@ def prepare_refit_info(self) -> None: for name, tensor in hf_params_generator: metadata = (tensor.shape, tensor.dtype) refit_param_info_hf[name] = metadata + # Also include KV/Q scale metadata so consumer can rely solely on state_dict_info + try: + import re + # Infer number of layers by scanning existing parameter names + max_layer_idx = -1 + layer_regex = re.compile(r"^model\\.layers\\.(\\d+)\\.") + for k in refit_param_info_hf.keys(): + m = layer_regex.match(k) + if m: + idx = int(m.group(1)) + if idx > max_layer_idx: + max_layer_idx = idx + num_layers = max_layer_idx + 1 if max_layer_idx >= 0 else 0 + # Append q/k/v scale placeholders (shape [1], dtype float32) + for layer_idx in range(num_layers): + q_key = f"model.layers.{layer_idx}.self_attn.attn.q_scale" + k_key = f"model.layers.{layer_idx}.self_attn.k_scale" + v_key = f"model.layers.{layer_idx}.self_attn.v_scale" + if q_key not in refit_param_info_hf: + refit_param_info_hf[q_key] = ([1], torch.float32) + if k_key not in refit_param_info_hf: + refit_param_info_hf[k_key] = ([1], torch.float32) + if v_key not in refit_param_info_hf: + refit_param_info_hf[v_key] = ([1], torch.float32) + except Exception: + pass return refit_param_info_hf def _calculate_refit_param_info(self) -> list[tuple[str, int]]: @@ -1980,7 +2006,7 @@ def get_free_memory_bytes(self) -> int: @torch.no_grad() @wrap_with_nvtx_name("megatron_policy_worker/stream_weights_via_ipc_zmq") - def stream_weights_via_ipc_zmq(self, buffer_size_bytes: int = 0) -> None: + def stream_weights_via_ipc_zmq(self, buffer_size_bytes: int = 0, kv_scales: Optional[dict[str, float]] = None) -> None: """Stream model weights to peer process via ZMQ IPC socket.""" self.maybe_init_zmq() @@ -1993,9 +2019,42 @@ def stream_weights_via_ipc_zmq(self, buffer_size_bytes: int = 0) -> None: conversion_tasks=self.refit_conversion_tasks, # used for metadata caching ) + def _enumerate_kv_scale_keys(): + import re + max_layer_idx = -1 + layer_regex = re.compile(r"^model\\.layers\\.(\\d+)\\.") + # Use state_dict keys to infer number of layers + for name in self.model.state_dict().keys(): + m = layer_regex.match(name) + if m: + idx = int(m.group(1)) + if idx > max_layer_idx: + max_layer_idx = idx + num_layers = max_layer_idx + 1 if max_layer_idx >= 0 else 0 + keys = [] + for layer_idx in range(num_layers): + keys.append(f"model.layers.{layer_idx}.self_attn.attn.q_scale") + keys.append(f"model.layers.{layer_idx}.self_attn.k_scale") + keys.append(f"model.layers.{layer_idx}.self_attn.v_scale") + return keys + + def iter_with_kv_scales(): + for name, tensor in hf_params_generator: + yield name, tensor + # Always append kv-scale entries to match metadata; use provided value or default 1.0 + for param_name in _enumerate_kv_scale_keys(): + if kv_scales and param_name in kv_scales: + scale_value = kv_scales[param_name] + else: + scale_value = 1.0 + scale_tensor = torch.tensor( + scale_value, dtype=torch.float32, device="cuda" + ).reshape(1) + yield param_name, scale_tensor + # Use the shared implementation stream_weights_via_ipc_zmq_impl( - params_generator=hf_params_generator, + params_generator=iter_with_kv_scales(), buffer_size_bytes=buffer_size_bytes, zmq_socket=self.zmq_socket, rank=self.rank, @@ -2003,7 +2062,7 @@ def stream_weights_via_ipc_zmq(self, buffer_size_bytes: int = 0) -> None: ) @torch.no_grad() - def broadcast_weights_for_collective(self) -> None: + def broadcast_weights_for_collective(self, kv_scales: Optional[dict[str, float]] = None) -> None: """Broadcast the weights for collective communication.""" hf_params_generator = self.megatron_bridge.export_hf_weights( [self.model], @@ -2011,9 +2070,42 @@ def broadcast_weights_for_collective(self) -> None: conversion_tasks=self.refit_conversion_tasks, # used for metadata caching ) + def _enumerate_kv_scale_keys(): + import re + max_layer_idx = -1 + layer_regex = re.compile(r"^model\\.layers\\.(\\d+)\\.") + # Use state_dict keys to infer number of layers + for name in self.model.state_dict().keys(): + m = layer_regex.match(name) + if m: + idx = int(m.group(1)) + if idx > max_layer_idx: + max_layer_idx = idx + num_layers = max_layer_idx + 1 if max_layer_idx >= 0 else 0 + keys = [] + for layer_idx in range(num_layers): + keys.append(f"model.layers.{layer_idx}.self_attn.attn.q_scale") + keys.append(f"model.layers.{layer_idx}.self_attn.k_scale") + keys.append(f"model.layers.{layer_idx}.self_attn.v_scale") + return keys + + def iter_with_kv_scales(): + for name, tensor in hf_params_generator: + yield name, tensor + # Always append kv-scale entries to match metadata; use provided value or default 1.0 + for param_name in _enumerate_kv_scale_keys(): + if kv_scales and param_name in kv_scales: + scale_value = kv_scales[param_name] + else: + scale_value = 1.0 + scale_tensor = torch.tensor( + scale_value, dtype=torch.float32, device="cuda" + ).reshape(1) + yield param_name, scale_tensor + # param_iterator will return (name, tensor), we only need tensor packed_broadcast_producer( - iterator=hf_params_generator, + iterator=iter_with_kv_scales(), group=self.model_update_group, src=0, post_iter_func=lambda x: x[1], @@ -2367,43 +2459,6 @@ def check_tensor_parallel_attributes(self) -> dict[str, Any]: "tp_size": self.megatron_cfg.model.tensor_model_parallel_size, } - -class CustomFloat16Module(Float16Module): - """Float 16 Module. - - Attributes: - config (TransformerConfig): Transformer config - fp16 (bool) : Specifies if the model runs in fp16 mode - bf16 (bool) : Specifies if the model runs in bf16 mode - - Args: - config (TransformerConfig): The transformer config used to initalize the model - """ - - def __init__(self, config: TransformerConfig, module: torch.nn.Module): - super(CustomFloat16Module, self).__init__(config, module) - self.re_enable_float32_expert_bias() - - def re_enable_float32_expert_bias(self) -> None: - """Ensure MoE router expert bias stays in float32 for numerical stability. - - Walks the wrapped module to find MoE routers and invokes the - `_maintain_float32_expert_bias()` helper which recreates or casts the - expert bias tensors to float32 as required by Megatron-LM. - """ - module = self.module - # Handle VLM models where language model is nested - if hasattr(module, "language_model"): - module = module.language_model - if hasattr(module, "decoder") and hasattr(module.decoder, "layers"): - for layer in module.decoder.layers: - mlp = getattr(layer, "mlp", None) - router = getattr(mlp, "router", None) if mlp is not None else None - if router is not None and hasattr( - router, "_maintain_float32_expert_bias" - ): - router._maintain_float32_expert_bias() - @torch.no_grad() def calibrate_qkv_fp8_scales( self, @@ -2608,3 +2663,40 @@ def _percentile(values: list[float], p: float) -> float: pass return final_result + + +class CustomFloat16Module(Float16Module): + """Float 16 Module. + + Attributes: + config (TransformerConfig): Transformer config + fp16 (bool) : Specifies if the model runs in fp16 mode + bf16 (bool) : Specifies if the model runs in bf16 mode + + Args: + config (TransformerConfig): The transformer config used to initalize the model + """ + + def __init__(self, config: TransformerConfig, module: torch.nn.Module): + super(CustomFloat16Module, self).__init__(config, module) + self.re_enable_float32_expert_bias() + + def re_enable_float32_expert_bias(self) -> None: + """Ensure MoE router expert bias stays in float32 for numerical stability. + + Walks the wrapped module to find MoE routers and invokes the + `_maintain_float32_expert_bias()` helper which recreates or casts the + expert bias tensors to float32 as required by Megatron-LM. + """ + module = self.module + # Handle VLM models where language model is nested + if hasattr(module, "language_model"): + module = module.language_model + if hasattr(module, "decoder") and hasattr(module.decoder, "layers"): + for layer in module.decoder.layers: + mlp = getattr(layer, "mlp", None) + router = getattr(mlp, "router", None) if mlp is not None else None + if router is not None and hasattr( + router, "_maintain_float32_expert_bias" + ): + router._maintain_float32_expert_bias() From 4ae3ec0170aaeca8df3ffdbc80698767dd3a225b Mon Sep 17 00:00:00 2001 From: Zhaopeng Qiu Date: Tue, 11 Nov 2025 08:07:44 -0800 Subject: [PATCH 11/40] fix refitting bugs after rebase Signed-off-by: Zhaopeng Qiu --- nemo_rl/models/generation/fp8.py | 6 ++ .../models/generation/vllm/vllm_backend.py | 16 +++++ .../models/policy/megatron_policy_worker.py | 72 +++++++++---------- 3 files changed, 55 insertions(+), 39 deletions(-) diff --git a/nemo_rl/models/generation/fp8.py b/nemo_rl/models/generation/fp8.py index 70b20c6d7c..6291ef7d62 100644 --- a/nemo_rl/models/generation/fp8.py +++ b/nemo_rl/models/generation/fp8.py @@ -115,6 +115,9 @@ def kv_cache_process_weights_after_loading(self, layer: torch.nn.Module) -> None from vllm.platforms import current_platform logger = init_logger(__name__) + + # print(f"[KV_SCALES] kv_cache_process_weights_after_loading: layer.k_scale = {layer.k_scale}, layer.v_scale = {layer.v_scale}") + print(f"[@@KV_SCALES@@] [fp8.py] kv_cache_process_weights_after_loading: layer.k_scale = {layer.k_scale}, layer.v_scale = {layer.v_scale}") # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 # regardless whether the kv-scale is available in the checkpoint. @@ -408,9 +411,12 @@ def load_weights(weights, model_runner): model = model_runner.model for k, v in weights: + if "scale" in k: + print(f"[@@KV_SCALES@@] [fp8.py] load_weights: Parameter {k}, value = {v.item() if v.numel() == 1 else v}") if not _is_fp8_weight(k, model): weights_quantized.append((k, v)) continue + print(f"[@@KV_SCALES@@] [fp8.py] load_weights: Casting weight {k} into fp8") # Cast the weight into fp8 and its scale factor param_lp, param_scale = cast_tensor_to_fp8_blockwise( v.to(torch.float), diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index 678c2d7028..a4d1f9163b 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -160,6 +160,22 @@ def update_weights_via_ipc_zmq(self) -> bool: buffer = None self.zmq_socket.send(IPCProtocol.ACK.value.encode()) + # When kv_scales is provided, we need to invoke process_weights_after_loading() + # to copy the kv scales to the _k_scale and _v_scale attributes used during inference + # if kv_scales: + print(f"[@@KV_SCALES@@] [vllm_backend.py] update_weights_via_ipc_zmq: Processing KV cache scales after weight loading") + from vllm.model_executor.model_loader.utils import process_weights_after_loading + + # Get target device for processing + target_device = next(self.model_runner.model.parameters()).device + + # Call process_weights_after_loading to handle KV scales + process_weights_after_loading( + self.model_runner.model, + self.model_runner.model_config, + target_device + ) + gc.collect() torch.cuda.empty_cache() return True diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 6e7e467a28..8d9f33b2a3 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -1902,6 +1902,13 @@ def maybe_init_zmq(self): self.zmq_socket.setsockopt(zmq.LINGER, 0) self.zmq_socket.bind(self.get_zmq_address()) + def _extract_layer_key(self, module_name: str) -> int: + # Expected format: "module.decoder.layers..self_attention.query_key_value" + m = re.search(r"model\.layers\.(\d+)", module_name) + if m is not None: + return int(m.group(1)) + return -1 + @torch.no_grad() @wrap_with_nvtx_name("megatron_policy_worker/prepare_refit_info") def prepare_refit_info(self) -> None: @@ -1921,16 +1928,17 @@ def prepare_refit_info(self) -> None: # Also include KV/Q scale metadata so consumer can rely solely on state_dict_info try: import re - # Infer number of layers by scanning existing parameter names + # Infer number of layers by scanning existing parameter names (robust to various prefixes) max_layer_idx = -1 - layer_regex = re.compile(r"^model\\.layers\\.(\\d+)\\.") + for k in refit_param_info_hf.keys(): - m = layer_regex.match(k) - if m: - idx = int(m.group(1)) - if idx > max_layer_idx: - max_layer_idx = idx - num_layers = max_layer_idx + 1 if max_layer_idx >= 0 else 0 + # print(f"[@@KV_SCALES@@] prepare_refit_info: k = {k}") + idx = self._extract_layer_key(k) + if idx > max_layer_idx: + max_layer_idx = idx + + num_layers = (max_layer_idx + 1) if max_layer_idx >= 0 else 0 + print(f"[@@KV_SCALES@@] prepare_refit_info: num_layers = {num_layers}") # Append q/k/v scale placeholders (shape [1], dtype float32) for layer_idx in range(num_layers): q_key = f"model.layers.{layer_idx}.self_attn.attn.q_scale" @@ -2019,30 +2027,24 @@ def stream_weights_via_ipc_zmq(self, buffer_size_bytes: int = 0, kv_scales: Opt conversion_tasks=self.refit_conversion_tasks, # used for metadata caching ) - def _enumerate_kv_scale_keys(): - import re + def iter_with_kv_scales(): max_layer_idx = -1 - layer_regex = re.compile(r"^model\\.layers\\.(\\d+)\\.") - # Use state_dict keys to infer number of layers - for name in self.model.state_dict().keys(): - m = layer_regex.match(name) - if m: - idx = int(m.group(1)) - if idx > max_layer_idx: - max_layer_idx = idx + for name, tensor in hf_params_generator: + idx = self._extract_layer_key(name) + if idx > max_layer_idx: + max_layer_idx = idx + + yield name, tensor + num_layers = max_layer_idx + 1 if max_layer_idx >= 0 else 0 + print(f"[@@KV_SCALES@@] iter_with_kv_scales: num_layers = {num_layers}") keys = [] for layer_idx in range(num_layers): keys.append(f"model.layers.{layer_idx}.self_attn.attn.q_scale") keys.append(f"model.layers.{layer_idx}.self_attn.k_scale") keys.append(f"model.layers.{layer_idx}.self_attn.v_scale") - return keys - - def iter_with_kv_scales(): - for name, tensor in hf_params_generator: - yield name, tensor # Always append kv-scale entries to match metadata; use provided value or default 1.0 - for param_name in _enumerate_kv_scale_keys(): + for param_name in keys: if kv_scales and param_name in kv_scales: scale_value = kv_scales[param_name] else: @@ -2070,30 +2072,22 @@ def broadcast_weights_for_collective(self, kv_scales: Optional[dict[str, float]] conversion_tasks=self.refit_conversion_tasks, # used for metadata caching ) - def _enumerate_kv_scale_keys(): - import re + def iter_with_kv_scales(): max_layer_idx = -1 - layer_regex = re.compile(r"^model\\.layers\\.(\\d+)\\.") - # Use state_dict keys to infer number of layers - for name in self.model.state_dict().keys(): - m = layer_regex.match(name) - if m: - idx = int(m.group(1)) - if idx > max_layer_idx: - max_layer_idx = idx + for name, tensor in hf_params_generator: + idx = self._extract_layer_key(name) + if idx > max_layer_idx: + max_layer_idx = idx num_layers = max_layer_idx + 1 if max_layer_idx >= 0 else 0 + print(f"[@@KV_SCALES@@] iter_with_kv_scales: num_layers = {num_layers}") keys = [] for layer_idx in range(num_layers): keys.append(f"model.layers.{layer_idx}.self_attn.attn.q_scale") keys.append(f"model.layers.{layer_idx}.self_attn.k_scale") keys.append(f"model.layers.{layer_idx}.self_attn.v_scale") - return keys - - def iter_with_kv_scales(): - for name, tensor in hf_params_generator: yield name, tensor # Always append kv-scale entries to match metadata; use provided value or default 1.0 - for param_name in _enumerate_kv_scale_keys(): + for param_name in keys: if kv_scales and param_name in kv_scales: scale_value = kv_scales[param_name] else: From d66c26205b305a78690dbba9adfe10d128c989ef Mon Sep 17 00:00:00 2001 From: Shuang Yu Date: Fri, 14 Nov 2025 00:25:02 -0800 Subject: [PATCH 12/40] Refactor FP8 KV cache scale handling by centralizing vLLM parameter naming and deduplicating conversion logic Signed-off-by: Zhaopeng Qiu --- nemo_rl/algorithms/grpo.py | 27 ++----- nemo_rl/models/generation/fp8.py | 80 +++++++++++++++++++ .../models/policy/megatron_policy_worker.py | 57 +++++-------- 3 files changed, 105 insertions(+), 59 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index ed1e156cab..5ada498c85 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -59,6 +59,7 @@ run_async_penguin_rollout, run_multi_turn_rollout, ) +from nemo_rl.models.generation.fp8 import convert_calibration_to_vllm_format from nemo_rl.models.generation.interfaces import GenerationInterface from nemo_rl.models.generation.vllm import VllmConfig, VllmGeneration from nemo_rl.models.policy import PolicyConfig @@ -500,7 +501,7 @@ def init_vllm(): assert loss_config["use_importance_sampling_correction"] is True, ( "Importance sampling must be enabled for vLLM FP8 generation for good convergence!" ) - if generation_config["vllm_cfg"]["kv_cache_dtype"] == "fp8": + if generation_config["vllm_cfg"].get("kv_cache_dtype") == "fp8": # Temporary additional FP8 KV cache compatibility checks # TODO: Add the related support assert policy_config["dtensor_cfg"]["enabled"] == False, "DTensor backend is not supported with kv cache fp8 enabled." @@ -1104,7 +1105,6 @@ def grpo_train( if sync_kv_scales and kv_scales_cache is None: print("[KV_SCALES] Computing KV cache scales for the first time...") policy.prepare_for_lp_inference() - kv_scales_cache = {} # Create calibration data from flattened messages calibration_data = BatchedDataDict[ClippedPGLossDataDict]( { @@ -1116,16 +1116,8 @@ def grpo_train( calibration_data.update(batched_flat.get_multimodal_dict(as_tensors=False)) calibration_data.to("cpu") kv_scales = policy.calibrate_qkv_fp8_scales(calibration_data, include_q=True)["layers"] - for k, v in kv_scales.items(): - layer_idx = k.split("_")[1] - k_param_name = f"model.layers.{layer_idx}.self_attn.k_scale" - v_param_name = f"model.layers.{layer_idx}.self_attn.v_scale" - # q_param_name is different from k_param_name and v_param_name because vllm handles the param mappings differently for q and k/v - q_param_name = f"model.layers.{layer_idx}.self_attn.attn.q_scale" - - kv_scales_cache[q_param_name] = v["q_scale"] - kv_scales_cache[k_param_name] = v["k_scale"] - kv_scales_cache[v_param_name] = v["v_scale"] + # Convert calibration results to vLLM parameter format + kv_scales_cache = convert_calibration_to_vllm_format(kv_scales) refit_policy_generation( policy, policy_generation, colocated_inference, timer=timer, @@ -1332,15 +1324,8 @@ def grpo_train( with timer.time("recompute_kv_scales"): print("[KV_SCALES] Recomputing KV cache scales after policy update...") kv_scales = policy.calibrate_qkv_fp8_scales(train_data, include_q=True)["layers"] - for k, v in kv_scales.items(): - layer_idx = k.split("_")[1] - k_param_name = f"model.layers.{layer_idx}.self_attn.k_scale" - v_param_name = f"model.layers.{layer_idx}.self_attn.v_scale" - q_param_name = f"model.layers.{layer_idx}.self_attn.attn.q_scale" - - kv_scales_cache[q_param_name] = v["q_scale"] - kv_scales_cache[k_param_name] = v["k_scale"] - kv_scales_cache[v_param_name] = v["v_scale"] + # Convert calibration results to vLLM parameter format + kv_scales_cache = convert_calibration_to_vllm_format(kv_scales) # Set generation as stale to force refit with new scales POLICY_GENERATION_STALE = True diff --git a/nemo_rl/models/generation/fp8.py b/nemo_rl/models/generation/fp8.py index 6291ef7d62..7f401f0020 100644 --- a/nemo_rl/models/generation/fp8.py +++ b/nemo_rl/models/generation/fp8.py @@ -213,6 +213,86 @@ def kv_cache_process_weights_after_loading(self, layer: torch.nn.Module) -> None print(f"[KV_SCALES] Patched process_weights_after_loading: keeping k_scale, v_scale parameters for dynamic updates") +def get_vllm_qkv_scale_names(layer_idx: int) -> dict[str, str]: + """Get vLLM-compatible parameter names for Q/K/V FP8 scales. + + This function centralizes the naming convention for Q/K/V scale parameters + that vLLM expects. These names must match vLLM's internal parameter structure. + + Args: + layer_idx: The transformer layer index (0-based) + + Returns: + Dictionary mapping scale types to vLLM parameter names: + - 'q_scale': Q activation scale name + - 'k_scale': K activation scale name + - 'v_scale': V activation scale name + + Note: + The q_scale has an extra '.attn.' component compared to k_scale/v_scale. + This matches vLLM's parameter remapping logic in: + vllm.model_executor.model_loader.weight_utils.maybe_remap_kv_scale_name + + Example: + >>> get_vllm_qkv_scale_names(0) + { + 'q_scale': 'model.layers.0.self_attn.attn.q_scale', + 'k_scale': 'model.layers.0.self_attn.k_scale', + 'v_scale': 'model.layers.0.self_attn.v_scale' + } + """ + return { + "q_scale": f"model.layers.{layer_idx}.self_attn.attn.q_scale", + "k_scale": f"model.layers.{layer_idx}.self_attn.k_scale", + "v_scale": f"model.layers.{layer_idx}.self_attn.v_scale", + } + + +def convert_calibration_to_vllm_format( + calibration_results: dict[str, dict[str, float]] +) -> dict[str, float]: + """Convert NeMo-RL calibration results to vLLM parameter format. + + This function transforms the calibration output format (with layer_N keys) + into the flat dictionary format that vLLM expects for parameter loading. + + Args: + calibration_results: Dict with keys like "layer_0", "layer_1", etc. + Each value is a dict with keys: "q_scale", "k_scale", "v_scale" + and corresponding float scale values. + + Returns: + Flat dictionary mapping vLLM parameter names to scale values. + Keys follow vLLM's naming convention as defined in get_vllm_qkv_scale_names. + + Example: + >>> calib = { + ... "layer_0": {"q_scale": 1.0, "k_scale": 2.0, "v_scale": 3.0}, + ... "layer_1": {"q_scale": 1.5, "k_scale": 2.5, "v_scale": 3.5} + ... } + >>> convert_calibration_to_vllm_format(calib) + { + 'model.layers.0.self_attn.attn.q_scale': 1.0, + 'model.layers.0.self_attn.k_scale': 2.0, + 'model.layers.0.self_attn.v_scale': 3.0, + 'model.layers.1.self_attn.attn.q_scale': 1.5, + 'model.layers.1.self_attn.k_scale': 2.5, + 'model.layers.1.self_attn.v_scale': 3.5 + } + """ + vllm_scales = {} + for layer_key, scales in calibration_results.items(): + # Extract layer index from "layer_N" format + layer_idx = int(layer_key.split("_")[1]) + param_names = get_vllm_qkv_scale_names(layer_idx) + + vllm_scales[param_names["q_scale"]] = scales["q_scale"] + vllm_scales[param_names["k_scale"]] = scales["k_scale"] + vllm_scales[param_names["v_scale"]] = scales["v_scale"] + + return vllm_scales + + def apply_fp8_patches(self, fp8_config): global global_fp8_config, fp8_patches_applied assert not fp8_patches_applied diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 8d9f33b2a3..edfa4f17f6 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -110,6 +110,7 @@ from_parallel_logits_to_logprobs_packed_sequences, ) from nemo_rl.distributed.named_sharding import NamedSharding +from nemo_rl.models.generation.fp8 import get_vllm_qkv_scale_names from nemo_rl.models.generation.interfaces import ( GenerationDatumSpec, GenerationOutputSpec, @@ -1927,29 +1928,15 @@ def prepare_refit_info(self) -> None: refit_param_info_hf[name] = metadata # Also include KV/Q scale metadata so consumer can rely solely on state_dict_info try: - import re - # Infer number of layers by scanning existing parameter names (robust to various prefixes) - max_layer_idx = -1 - - for k in refit_param_info_hf.keys(): - # print(f"[@@KV_SCALES@@] prepare_refit_info: k = {k}") - idx = self._extract_layer_key(k) - if idx > max_layer_idx: - max_layer_idx = idx - - num_layers = (max_layer_idx + 1) if max_layer_idx >= 0 else 0 + # Get number of layers directly from transformer config + num_layers = self.megatron_bridge.transformer_config.num_layers print(f"[@@KV_SCALES@@] prepare_refit_info: num_layers = {num_layers}") # Append q/k/v scale placeholders (shape [1], dtype float32) for layer_idx in range(num_layers): - q_key = f"model.layers.{layer_idx}.self_attn.attn.q_scale" - k_key = f"model.layers.{layer_idx}.self_attn.k_scale" - v_key = f"model.layers.{layer_idx}.self_attn.v_scale" - if q_key not in refit_param_info_hf: - refit_param_info_hf[q_key] = ([1], torch.float32) - if k_key not in refit_param_info_hf: - refit_param_info_hf[k_key] = ([1], torch.float32) - if v_key not in refit_param_info_hf: - refit_param_info_hf[v_key] = ([1], torch.float32) + scale_names = get_vllm_qkv_scale_names(layer_idx) + for param_name in scale_names.values(): + if param_name not in refit_param_info_hf: + refit_param_info_hf[param_name] = ([1], torch.float32) except Exception: pass return refit_param_info_hf @@ -2028,21 +2015,17 @@ def stream_weights_via_ipc_zmq(self, buffer_size_bytes: int = 0, kv_scales: Opt ) def iter_with_kv_scales(): - max_layer_idx = -1 + # First yield all model weights for name, tensor in hf_params_generator: - idx = self._extract_layer_key(name) - if idx > max_layer_idx: - max_layer_idx = idx - yield name, tensor - num_layers = max_layer_idx + 1 if max_layer_idx >= 0 else 0 + # Get number of layers directly from transformer config + num_layers = self.megatron_bridge.transformer_config.num_layers print(f"[@@KV_SCALES@@] iter_with_kv_scales: num_layers = {num_layers}") keys = [] for layer_idx in range(num_layers): - keys.append(f"model.layers.{layer_idx}.self_attn.attn.q_scale") - keys.append(f"model.layers.{layer_idx}.self_attn.k_scale") - keys.append(f"model.layers.{layer_idx}.self_attn.v_scale") + scale_names = get_vllm_qkv_scale_names(layer_idx) + keys.extend(scale_names.values()) # Always append kv-scale entries to match metadata; use provided value or default 1.0 for param_name in keys: if kv_scales and param_name in kv_scales: @@ -2073,19 +2056,17 @@ def broadcast_weights_for_collective(self, kv_scales: Optional[dict[str, float]] ) def iter_with_kv_scales(): - max_layer_idx = -1 + # First yield all model weights for name, tensor in hf_params_generator: - idx = self._extract_layer_key(name) - if idx > max_layer_idx: - max_layer_idx = idx - num_layers = max_layer_idx + 1 if max_layer_idx >= 0 else 0 + yield name, tensor + + # Get number of layers directly from transformer config + num_layers = self.megatron_bridge.transformer_config.num_layers print(f"[@@KV_SCALES@@] iter_with_kv_scales: num_layers = {num_layers}") keys = [] for layer_idx in range(num_layers): - keys.append(f"model.layers.{layer_idx}.self_attn.attn.q_scale") - keys.append(f"model.layers.{layer_idx}.self_attn.k_scale") - keys.append(f"model.layers.{layer_idx}.self_attn.v_scale") - yield name, tensor + scale_names = get_vllm_qkv_scale_names(layer_idx) + keys.extend(scale_names.values()) # Always append kv-scale entries to match metadata; use provided value or default 1.0 for param_name in keys: if kv_scales and param_name in kv_scales: From bc26b407c1d7300a3174ee28e142db3131eeaa10 Mon Sep 17 00:00:00 2001 From: Zhaopeng Qiu Date: Fri, 14 Nov 2025 06:13:07 -0800 Subject: [PATCH 13/40] lint check Signed-off-by: Zhaopeng Qiu --- nemo_rl/algorithms/grpo.py | 76 ++++++++++----- nemo_rl/models/generation/fp8.py | 97 +++++++++++-------- .../models/generation/vllm/vllm_backend.py | 21 ++-- nemo_rl/models/policy/interfaces.py | 4 +- nemo_rl/models/policy/lm_policy.py | 16 ++- .../models/policy/megatron_policy_worker.py | 83 +++++++++++----- 6 files changed, 191 insertions(+), 106 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 5ada498c85..98fea3df14 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -504,9 +504,15 @@ def init_vllm(): if generation_config["vllm_cfg"].get("kv_cache_dtype") == "fp8": # Temporary additional FP8 KV cache compatibility checks # TODO: Add the related support - assert policy_config["dtensor_cfg"]["enabled"] == False, "DTensor backend is not supported with kv cache fp8 enabled." - assert not _should_use_async_rollouts(master_config), "Async rollouts is not supported with kv cache fp8 enabled." - assert policy_config["megatron_cfg"]["pipeline_model_parallel_size"] == 1, "Pipeline model parallel size must be 1 for megatron backend with kv cache fp8 enabled." + assert policy_config["dtensor_cfg"]["enabled"] == False, ( + "DTensor backend is not supported with kv cache fp8 enabled." + ) + assert not _should_use_async_rollouts(master_config), ( + "Async rollouts is not supported with kv cache fp8 enabled." + ) + assert policy_config["megatron_cfg"]["pipeline_model_parallel_size"] == 1, ( + "Pipeline model parallel size must be 1 for megatron backend with kv cache fp8 enabled." + ) ## make vllm hf overrides match the training policy generation_config["vllm_cfg"]["hf_overrides"] = policy_config.get( @@ -839,11 +845,11 @@ def _should_use_async_rollouts(master_config: MasterConfig) -> bool: generation_config = master_config["policy"]["generation"] if generation_config is None: return False - + backend = generation_config.get("backend", "") if backend != "vllm": return False - + vllm_cfg = generation_config.get("vllm_cfg", {}) return vllm_cfg.get("async_engine", False) @@ -873,24 +879,23 @@ def _should_use_penguin(master_config: MasterConfig) -> bool: # Function to check if KV cache scales should be calculated and synchronized during refit def _should_sync_kv_scales(master_config: MasterConfig) -> bool: - """ - Check if KV cache scales should be synchronized during refit. - + """Check if KV cache scales should be synchronized during refit. + Returns True if: - vLLM precision is fp8 and kv_cache_dtype is fp8 """ generation_config = master_config["policy"]["generation"] if generation_config is None: return False - + backend = generation_config.get("backend", "") if backend != "vllm": return False - + vllm_cfg = generation_config.get("vllm_cfg", {}) kv_cache_dtype = vllm_cfg.get("kv_cache_dtype", "auto") vllm_precision = vllm_cfg.get("precision", "auto") - + return kv_cache_dtype == "fp8" and vllm_precision == "fp8" @@ -1003,12 +1008,16 @@ def grpo_train( # Check if we need to sync KV cache scales (infer from config) sync_kv_scales = _should_sync_kv_scales(master_config) - kv_scales_cache = None # Cache reused for computed kv scales - - if sync_kv_scales: - print(f"[KV_SCALES] FP8 KV cache detected, will sync q_scale, _k_scale and _v_scale during refit") + kv_scales_cache = None # Cache reused for computed kv scales + + if sync_kv_scales: + print( + "[KV_SCALES] FP8 KV cache detected, will sync q_scale, _k_scale and _v_scale during refit" + ) else: - print("[KV_SCALES] KV cache scale sync not needed (non-FP8 mode or kv_cache_dtype is not fp8)") + print( + "[KV_SCALES] KV cache scale sync not needed (non-FP8 mode or kv_cache_dtype is not fp8)" + ) NEED_REFIT = True # If policy_generation is None, use the policy as the generation interface (megatron framework backend) @@ -1103,25 +1112,36 @@ def grpo_train( if NEED_REFIT and POLICY_GENERATION_STALE: # Compute KV scales if needed for FP8 quantization if sync_kv_scales and kv_scales_cache is None: - print("[KV_SCALES] Computing KV cache scales for the first time...") + print( + "[KV_SCALES] Computing KV cache scales for the first time..." + ) policy.prepare_for_lp_inference() # Create calibration data from flattened messages calibration_data = BatchedDataDict[ClippedPGLossDataDict]( { "input_ids": batched_flat["token_ids"], - "input_lengths": input_lengths + "input_lengths": input_lengths, } ) # this will be mini-batched inside the policy, so maintain the packed multimodal structure - calibration_data.update(batched_flat.get_multimodal_dict(as_tensors=False)) + calibration_data.update( + batched_flat.get_multimodal_dict(as_tensors=False) + ) calibration_data.to("cpu") - kv_scales = policy.calibrate_qkv_fp8_scales(calibration_data, include_q=True)["layers"] + kv_scales = policy.calibrate_qkv_fp8_scales( + calibration_data, include_q=True + )["layers"] # Convert calibration results to vLLM parameter format - kv_scales_cache = convert_calibration_to_vllm_format(kv_scales) - + kv_scales_cache = convert_calibration_to_vllm_format( + kv_scales + ) + refit_policy_generation( - policy, policy_generation, colocated_inference, timer=timer, - kv_scales=kv_scales_cache if sync_kv_scales else None + policy, + policy_generation, + colocated_inference, + timer=timer, + kv_scales=kv_scales_cache if sync_kv_scales else None, ) POLICY_GENERATION_STALE = False else: @@ -1322,8 +1342,12 @@ def grpo_train( # Recompute KV scales after policy training if needed if sync_kv_scales: with timer.time("recompute_kv_scales"): - print("[KV_SCALES] Recomputing KV cache scales after policy update...") - kv_scales = policy.calibrate_qkv_fp8_scales(train_data, include_q=True)["layers"] + print( + "[KV_SCALES] Recomputing KV cache scales after policy update..." + ) + kv_scales = policy.calibrate_qkv_fp8_scales( + train_data, include_q=True + )["layers"] # Convert calibration results to vLLM parameter format kv_scales_cache = convert_calibration_to_vllm_format(kv_scales) # Set generation as stale to force refit with new scales diff --git a/nemo_rl/models/generation/fp8.py b/nemo_rl/models/generation/fp8.py index 7f401f0020..42e5e95b12 100644 --- a/nemo_rl/models/generation/fp8.py +++ b/nemo_rl/models/generation/fp8.py @@ -106,19 +106,22 @@ def patched_run_workers(self, *args, **kwargs): def kv_cache_process_weights_after_loading(self, layer: torch.nn.Module) -> None: - """ - Modified version of BaseKVCacheMethod.process_weights_after_loading that doesn't delete - k_scale, v_scale, q_scale, and prob_scale parameters to allow for dynamic updates. + """Modified version of BaseKVCacheMethod.process_weights_after_loading. + + Doesn't delete k_scale, v_scale, q_scale, and prob_scale parameters to allow + for dynamic updates. """ import torch from vllm.logger import init_logger from vllm.platforms import current_platform - + logger = init_logger(__name__) # print(f"[KV_SCALES] kv_cache_process_weights_after_loading: layer.k_scale = {layer.k_scale}, layer.v_scale = {layer.v_scale}") - print(f"[@@KV_SCALES@@] [fp8.py] kv_cache_process_weights_after_loading: layer.k_scale = {layer.k_scale}, layer.v_scale = {layer.v_scale}") - + print( + f"[@@KV_SCALES@@] [fp8.py] kv_cache_process_weights_after_loading: layer.k_scale = {layer.k_scale}, layer.v_scale = {layer.v_scale}" + ) + # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 # regardless whether the kv-scale is available in the checkpoint. # No need to process kv scales after loading if we are going to @@ -148,16 +151,15 @@ def kv_cache_process_weights_after_loading(self, layer: torch.nn.Module) -> None k_scale *= 2 v_scale *= 2 - if not isinstance(k_scale, float) or not isinstance( - v_scale, float): - raise ValueError("Only support per-tensor scaling factor " - "for fp8 KV cache") + if not isinstance(k_scale, float) or not isinstance(v_scale, float): + raise ValueError("Only support per-tensor scaling factor for fp8 KV cache") if layer.q_scale < 0.0: logger.warning_once( "Checkpoint does not provide a q scaling factor. " "Setting it to k_scale. This only matters for " - "the flash-attn backend.") + "the flash-attn backend." + ) layer._q_scale.copy_(k_scale) layer._q_scale_float = k_scale @@ -166,12 +168,12 @@ def kv_cache_process_weights_after_loading(self, layer: torch.nn.Module) -> None layer._v_scale.copy_(v_scale) layer._k_scale_float = k_scale layer._v_scale_float = v_scale - if (k_scale == 1.0 and v_scale == 1.0 - and "e5m2" not in layer.kv_cache_dtype): + if k_scale == 1.0 and v_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype: logger.warning_once( "Using KV cache scaling factor 1.0 for fp8_e4m3. This " "may cause accuracy issues. Please make sure k/v_scale " - "scaling factors are available in the fp8 checkpoint.") + "scaling factors are available in the fp8 checkpoint." + ) if layer.q_scale > 0.0: q_scale = layer.q_scale @@ -187,52 +189,59 @@ def kv_cache_process_weights_after_loading(self, layer: torch.nn.Module) -> None else: prob_scale = 1.0 - is_singleton_float = lambda x: isinstance(x, float) or isinstance( - x, torch.Tensor) and x.numel() == 1 and x.is_floating_point() - if not is_singleton_float(q_scale) or not is_singleton_float( - prob_scale): - raise ValueError("Only support per-tensor scaling factor" - "for fp8-quantized Q/prob") + is_singleton_float = ( + lambda x: isinstance(x, float) + or isinstance(x, torch.Tensor) + and x.numel() == 1 + and x.is_floating_point() + ) + if not is_singleton_float(q_scale) or not is_singleton_float(prob_scale): + raise ValueError( + "Only support per-tensor scaling factorfor fp8-quantized Q/prob" + ) # These are used in the final Attention.forward() layer._q_scale.copy_(q_scale) - layer._q_scale_float = q_scale.item() if isinstance( - q_scale, torch.Tensor) else q_scale + layer._q_scale_float = ( + q_scale.item() if isinstance(q_scale, torch.Tensor) else q_scale + ) layer._prob_scale.copy_(prob_scale) - if layer.kv_cache_dtype == "fp8" and (q_scale == 1.0 - or prob_scale == 1.0): + if layer.kv_cache_dtype == "fp8" and (q_scale == 1.0 or prob_scale == 1.0): logger.warning_once( f"Using uncalibrated q_scale {q_scale} and/or prob_scale " f"{prob_scale} with fp8 attention. This may cause accuracy " "issues. Please make sure q/prob scaling factors are " - "available in the fp8 checkpoint.") + "available in the fp8 checkpoint." + ) # IMPORTANT: We DON'T delete the parameters here to allow for dynamic updates # Original code deleted: layer.k_scale, layer.v_scale, layer.q_scale, layer.prob_scale - print(f"[KV_SCALES] Patched process_weights_after_loading: keeping k_scale, v_scale parameters for dynamic updates") + print( + "[KV_SCALES] Patched process_weights_after_loading: keeping k_scale, v_scale parameters for dynamic updates" + ) def get_vllm_qkv_scale_names(layer_idx: int) -> dict[str, str]: """Get vLLM-compatible parameter names for Q/K/V FP8 scales. - + This function centralizes the naming convention for Q/K/V scale parameters that vLLM expects. These names must match vLLM's internal parameter structure. - + Args: layer_idx: The transformer layer index (0-based) - + Returns: Dictionary mapping scale types to vLLM parameter names: - 'q_scale': Q activation scale name - - 'k_scale': K activation scale name + - 'k_scale': K activation scale name - 'v_scale': V activation scale name - + Note: The q_scale has an extra '.attn.' component compared to k_scale/v_scale. This matches vLLM's parameter remapping logic in: vllm.model_executor.model_loader.weight_utils.maybe_remap_kv_scale_name - + Example: >>> get_vllm_qkv_scale_names(0) { @@ -249,22 +258,22 @@ def get_vllm_qkv_scale_names(layer_idx: int) -> dict[str, str]: def convert_calibration_to_vllm_format( - calibration_results: dict[str, dict[str, float]] + calibration_results: dict[str, dict[str, float]], ) -> dict[str, float]: """Convert NeMo-RL calibration results to vLLM parameter format. - + This function transforms the calibration output format (with layer_N keys) into the flat dictionary format that vLLM expects for parameter loading. - + Args: calibration_results: Dict with keys like "layer_0", "layer_1", etc. Each value is a dict with keys: "q_scale", "k_scale", "v_scale" and corresponding float scale values. - + Returns: Flat dictionary mapping vLLM parameter names to scale values. Keys follow vLLM's naming convention as defined in get_vllm_qkv_scale_names. - + Example: >>> calib = { ... "layer_0": {"q_scale": 1.0, "k_scale": 2.0, "v_scale": 3.0}, @@ -285,11 +294,11 @@ def convert_calibration_to_vllm_format( # Extract layer index from "layer_N" format layer_idx = int(layer_key.split("_")[1]) param_names = get_vllm_qkv_scale_names(layer_idx) - + vllm_scales[param_names["q_scale"]] = scales["q_scale"] vllm_scales[param_names["k_scale"]] = scales["k_scale"] vllm_scales[param_names["v_scale"]] = scales["v_scale"] - + return vllm_scales @@ -393,7 +402,9 @@ def init_fp8(vllm_cfg, model_name, model_parallel_size): vllm_kwargs = { "quantization": "fp8", # Conditionally set kv_cache_dtype to fp8 if global config kv_cache_dtype is fp8 - "kv_cache_dtype": "fp8" if global_fp8_config.kv_cache_dtype == "fp8" else "auto", + "kv_cache_dtype": "fp8" + if global_fp8_config.kv_cache_dtype == "fp8" + else "auto", "hf_overrides": {"quantization_config": fp8_block_quant_kwargs}, } return vllm_kwargs @@ -492,7 +503,9 @@ def load_weights(weights, model_runner): for k, v in weights: if "scale" in k: - print(f"[@@KV_SCALES@@] [fp8.py] load_weights: Parameter {k}, value = {v.item() if v.numel() == 1 else v}") + print( + f"[@@KV_SCALES@@] [fp8.py] load_weights: Parameter {k}, value = {v.item() if v.numel() == 1 else v}" + ) if not _is_fp8_weight(k, model): weights_quantized.append((k, v)) continue @@ -755,4 +768,4 @@ def per_token_group_quant_fp8( per_token_group_quant_fp8 as vllm_per_token_group_quant_fp8, ) - return vllm_per_token_group_quant_fp8(*args, **kwargs) \ No newline at end of file + return vllm_per_token_group_quant_fp8(*args, **kwargs) diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index a4d1f9163b..648ad548a5 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -13,7 +13,8 @@ # limitations under the License. import gc import traceback -from typing import Any, Optional +from typing import Any + import torch import zmq @@ -160,20 +161,22 @@ def update_weights_via_ipc_zmq(self) -> bool: buffer = None self.zmq_socket.send(IPCProtocol.ACK.value.encode()) - # When kv_scales is provided, we need to invoke process_weights_after_loading() + # When kv_scales is provided, we need to invoke process_weights_after_loading() # to copy the kv scales to the _k_scale and _v_scale attributes used during inference # if kv_scales: - print(f"[@@KV_SCALES@@] [vllm_backend.py] update_weights_via_ipc_zmq: Processing KV cache scales after weight loading") - from vllm.model_executor.model_loader.utils import process_weights_after_loading - + print( + "[@@KV_SCALES@@] [vllm_backend.py] update_weights_via_ipc_zmq: Processing KV cache scales after weight loading" + ) + from vllm.model_executor.model_loader.utils import ( + process_weights_after_loading, + ) + # Get target device for processing target_device = next(self.model_runner.model.parameters()).device - + # Call process_weights_after_loading to handle KV scales process_weights_after_loading( - self.model_runner.model, - self.model_runner.model_config, - target_device + self.model_runner.model, self.model_runner.model_config, target_device ) gc.collect() diff --git a/nemo_rl/models/policy/interfaces.py b/nemo_rl/models/policy/interfaces.py index e486821536..70e7ca4459 100644 --- a/nemo_rl/models/policy/interfaces.py +++ b/nemo_rl/models/policy/interfaces.py @@ -189,5 +189,7 @@ def stream_weights_via_ipc_zmq( pass @abstractmethod - def broadcast_weights_for_collective(self, kv_scales: Optional[dict[str, float]] = None) -> list[ray.ObjectRef]: + def broadcast_weights_for_collective( + self, kv_scales: Optional[dict[str, float]] = None + ) -> list[ray.ObjectRef]: pass diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index 72a67b1458..5e63c66f73 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -701,8 +701,8 @@ def calibrate_qkv_fp8_scales( ) -> dict[str, Any]: """Trigger KV-cache FP8 scale calibration across Megatron workers and return results. - Note: The backend `MegatronPolicyWorker.calibrate_qkv_fp8_scales` already implements - distributed reduction, returning results merged across ranks. Therefore, we shard the + Note: The backend `MegatronPolicyWorker.calibrate_qkv_fp8_scales` already implements + distributed reduction, returning results merged across ranks. Therefore, we shard the input by DP and call in parallel, then take the result from the first worker. """ dp_size = self.sharding_annotations.get_axis_size("data_parallel") @@ -762,14 +762,20 @@ def get_free_memory_bytes(self) -> int: free_memory_bytes = min(ray.get(future) for future in futures) return free_memory_bytes - def stream_weights_via_ipc_zmq(self, buffer_size_bytes: int, kv_scales: Optional[dict[str, float]] = None) -> list[ray.ObjectRef]: + def stream_weights_via_ipc_zmq( + self, buffer_size_bytes: int, kv_scales: Optional[dict[str, float]] = None + ) -> list[ray.ObjectRef]: """Send the weights for IPC handles via ZMQ socket.""" futures = self.worker_group.run_all_workers_single_data( - "stream_weights_via_ipc_zmq", buffer_size_bytes=buffer_size_bytes, kv_scales=kv_scales + "stream_weights_via_ipc_zmq", + buffer_size_bytes=buffer_size_bytes, + kv_scales=kv_scales, ) return futures - def broadcast_weights_for_collective(self, kv_scales: Optional[dict[str, float]] = None) -> list[ray.ObjectRef]: + def broadcast_weights_for_collective( + self, kv_scales: Optional[dict[str, float]] = None + ) -> list[ray.ObjectRef]: """Broadcast the weights for collective communication.""" futures = self.worker_group.run_all_workers_single_data( "broadcast_weights_for_collective", diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index edfa4f17f6..176cba0042 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import gc +import json import math import os +import re import time import warnings from collections import defaultdict -import json -import re from contextlib import AbstractContextManager, contextmanager, nullcontext from functools import partial from typing import Any, Iterator, Optional, TypeVar @@ -2001,7 +2001,9 @@ def get_free_memory_bytes(self) -> int: @torch.no_grad() @wrap_with_nvtx_name("megatron_policy_worker/stream_weights_via_ipc_zmq") - def stream_weights_via_ipc_zmq(self, buffer_size_bytes: int = 0, kv_scales: Optional[dict[str, float]] = None) -> None: + def stream_weights_via_ipc_zmq( + self, buffer_size_bytes: int = 0, kv_scales: Optional[dict[str, float]] = None + ) -> None: """Stream model weights to peer process via ZMQ IPC socket.""" self.maybe_init_zmq() @@ -2047,7 +2049,9 @@ def iter_with_kv_scales(): ) @torch.no_grad() - def broadcast_weights_for_collective(self, kv_scales: Optional[dict[str, float]] = None) -> None: + def broadcast_weights_for_collective( + self, kv_scales: Optional[dict[str, float]] = None + ) -> None: """Broadcast the weights for collective communication.""" hf_params_generator = self.megatron_bridge.export_hf_weights( [self.model], @@ -2059,7 +2063,7 @@ def iter_with_kv_scales(): # First yield all model weights for name, tensor in hf_params_generator: yield name, tensor - + # Get number of layers directly from transformer config num_layers = self.megatron_bridge.transformer_config.num_layers print(f"[@@KV_SCALES@@] iter_with_kv_scales: num_layers = {num_layers}") @@ -2265,7 +2269,9 @@ def save_checkpoint( # Temporary fix to avoid OOM after saving checkpoint allocated = torch.cuda.memory_allocated() / (1024**3) # Convert to GB reserved = torch.cuda.memory_reserved() / (1024**3) # Convert to GB - print(f"GPU Memory before saving checkpoint: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved") + print( + f"GPU Memory before saving checkpoint: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved" + ) if not torch.distributed.is_initialized(): raise RuntimeError( "Distributed process group is not initialized. Cannot save checkpoint." @@ -2326,7 +2332,7 @@ def save_checkpoint( if not is_training: # Restore training state if it was changed self.model.train() - + # Temporary fix to avoid OOM after saving checkpoint: https://github.com/NVIDIA-NeMo/RL/issues/1057 torch.randn(1).cuda() # wake up torch allocator if hasattr(self, "optimizer") and self.optimizer is not None: @@ -2348,8 +2354,9 @@ def save_checkpoint( allocated = torch.cuda.memory_allocated() / (1024**3) # Convert to GB reserved = torch.cuda.memory_reserved() / (1024**3) # Convert to GB - print(f"GPU Memory after saving checkpoint: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved") - + print( + f"GPU Memory after saving checkpoint: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved" + ) except Exception as e: print(f"Failed to save checkpoint to {weights_path}: {e}") @@ -2507,11 +2514,19 @@ def _pre_hook(module, inputs): k = args[1] v = args[2] if include_q: - layer_to_samples_q[layer_key].append(float(torch.amax(torch.abs(q)).item())) - layer_to_samples_k[layer_key].append(float(torch.amax(torch.abs(k)).item())) - layer_to_samples_v[layer_key].append(float(torch.amax(torch.abs(v)).item())) + layer_to_samples_q[layer_key].append( + float(torch.amax(torch.abs(q)).item()) + ) + layer_to_samples_k[layer_key].append( + float(torch.amax(torch.abs(k)).item()) + ) + layer_to_samples_v[layer_key].append( + float(torch.amax(torch.abs(v)).item()) + ) except Exception as e: - print(f"[KV_SCALES] Error in core_attention pre-hook on {module_name}: {e}") + print( + f"[KV_SCALES] Error in core_attention pre-hook on {module_name}: {e}" + ) pass return _pre_hook @@ -2521,7 +2536,9 @@ def _pre_hook(module, inputs): for name, module in self.model.named_modules(): if "self_attention.core_attention" in name: try: - handle = module.register_forward_pre_hook(_pre_hook_builder_core_attention(name)) + handle = module.register_forward_pre_hook( + _pre_hook_builder_core_attention(name) + ) hook_handles.append(handle) matched_modules.append((name, module.__class__.__name__, "pre")) except Exception as e: @@ -2529,7 +2546,9 @@ def _pre_hook(module, inputs): continue if not hook_handles: - print("[KV_SCALES] No QKV proj modules matched for hook. Example module/param names:") + print( + "[KV_SCALES] No QKV proj modules matched for hook. Example module/param names:" + ) try: # Print the first 10 modules and parameters to help locate the actual names cnt = 0 @@ -2561,27 +2580,43 @@ def _pre_hook(module, inputs): h.remove() except Exception as e: print(f"[KV_SCALES] Error removing hook: {e}") - pass + pass # Compute local percentile amax def _percentile(values: list[float], p: float) -> float: if not values: return 0.0 t = torch.tensor(sorted(values), device="cuda", dtype=torch.float32) - rank = max(0, min(len(values) - 1, int(round((p / 100.0) * (len(values) - 1))))) + rank = max( + 0, min(len(values) - 1, int(round((p / 100.0) * (len(values) - 1)))) + ) return float(t[rank].item()) local_layer_to_pamax = {} - for layer_key in set(list(layer_to_samples_k.keys()) + list(layer_to_samples_v.keys()) + (list(layer_to_samples_q.keys()) if include_q else [])): + for layer_key in set( + list(layer_to_samples_k.keys()) + + list(layer_to_samples_v.keys()) + + (list(layer_to_samples_q.keys()) if include_q else []) + ): entry = {} if include_q: - entry["q_amax_p"] = _percentile(layer_to_samples_q.get(layer_key, []), percentile) - entry["k_amax_p"] = _percentile(layer_to_samples_k.get(layer_key, []), percentile) - entry["v_amax_p"] = _percentile(layer_to_samples_v.get(layer_key, []), percentile) + entry["q_amax_p"] = _percentile( + layer_to_samples_q.get(layer_key, []), percentile + ) + entry["k_amax_p"] = _percentile( + layer_to_samples_k.get(layer_key, []), percentile + ) + entry["v_amax_p"] = _percentile( + layer_to_samples_v.get(layer_key, []), percentile + ) local_layer_to_pamax[layer_key] = entry # Merge across all ranks: take maximum of percentile amax (conservative approach) - world_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 + world_size = ( + torch.distributed.get_world_size() + if torch.distributed.is_initialized() + else 1 + ) gathered = [None for _ in range(world_size)] if world_size > 1 else None if world_size > 1: torch.distributed.all_gather_object(gathered, local_layer_to_pamax) @@ -2630,7 +2665,9 @@ def _percentile(values: list[float], p: float) -> float: final_result = obj_list[0] # type: ignore # Optional save to JSON (only rank0) - if save_path is not None and (not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0): + if save_path is not None and ( + not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 + ): try: with open(save_path, "w") as f: json.dump(final_result, f) From 7f709b4670060bf0e5c7f1ed5ceccdd82c099add Mon Sep 17 00:00:00 2001 From: Shuang Yu Date: Fri, 14 Nov 2025 23:16:38 -0800 Subject: [PATCH 14/40] Update to correct BF16 issue with load_weights Signed-off-by: Zhaopeng Qiu --- nemo_rl/algorithms/grpo.py | 28 ++-- nemo_rl/models/generation/fp8.py | 87 +++++++----- .../models/generation/vllm/vllm_backend.py | 68 ++++++--- nemo_rl/models/generation/vllm/vllm_worker.py | 18 ++- .../models/policy/megatron_policy_worker.py | 131 +++++++++++------- 5 files changed, 218 insertions(+), 114 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 98fea3df14..1cfa50b3c1 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -502,7 +502,8 @@ def init_vllm(): "Importance sampling must be enabled for vLLM FP8 generation for good convergence!" ) if generation_config["vllm_cfg"].get("kv_cache_dtype") == "fp8": - # Temporary additional FP8 KV cache compatibility checks + # FP8 KV cache compatibility checks + # These checks are independent of model precision (can use bf16 or fp8 weights with fp8 KV cache) # TODO: Add the related support assert policy_config["dtensor_cfg"]["enabled"] == False, ( "DTensor backend is not supported with kv cache fp8 enabled." @@ -882,7 +883,10 @@ def _should_sync_kv_scales(master_config: MasterConfig) -> bool: """Check if KV cache scales should be synchronized during refit. Returns True if: - - vLLM precision is fp8 and kv_cache_dtype is fp8 + - kv_cache_dtype is fp8 (independent of precision) + + Note: KV cache scales are only needed when kv_cache_dtype is FP8. + The model precision (fp8 or bf16) is independent of this requirement. """ generation_config = master_config["policy"]["generation"] if generation_config is None: @@ -894,9 +898,9 @@ def _should_sync_kv_scales(master_config: MasterConfig) -> bool: vllm_cfg = generation_config.get("vllm_cfg", {}) kv_cache_dtype = vllm_cfg.get("kv_cache_dtype", "auto") - vllm_precision = vllm_cfg.get("precision", "auto") - - return kv_cache_dtype == "fp8" and vllm_precision == "fp8" + + # Only check kv_cache_dtype, not precision. This allows: precision=bf16 with kv_cache_dtype=fp8 + return kv_cache_dtype == "fp8" def refit_policy_generation( @@ -1008,16 +1012,12 @@ def grpo_train( # Check if we need to sync KV cache scales (infer from config) sync_kv_scales = _should_sync_kv_scales(master_config) - kv_scales_cache = None # Cache reused for computed kv scales - - if sync_kv_scales: - print( - "[KV_SCALES] FP8 KV cache detected, will sync q_scale, _k_scale and _v_scale during refit" - ) + kv_scales_cache = None # Cache reused for computed kv scales + + if sync_kv_scales: + print(f"[KV_SCALES] FP8 KV cache enabled (kv_cache_dtype=fp8), will sync q_scale, k_scale and v_scale during refit") else: - print( - "[KV_SCALES] KV cache scale sync not needed (non-FP8 mode or kv_cache_dtype is not fp8)" - ) + print("[KV_SCALES] KV cache scale sync not needed (kv_cache_dtype is not fp8, regardless of model precision)") NEED_REFIT = True # If policy_generation is None, use the policy as the generation interface (megatron framework backend) diff --git a/nemo_rl/models/generation/fp8.py b/nemo_rl/models/generation/fp8.py index 42e5e95b12..67252372e4 100644 --- a/nemo_rl/models/generation/fp8.py +++ b/nemo_rl/models/generation/fp8.py @@ -41,6 +41,7 @@ class FP8Config: num_last_layers_in_bf16: int = 0 model_parallel_size: int = None kv_cache_dtype: str = "auto" + use_fp8_weights: bool = True # Whether model weights are quantized to FP8 @dataclass() @@ -308,29 +309,42 @@ def apply_fp8_patches(self, fp8_config): global_fp8_config = fp8_config - # This patch is used to support torch.compile with vllm parameter subclasses, such as - # PerTensorScaleParameter. Because we need weight loaders to update fp8 weights each - # refit, we patch fp8 parameters to have a reference to their weight loader. Eventually - # with pytorch 2.8, parameter subclassing with torch.compile will be natively supported, in - # which this patch can be removed. - func1_path = "vllm.model_executor.layers.quantization.fp8.Fp8LinearMethod.process_weights_after_loading" - patcher1 = patch(func1_path, process_weights_after_loading) - fp8_state.vllm_patches.append(patcher1) - # These patches add support for pow2, e8 dynamic activation scalings factors which are believed to have higher - # SNR compared to plain fp32 scaling factors. This feature is still under active research. - if global_fp8_config.use_activation_pow2_scale: - func2_path = "vllm.model_executor.layers.quantization.utils.fp8_utils.per_token_group_quant_fp8" - func3_path = "vllm.model_executor.layers.quantization.utils.fp8_utils._per_token_group_quant_fp8" - func4_path = "vllm.model_executor.layers.quantization.utils.fp8_utils._per_token_group_quant_fp8_colmajor" - patcher2 = patch(func2_path, per_token_group_quant_fp8) - patcher3 = patch(func3_path, _per_token_group_quant_fp8) - patcher4 = patch(func4_path, _per_token_group_quant_fp8_colmajor) - fp8_state.vllm_patches.append(patcher2, patcher3, patcher4) - - # Patch the vllm kv_cache.py process_weights_after_loading() to remove the deletion of k_scale and v_scale - func5_path = "vllm.model_executor.layers.quantization.kv_cache.BaseKVCacheMethod.process_weights_after_loading" - patcher5 = patch(func5_path, kv_cache_process_weights_after_loading) - fp8_state.vllm_patches.append(patcher5) + # Apply patches conditionally based on configuration + # Only apply weight patches if using FP8 weights + # Only apply KV cache patches if using FP8 KV cache + + # Apply weight-related patches only when using FP8 weights (precision=fp8) + if global_fp8_config.use_fp8_weights: + print("[FP8_PATCHES] Applying FP8 weight quantization patches (precision=fp8)") + + # This patch is used to support torch.compile with vllm parameter subclasses, such as + # PerTensorScaleParameter. Because we need weight loaders to update fp8 weights each + # refit, we patch fp8 parameters to have a reference to their weight loader. Eventually + # with pytorch 2.8, parameter subclassing with torch.compile will be natively supported, in + # which this patch can be removed. + func1_path = "vllm.model_executor.layers.quantization.fp8.Fp8LinearMethod.process_weights_after_loading" + patcher1 = patch(func1_path, process_weights_after_loading) + fp8_state.vllm_patches.append(patcher1) + + # These patches add support for pow2, e8 dynamic activation scalings factors which are believed to have higher + # SNR compared to plain fp32 scaling factors. This feature is still under active research. + if global_fp8_config.use_activation_pow2_scale: + func2_path = "vllm.model_executor.layers.quantization.utils.fp8_utils.per_token_group_quant_fp8" + func3_path = "vllm.model_executor.layers.quantization.utils.fp8_utils._per_token_group_quant_fp8" + func4_path = "vllm.model_executor.layers.quantization.utils.fp8_utils._per_token_group_quant_fp8_colmajor" + patcher2 = patch(func2_path, per_token_group_quant_fp8) + patcher3 = patch(func3_path, _per_token_group_quant_fp8) + patcher4 = patch(func4_path, _per_token_group_quant_fp8_colmajor) + fp8_state.vllm_patches.append(patcher2, patcher3, patcher4) + + # Apply KV cache patches only when using FP8 KV cache (kv_cache_dtype=fp8) + if global_fp8_config.kv_cache_dtype == "fp8": + print("[FP8_PATCHES] Applying FP8 KV cache patches (kv_cache_dtype=fp8)") + + # Patch the vllm kv_cache.py process_weights_after_loading() to remove the deletion of k_scale and v_scale + func5_path = "vllm.model_executor.layers.quantization.kv_cache.BaseKVCacheMethod.process_weights_after_loading" + patcher5 = patch(func5_path, kv_cache_process_weights_after_loading) + fp8_state.vllm_patches.append(patcher5) for p in fp8_state.vllm_patches: p.start() @@ -346,6 +360,9 @@ def init_fp8(vllm_cfg, model_name, model_parallel_size): ) global global_fp8_config + # Determine if we're using FP8 weights based on precision setting + use_fp8_weights = vllm_cfg.get("precision") == "fp8" + global_fp8_config = FP8Config( use_weight_pow2_scale=vllm_cfg.get("pow2_weight_scaling_factors", False), use_activation_pow2_scale=vllm_cfg.get( @@ -355,6 +372,7 @@ def init_fp8(vllm_cfg, model_name, model_parallel_size): num_last_layers_in_bf16=vllm_cfg.get("num_last_layers_in_bf16", 0), model_parallel_size=model_parallel_size, kv_cache_dtype=vllm_cfg.get("kv_cache_dtype", "auto"), + use_fp8_weights=use_fp8_weights, ) if vllm_cfg.get("use_deep_gemm", False): @@ -399,14 +417,21 @@ def init_fp8(vllm_cfg, model_name, model_parallel_size): # TODO: Remove this after debugging. print(f"[KV_SCALES] Global FP8 config: {global_fp8_config}") - vllm_kwargs = { - "quantization": "fp8", - # Conditionally set kv_cache_dtype to fp8 if global config kv_cache_dtype is fp8 - "kv_cache_dtype": "fp8" - if global_fp8_config.kv_cache_dtype == "fp8" - else "auto", - "hf_overrides": {"quantization_config": fp8_block_quant_kwargs}, - } + + # CHANGE: Return different kwargs based on whether we're using FP8 weights + if use_fp8_weights: + # Full FP8: quantize weights and optionally use FP8 KV cache + vllm_kwargs = { + "quantization": "fp8", + "kv_cache_dtype": kv_cache_dtype, + "hf_overrides": {"quantization_config": fp8_block_quant_kwargs}, + } + else: + # Only FP8 KV cache, no weight quantization + vllm_kwargs = { + "kv_cache_dtype": kv_cache_dtype, + } + return vllm_kwargs diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index 648ad548a5..6034e31f85 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -161,23 +161,31 @@ def update_weights_via_ipc_zmq(self) -> bool: buffer = None self.zmq_socket.send(IPCProtocol.ACK.value.encode()) - # When kv_scales is provided, we need to invoke process_weights_after_loading() - # to copy the kv scales to the _k_scale and _v_scale attributes used during inference - # if kv_scales: - print( - "[@@KV_SCALES@@] [vllm_backend.py] update_weights_via_ipc_zmq: Processing KV cache scales after weight loading" - ) - from vllm.model_executor.model_loader.utils import ( - process_weights_after_loading, - ) - - # Get target device for processing - target_device = next(self.model_runner.model.parameters()).device - - # Call process_weights_after_loading to handle KV scales - process_weights_after_loading( - self.model_runner.model, self.model_runner.model_config, target_device - ) + # CHANGE: Only invoke process_weights_after_loading when kv_cache_dtype is FP8 + # Check if KV cache is using FP8 + use_fp8_kv_cache = False + if hasattr(self.model_runner.vllm_config, 'cache_config'): + kv_cache_dtype = getattr(self.model_runner.vllm_config.cache_config, 'cache_dtype', None) + print(f"[KV_SCALES] [vllm_backend.py] update_weights_via_ipc_zmq: kv_cache_dtype is {kv_cache_dtype}") + use_fp8_kv_cache = kv_cache_dtype is not None and 'fp8' in str(kv_cache_dtype).lower() + + if use_fp8_kv_cache: + # When kv_scales is provided, we need to invoke process_weights_after_loading() + # to copy the kv scales to the _k_scale and _v_scale attributes used during inference + print(f"[@@KV_SCALES@@] [vllm_backend.py] update_weights_via_ipc_zmq: kv_cache_dtype is FP8, processing KV cache scales after weight loading") + from vllm.model_executor.model_loader.utils import process_weights_after_loading + + # Get target device for processing + target_device = next(self.model_runner.model.parameters()).device + + # Call process_weights_after_loading to handle KV scales + process_weights_after_loading( + self.model_runner.model, + self.model_runner.model_config, + target_device + ) + else: + print(f"[KV_SCALES] [vllm_backend.py] update_weights_via_ipc_zmq: kv_cache_dtype is not FP8, skipping process_weights_after_loading") gc.collect() torch.cuda.empty_cache() @@ -226,6 +234,32 @@ def _load_model_weights(weights, model_runner): src=0, post_unpack_func=load_model_weight_func, ) + + # CHANGE: Only invoke process_weights_after_loading when kv_cache_dtype is FP8 + # Check if KV cache is using FP8 + use_fp8_kv_cache = False + if hasattr(self.model_runner.vllm_config, 'cache_config'): + kv_cache_dtype = getattr(self.model_runner.vllm_config.cache_config, 'cache_dtype', None) + use_fp8_kv_cache = kv_cache_dtype is not None and 'fp8' in str(kv_cache_dtype).lower() + + if use_fp8_kv_cache: + # When KV scales are broadcast, we need to invoke process_weights_after_loading() + # to copy the kv scales to the _k_scale and _v_scale attributes used during inference + print(f"[@@KV_SCALES@@] [vllm_backend.py] update_weights_from_collective: kv_cache_dtype is FP8, processing KV cache scales after weight loading") + from vllm.model_executor.model_loader.utils import process_weights_after_loading + + # Get target device for processing + target_device = next(self.model_runner.model.parameters()).device + + # Call process_weights_after_loading to handle KV scales + process_weights_after_loading( + self.model_runner.model, + self.model_runner.model_config, + target_device + ) + else: + print(f"[KV_SCALES] [vllm_backend.py] update_weights_from_collective: kv_cache_dtype is not FP8, skipping process_weights_after_loading") + except Exception as e: print( f"Error in VllmInternalWorkerExtension.update_weights_from_collective: {e}" diff --git a/nemo_rl/models/generation/vllm/vllm_worker.py b/nemo_rl/models/generation/vllm/vllm_worker.py index a97d68e669..0ca55a2ebb 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker.py +++ b/nemo_rl/models/generation/vllm/vllm_worker.py @@ -287,15 +287,25 @@ def _patch_vllm_init_workers_ray(): ) vllm_kwargs["ray_workers_use_nsight"] = True - if self.cfg["vllm_cfg"]["precision"] == "fp8": + # Call init_fp8 when either precision is fp8 OR kv_cache_dtype is fp8 + # This ensures vLLM patches are applied for KV cache FP8 even with bf16 weights + if self.cfg["vllm_cfg"]["precision"] == "fp8" or self.cfg["vllm_cfg"].get("kv_cache_dtype") == "fp8": from nemo_rl.models.generation.fp8 import init_fp8 fp8_kwargs = init_fp8( self.cfg["vllm_cfg"], self.model_name, model_parallel_size ) - vllm_kwargs.update(fp8_kwargs) - # overriden by quant config, however vllm complains if this not passed - self.precision = "bfloat16" + + # For FP8 precision, we need quantization="fp8" and weight quantization config + if self.cfg["vllm_cfg"]["precision"] == "fp8": + vllm_kwargs.update(fp8_kwargs) + # overriden by quant config, however vllm complains if this not passed + self.precision = "bfloat16" + else: + # For non-FP8 precision with FP8 KV cache, only set kv_cache_dtype + # Don't set quantization="fp8" as weights are not quantized + vllm_kwargs["kv_cache_dtype"] = fp8_kwargs["kv_cache_dtype"] + print(f"[KV_SCALES] Using FP8 KV cache with precision={self.precision} (weights not quantized)") if not isinstance(vllm_kwargs.get("hf_overrides"), dict): vllm_kwargs["hf_overrides"] = {} diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 176cba0042..71491fb65f 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -1926,19 +1926,32 @@ def prepare_refit_info(self) -> None: for name, tensor in hf_params_generator: metadata = (tensor.shape, tensor.dtype) refit_param_info_hf[name] = metadata - # Also include KV/Q scale metadata so consumer can rely solely on state_dict_info - try: - # Get number of layers directly from transformer config - num_layers = self.megatron_bridge.transformer_config.num_layers - print(f"[@@KV_SCALES@@] prepare_refit_info: num_layers = {num_layers}") - # Append q/k/v scale placeholders (shape [1], dtype float32) - for layer_idx in range(num_layers): - scale_names = get_vllm_qkv_scale_names(layer_idx) - for param_name in scale_names.values(): - if param_name not in refit_param_info_hf: - refit_param_info_hf[param_name] = ([1], torch.float32) - except Exception: - pass + + # Only include KV/Q scale metadata when kv_cache_dtype is FP8 + # Check if we're using FP8 KV cache + use_fp8_kv_cache = False + if "generation" in self.cfg and self.cfg["generation"] is not None: + vllm_cfg = self.cfg["generation"].get("vllm_cfg", {}) + kv_cache_dtype = vllm_cfg.get("kv_cache_dtype", "auto") + use_fp8_kv_cache = kv_cache_dtype == "fp8" + + if use_fp8_kv_cache: + # Include KV/Q scale metadata so consumer can rely solely on state_dict_info + try: + # Get number of layers directly from transformer config + num_layers = self.megatron_bridge.transformer_config.num_layers + print(f"[@@KV_SCALES@@] prepare_refit_info: kv_cache_dtype=fp8, adding scale metadata for {num_layers} layers") + # Append q/k/v scale placeholders (shape [1], dtype float32) + for layer_idx in range(num_layers): + scale_names = get_vllm_qkv_scale_names(layer_idx) + for param_name in scale_names.values(): + if param_name not in refit_param_info_hf: + refit_param_info_hf[param_name] = ([1], torch.float32) + except Exception: + pass + else: + print(f"[KV_SCALES] prepare_refit_info: kv_cache_dtype is not fp8, skipping KV scale metadata") + return refit_param_info_hf def _calculate_refit_param_info(self) -> list[tuple[str, int]]: @@ -2016,28 +2029,39 @@ def stream_weights_via_ipc_zmq( conversion_tasks=self.refit_conversion_tasks, # used for metadata caching ) + # CHANGE: Only check and stream KV scales when kv_cache_dtype is FP8 + use_fp8_kv_cache = False + if "generation" in self.cfg and self.cfg["generation"] is not None: + vllm_cfg = self.cfg["generation"].get("vllm_cfg", {}) + kv_cache_dtype = vllm_cfg.get("kv_cache_dtype", "auto") + use_fp8_kv_cache = kv_cache_dtype == "fp8" + def iter_with_kv_scales(): # First yield all model weights for name, tensor in hf_params_generator: yield name, tensor - # Get number of layers directly from transformer config - num_layers = self.megatron_bridge.transformer_config.num_layers - print(f"[@@KV_SCALES@@] iter_with_kv_scales: num_layers = {num_layers}") - keys = [] - for layer_idx in range(num_layers): - scale_names = get_vllm_qkv_scale_names(layer_idx) - keys.extend(scale_names.values()) - # Always append kv-scale entries to match metadata; use provided value or default 1.0 - for param_name in keys: - if kv_scales and param_name in kv_scales: - scale_value = kv_scales[param_name] - else: - scale_value = 1.0 - scale_tensor = torch.tensor( - scale_value, dtype=torch.float32, device="cuda" - ).reshape(1) - yield param_name, scale_tensor + # CHANGE: Only append KV scales when kv_cache_dtype is FP8 + if use_fp8_kv_cache: + # Get number of layers directly from transformer config + num_layers = self.megatron_bridge.transformer_config.num_layers + print(f"[@@KV_SCALES@@] stream_weights_via_ipc_zmq: kv_cache_dtype=fp8, streaming KV scales for {num_layers} layers") + keys = [] + for layer_idx in range(num_layers): + scale_names = get_vllm_qkv_scale_names(layer_idx) + keys.extend(scale_names.values()) + # Append kv-scale entries to match metadata; use provided value or default 1.0 + for param_name in keys: + if kv_scales and param_name in kv_scales: + scale_value = kv_scales[param_name] + else: + scale_value = 1.0 + scale_tensor = torch.tensor( + scale_value, dtype=torch.float32, device="cuda" + ).reshape(1) + yield param_name, scale_tensor + else: + print(f"[KV_SCALES] stream_weights_via_ipc_zmq: kv_cache_dtype is not fp8, skipping KV scales") # Use the shared implementation stream_weights_via_ipc_zmq_impl( @@ -2059,28 +2083,39 @@ def broadcast_weights_for_collective( conversion_tasks=self.refit_conversion_tasks, # used for metadata caching ) + # CHANGE: Only check and broadcast KV scales when kv_cache_dtype is FP8 + use_fp8_kv_cache = False + if "generation" in self.cfg and self.cfg["generation"] is not None: + vllm_cfg = self.cfg["generation"].get("vllm_cfg", {}) + kv_cache_dtype = vllm_cfg.get("kv_cache_dtype", "auto") + use_fp8_kv_cache = kv_cache_dtype == "fp8" + def iter_with_kv_scales(): # First yield all model weights for name, tensor in hf_params_generator: yield name, tensor - - # Get number of layers directly from transformer config - num_layers = self.megatron_bridge.transformer_config.num_layers - print(f"[@@KV_SCALES@@] iter_with_kv_scales: num_layers = {num_layers}") - keys = [] - for layer_idx in range(num_layers): - scale_names = get_vllm_qkv_scale_names(layer_idx) - keys.extend(scale_names.values()) - # Always append kv-scale entries to match metadata; use provided value or default 1.0 - for param_name in keys: - if kv_scales and param_name in kv_scales: - scale_value = kv_scales[param_name] - else: - scale_value = 1.0 - scale_tensor = torch.tensor( - scale_value, dtype=torch.float32, device="cuda" - ).reshape(1) - yield param_name, scale_tensor + + # CHANGE: Only append KV scales when kv_cache_dtype is FP8 + if use_fp8_kv_cache: + # Get number of layers directly from transformer config + num_layers = self.megatron_bridge.transformer_config.num_layers + print(f"[@@KV_SCALES@@] broadcast_weights_for_collective: kv_cache_dtype=fp8, broadcasting KV scales for {num_layers} layers") + keys = [] + for layer_idx in range(num_layers): + scale_names = get_vllm_qkv_scale_names(layer_idx) + keys.extend(scale_names.values()) + # Append kv-scale entries to match metadata; use provided value or default 1.0 + for param_name in keys: + if kv_scales and param_name in kv_scales: + scale_value = kv_scales[param_name] + else: + scale_value = 1.0 + scale_tensor = torch.tensor( + scale_value, dtype=torch.float32, device="cuda" + ).reshape(1) + yield param_name, scale_tensor + else: + print(f"[KV_SCALES] broadcast_weights_for_collective: kv_cache_dtype is not fp8, skipping KV scales") # param_iterator will return (name, tensor), we only need tensor packed_broadcast_producer( From 964d929f32fb91a822a5f782902d99e8691e2be3 Mon Sep 17 00:00:00 2001 From: Shuang Yu Date: Sat, 15 Nov 2025 21:54:55 -0800 Subject: [PATCH 15/40] WIP: changes before rebase Signed-off-by: Zhaopeng Qiu --- nemo_rl/algorithms/grpo.py | 25 ++++-- nemo_rl/models/generation/fp8.py | 81 +++++++++++++++++-- nemo_rl/models/generation/vllm/config.py | 1 + .../models/generation/vllm/vllm_backend.py | 71 +++++++++++----- .../models/policy/megatron_policy_worker.py | 56 ++++++++----- 5 files changed, 181 insertions(+), 53 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 1cfa50b3c1..e5ca2516b5 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -883,9 +883,13 @@ def _should_sync_kv_scales(master_config: MasterConfig) -> bool: """Check if KV cache scales should be synchronized during refit. Returns True if: - - kv_cache_dtype is fp8 (independent of precision) + - kv_cache_dtype is fp8 AND + - calculate_kv_scales is False (static scales mode) - Note: KV cache scales are only needed when kv_cache_dtype is FP8. + When calculate_kv_scales=True (dynamic mode), vLLM calculates scales + automatically during forward passes, so no sync is needed. + + Note: KV cache scales are only relevant when kv_cache_dtype is FP8. The model precision (fp8 or bf16) is independent of this requirement. """ generation_config = master_config["policy"]["generation"] @@ -898,9 +902,10 @@ def _should_sync_kv_scales(master_config: MasterConfig) -> bool: vllm_cfg = generation_config.get("vllm_cfg", {}) kv_cache_dtype = vllm_cfg.get("kv_cache_dtype", "auto") + calculate_kv_scales = vllm_cfg.get("calculate_kv_scales", False) - # Only check kv_cache_dtype, not precision. This allows: precision=bf16 with kv_cache_dtype=fp8 - return kv_cache_dtype == "fp8" + # Only sync scales when using FP8 KV cache with static scales (not dynamic calculation) + return kv_cache_dtype == "fp8" and not calculate_kv_scales def refit_policy_generation( @@ -1014,10 +1019,18 @@ def grpo_train( sync_kv_scales = _should_sync_kv_scales(master_config) kv_scales_cache = None # Cache reused for computed kv scales + vllm_cfg = master_config["policy"]["generation"].get("vllm_cfg", {}) + kv_cache_dtype = vllm_cfg.get("kv_cache_dtype", "auto") + calculate_kv_scales = vllm_cfg.get("calculate_kv_scales", False) + if sync_kv_scales: - print(f"[KV_SCALES] FP8 KV cache enabled (kv_cache_dtype=fp8), will sync q_scale, k_scale and v_scale during refit") + print(f"[KV_SCALES] FP8 KV cache with static scales (kv_cache_dtype=fp8, calculate_kv_scales=False)") + print(f"[KV_SCALES] Will compute and sync q_scale, k_scale, v_scale during refit") + elif kv_cache_dtype == "fp8" and calculate_kv_scales: + print(f"[KV_SCALES] FP8 KV cache with dynamic calculation (kv_cache_dtype=fp8, calculate_kv_scales=True)") + print(f"[KV_SCALES] vLLM will calculate scales dynamically on each forward pass, no sync needed") else: - print("[KV_SCALES] KV cache scale sync not needed (kv_cache_dtype is not fp8, regardless of model precision)") + print("[KV_SCALES] KV cache scale sync not needed (kv_cache_dtype is not fp8)") NEED_REFIT = True # If policy_generation is None, use the policy as the generation interface (megatron framework backend) diff --git a/nemo_rl/models/generation/fp8.py b/nemo_rl/models/generation/fp8.py index 67252372e4..bd304ccff8 100644 --- a/nemo_rl/models/generation/fp8.py +++ b/nemo_rl/models/generation/fp8.py @@ -42,6 +42,7 @@ class FP8Config: model_parallel_size: int = None kv_cache_dtype: str = "auto" use_fp8_weights: bool = True # Whether model weights are quantized to FP8 + calculate_kv_scales: bool = False # Whether to dynamically calculate KV scales @dataclass() @@ -303,6 +304,48 @@ def convert_calibration_to_vllm_format( return vllm_scales +def reset_calculate_kv_scales_in_worker(worker): + """Reset calculate_kv_scales flag for all attention layers after wake_up. + + This is called after wake_up to ensure KV scales are recalculated with new weights. + """ + print("[FP8_PATCHES] reset_calculate_kv_scales_in_worker called") + try: + model = worker.model_runner.model + print(f"[FP8_PATCHES] Searching for attention layers in model: {type(model).__name__}") + + # Iterate through all modules to find attention layers + attention_layers_found = 0 + for name, module in model.named_modules(): + # Check if this is an Attention layer with calculate_kv_scales attribute + if hasattr(module, 'calculate_kv_scales') and hasattr(module, 'kv_cache_dtype'): + if module.kv_cache_dtype == "fp8": + print(f"[FP8_PATCHES] Found attention layer: {name}, kv_cache_dtype={module.kv_cache_dtype}, calculate_kv_scales={module.calculate_kv_scales}") + module.calculate_kv_scales = True + attention_layers_found += 1 + print(f"[FP8_PATCHES] Reset calculate_kv_scales=True for layer: {name}") + + print(f"[FP8_PATCHES] Total attention layers reset: {attention_layers_found}") + except Exception as e: + print(f"[FP8_PATCHES] Error in reset_calculate_kv_scales_in_worker: {e}") + import traceback + traceback.print_exc() + + +def patched_wake_up(original_wake_up): + """Wrapper for Worker.wake_up that resets calculate_kv_scales after waking up.""" + def wake_up_wrapper(self, tags=None): + print("[FP8_PATCHES] patched_wake_up called") + # Call original wake_up + result = original_wake_up(self, tags) + + # Reset calculate_kv_scales for all attention layers + reset_calculate_kv_scales_in_worker(self) + + return result + return wake_up_wrapper + + def apply_fp8_patches(self, fp8_config): global global_fp8_config, fp8_patches_applied assert not fp8_patches_applied @@ -341,10 +384,20 @@ def apply_fp8_patches(self, fp8_config): if global_fp8_config.kv_cache_dtype == "fp8": print("[FP8_PATCHES] Applying FP8 KV cache patches (kv_cache_dtype=fp8)") - # Patch the vllm kv_cache.py process_weights_after_loading() to remove the deletion of k_scale and v_scale - func5_path = "vllm.model_executor.layers.quantization.kv_cache.BaseKVCacheMethod.process_weights_after_loading" - patcher5 = patch(func5_path, kv_cache_process_weights_after_loading) - fp8_state.vllm_patches.append(patcher5) + if global_fp8_config.calculate_kv_scales: + # Dynamic calculation mode: patch wake_up to reset calculate_kv_scales + print("[FP8_PATCHES] Enabling dynamic KV scale recalculation (calculate_kv_scales=True)") + from vllm.v1.worker.gpu_worker import Worker + original_wake_up = Worker.wake_up + Worker.wake_up = patched_wake_up(original_wake_up) + print("[FP8_PATCHES] Patched Worker.wake_up to reset calculate_kv_scales after wake_up") + else: + # Static scales mode: patch process_weights_after_loading to preserve k_scale/v_scale for manual updates + print("[FP8_PATCHES] Using static KV scales (calculate_kv_scales=False)") + func5_path = "vllm.model_executor.layers.quantization.kv_cache.BaseKVCacheMethod.process_weights_after_loading" + patcher5 = patch(func5_path, kv_cache_process_weights_after_loading) + fp8_state.vllm_patches.append(patcher5) + print("[FP8_PATCHES] Patched process_weights_after_loading to preserve k_scale/v_scale for updates") for p in fp8_state.vllm_patches: p.start() @@ -363,6 +416,17 @@ def init_fp8(vllm_cfg, model_name, model_parallel_size): # Determine if we're using FP8 weights based on precision setting use_fp8_weights = vllm_cfg.get("precision") == "fp8" + # Extract calculate_kv_scales from config (default to False for backward compatibility) + calculate_kv_scales = vllm_cfg.get("calculate_kv_scales", False) + kv_cache_dtype = vllm_cfg.get("kv_cache_dtype", "auto") + + # Validate configuration + if calculate_kv_scales and kv_cache_dtype != "fp8": + raise ValueError( + f"calculate_kv_scales=True requires kv_cache_dtype='fp8', " + f"but got kv_cache_dtype='{kv_cache_dtype}'" + ) + global_fp8_config = FP8Config( use_weight_pow2_scale=vllm_cfg.get("pow2_weight_scaling_factors", False), use_activation_pow2_scale=vllm_cfg.get( @@ -371,8 +435,9 @@ def init_fp8(vllm_cfg, model_name, model_parallel_size): num_first_layers_in_bf16=vllm_cfg.get("num_first_layers_in_bf16", 0), num_last_layers_in_bf16=vllm_cfg.get("num_last_layers_in_bf16", 0), model_parallel_size=model_parallel_size, - kv_cache_dtype=vllm_cfg.get("kv_cache_dtype", "auto"), + kv_cache_dtype=kv_cache_dtype, use_fp8_weights=use_fp8_weights, + calculate_kv_scales=calculate_kv_scales, ) if vllm_cfg.get("use_deep_gemm", False): @@ -390,7 +455,6 @@ def init_fp8(vllm_cfg, model_name, model_parallel_size): # create fp8 kwargs for vllm's LLM(...) num_first_layers_in_bf16 = vllm_cfg.get("num_first_layers_in_bf16", 0) num_last_layers_in_bf16 = vllm_cfg.get("num_last_layers_in_bf16", 0) - kv_cache_dtype = vllm_cfg.get("kv_cache_dtype", "auto") fp8_block_quant_kwargs = dict(FP8_BLOCK_QUANT_KWARGS) if num_first_layers_in_bf16 > 0 or num_last_layers_in_bf16 > 0: @@ -432,6 +496,11 @@ def init_fp8(vllm_cfg, model_name, model_parallel_size): "kv_cache_dtype": kv_cache_dtype, } + # Add calculate_kv_scales to the kwargs if it's set + # This will be passed to vLLM's CacheConfig + if calculate_kv_scales: + vllm_kwargs["calculate_kv_scales"] = calculate_kv_scales + return vllm_kwargs diff --git a/nemo_rl/models/generation/vllm/config.py b/nemo_rl/models/generation/vllm/config.py index c3a0171679..ff6b01d08e 100644 --- a/nemo_rl/models/generation/vllm/config.py +++ b/nemo_rl/models/generation/vllm/config.py @@ -29,6 +29,7 @@ class VllmSpecificArgs(TypedDict): load_format: NotRequired[str] precision: NotRequired[str] kv_cache_dtype: NotRequired[str] + calculate_kv_scales: NotRequired[bool] enforce_eager: NotRequired[bool] # By default, NeMo RL only has a Python handle to the vllm.LLM generation engine. The expose_http_server flag here will expose that generation engine as an HTTP server. # Exposing vLLM as a server is useful in instances where the multi-turn rollout is performed with utilities outside of NeMo RL, but the user still wants to take advantage of the refit logic in NeMo RL that keeps the policy and generation up to date. diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index 6034e31f85..0d0074398a 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -161,18 +161,34 @@ def update_weights_via_ipc_zmq(self) -> bool: buffer = None self.zmq_socket.send(IPCProtocol.ACK.value.encode()) - # CHANGE: Only invoke process_weights_after_loading when kv_cache_dtype is FP8 - # Check if KV cache is using FP8 - use_fp8_kv_cache = False + # CHANGE: Only invoke process_weights_after_loading for static FP8 KV scales + # (kv_cache_dtype=fp8 AND calculate_kv_scales=False) + # When calculate_kv_scales=True, vLLM calculates scales dynamically, no need to process + use_static_fp8_kv_scales = False if hasattr(self.model_runner.vllm_config, 'cache_config'): kv_cache_dtype = getattr(self.model_runner.vllm_config.cache_config, 'cache_dtype', None) - print(f"[KV_SCALES] [vllm_backend.py] update_weights_via_ipc_zmq: kv_cache_dtype is {kv_cache_dtype}") - use_fp8_kv_cache = kv_cache_dtype is not None and 'fp8' in str(kv_cache_dtype).lower() + is_fp8 = kv_cache_dtype is not None and 'fp8' in str(kv_cache_dtype).lower() + + if is_fp8: + # Check if dynamic calculation is enabled by examining attention layers + # If any attention layer has calculate_kv_scales attribute, we're in dynamic mode + calculate_kv_scales_enabled = False + for module in self.model_runner.model.modules(): + if hasattr(module, 'calculate_kv_scales') and hasattr(module, 'kv_cache_dtype'): + if module.kv_cache_dtype == "fp8": + # If calculate_kv_scales is True, we're in dynamic mode + calculate_kv_scales_enabled = True + break + + # Only use static scales when kv_cache is fp8 AND NOT using dynamic calculation + use_static_fp8_kv_scales = not calculate_kv_scales_enabled + print(f"[KV_SCALES] update_weights_via_ipc_zmq: kv_cache_dtype={kv_cache_dtype}, calculate_kv_scales={calculate_kv_scales_enabled}, use_static_fp8_kv_scales={use_static_fp8_kv_scales}") + else: + print(f"[KV_SCALES] update_weights_via_ipc_zmq: kv_cache_dtype={kv_cache_dtype}, not FP8") - if use_fp8_kv_cache: - # When kv_scales is provided, we need to invoke process_weights_after_loading() - # to copy the kv scales to the _k_scale and _v_scale attributes used during inference - print(f"[@@KV_SCALES@@] [vllm_backend.py] update_weights_via_ipc_zmq: kv_cache_dtype is FP8, processing KV cache scales after weight loading") + if use_static_fp8_kv_scales: + # Static FP8 KV scale mode: process KV scales after weight loading + print(f"[KV_SCALES] update_weights_via_ipc_zmq: Static FP8 KV scales mode, processing scales after weight loading") from vllm.model_executor.model_loader.utils import process_weights_after_loading # Get target device for processing @@ -185,7 +201,7 @@ def update_weights_via_ipc_zmq(self) -> bool: target_device ) else: - print(f"[KV_SCALES] [vllm_backend.py] update_weights_via_ipc_zmq: kv_cache_dtype is not FP8, skipping process_weights_after_loading") + print(f"[KV_SCALES] update_weights_via_ipc_zmq: Not using static FP8 KV scales, skipping process_weights_after_loading") gc.collect() torch.cuda.empty_cache() @@ -235,17 +251,34 @@ def _load_model_weights(weights, model_runner): post_unpack_func=load_model_weight_func, ) - # CHANGE: Only invoke process_weights_after_loading when kv_cache_dtype is FP8 - # Check if KV cache is using FP8 - use_fp8_kv_cache = False + # CHANGE: Only invoke process_weights_after_loading for static FP8 KV scales + # (kv_cache_dtype=fp8 AND calculate_kv_scales=False) + # When calculate_kv_scales=True, vLLM calculates scales dynamically, no need to process + use_static_fp8_kv_scales = False if hasattr(self.model_runner.vllm_config, 'cache_config'): kv_cache_dtype = getattr(self.model_runner.vllm_config.cache_config, 'cache_dtype', None) - use_fp8_kv_cache = kv_cache_dtype is not None and 'fp8' in str(kv_cache_dtype).lower() + is_fp8 = kv_cache_dtype is not None and 'fp8' in str(kv_cache_dtype).lower() + + if is_fp8: + # Check if dynamic calculation is enabled by examining attention layers + # If any attention layer has calculate_kv_scales attribute, we're in dynamic mode + calculate_kv_scales_enabled = False + for module in self.model_runner.model.modules(): + if hasattr(module, 'calculate_kv_scales') and hasattr(module, 'kv_cache_dtype'): + if module.kv_cache_dtype == "fp8": + # If calculate_kv_scales is True, we're in dynamic mode + calculate_kv_scales_enabled = True + break + + # Only use static scales when kv_cache is fp8 AND NOT using dynamic calculation + use_static_fp8_kv_scales = not calculate_kv_scales_enabled + print(f"[KV_SCALES] update_weights_from_collective: kv_cache_dtype={kv_cache_dtype}, calculate_kv_scales={calculate_kv_scales_enabled}, use_static_fp8_kv_scales={use_static_fp8_kv_scales}") + else: + print(f"[KV_SCALES] update_weights_from_collective: kv_cache_dtype={kv_cache_dtype}, not FP8") - if use_fp8_kv_cache: - # When KV scales are broadcast, we need to invoke process_weights_after_loading() - # to copy the kv scales to the _k_scale and _v_scale attributes used during inference - print(f"[@@KV_SCALES@@] [vllm_backend.py] update_weights_from_collective: kv_cache_dtype is FP8, processing KV cache scales after weight loading") + if use_static_fp8_kv_scales: + # Static FP8 KV scale mode: process KV scales after weight loading + print(f"[KV_SCALES] update_weights_from_collective: Static FP8 KV scales mode, processing scales after weight loading") from vllm.model_executor.model_loader.utils import process_weights_after_loading # Get target device for processing @@ -258,7 +291,7 @@ def _load_model_weights(weights, model_runner): target_device ) else: - print(f"[KV_SCALES] [vllm_backend.py] update_weights_from_collective: kv_cache_dtype is not FP8, skipping process_weights_after_loading") + print(f"[KV_SCALES] update_weights_from_collective: Not using static FP8 KV scales, skipping process_weights_after_loading") except Exception as e: print( diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 71491fb65f..516b467644 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -1927,20 +1927,24 @@ def prepare_refit_info(self) -> None: metadata = (tensor.shape, tensor.dtype) refit_param_info_hf[name] = metadata - # Only include KV/Q scale metadata when kv_cache_dtype is FP8 - # Check if we're using FP8 KV cache - use_fp8_kv_cache = False + # Only include KV/Q scale metadata when using static FP8 KV cache scales + # (kv_cache_dtype=fp8 AND calculate_kv_scales=False) + # When calculate_kv_scales=True, vLLM calculates scales dynamically, no need to sync + use_static_fp8_kv_scales = False if "generation" in self.cfg and self.cfg["generation"] is not None: vllm_cfg = self.cfg["generation"].get("vllm_cfg", {}) kv_cache_dtype = vllm_cfg.get("kv_cache_dtype", "auto") - use_fp8_kv_cache = kv_cache_dtype == "fp8" + calculate_kv_scales = vllm_cfg.get("calculate_kv_scales", False) + # Only use static scales when kv_cache is fp8 but NOT dynamically calculated + use_static_fp8_kv_scales = kv_cache_dtype == "fp8" and not calculate_kv_scales - if use_fp8_kv_cache: - # Include KV/Q scale metadata so consumer can rely solely on state_dict_info + if use_static_fp8_kv_scales: + # Static FP8 KV scale mode: Include KV/Q scale metadata for syncing try: # Get number of layers directly from transformer config num_layers = self.megatron_bridge.transformer_config.num_layers - print(f"[@@KV_SCALES@@] prepare_refit_info: kv_cache_dtype=fp8, adding scale metadata for {num_layers} layers") + print(f"[KV_SCALES] prepare_refit_info: Static FP8 KV scales mode (kv_cache_dtype=fp8, calculate_kv_scales=False)") + print(f"[KV_SCALES] Adding scale metadata for {num_layers} layers") # Append q/k/v scale placeholders (shape [1], dtype float32) for layer_idx in range(num_layers): scale_names = get_vllm_qkv_scale_names(layer_idx) @@ -1950,7 +1954,7 @@ def prepare_refit_info(self) -> None: except Exception: pass else: - print(f"[KV_SCALES] prepare_refit_info: kv_cache_dtype is not fp8, skipping KV scale metadata") + print(f"[KV_SCALES] prepare_refit_info: Not using static FP8 KV scales, skipping KV scale metadata") return refit_param_info_hf @@ -2029,23 +2033,27 @@ def stream_weights_via_ipc_zmq( conversion_tasks=self.refit_conversion_tasks, # used for metadata caching ) - # CHANGE: Only check and stream KV scales when kv_cache_dtype is FP8 - use_fp8_kv_cache = False + # CHANGE: Only stream KV scales in static FP8 KV scale mode + # (kv_cache_dtype=fp8 AND calculate_kv_scales=False) + # When calculate_kv_scales=True, vLLM calculates scales dynamically, no need to stream + use_static_fp8_kv_scales = False if "generation" in self.cfg and self.cfg["generation"] is not None: vllm_cfg = self.cfg["generation"].get("vllm_cfg", {}) kv_cache_dtype = vllm_cfg.get("kv_cache_dtype", "auto") - use_fp8_kv_cache = kv_cache_dtype == "fp8" + calculate_kv_scales = vllm_cfg.get("calculate_kv_scales", False) + # Only use static scales when kv_cache is fp8 but NOT dynamically calculated + use_static_fp8_kv_scales = kv_cache_dtype == "fp8" and not calculate_kv_scales def iter_with_kv_scales(): # First yield all model weights for name, tensor in hf_params_generator: yield name, tensor - # CHANGE: Only append KV scales when kv_cache_dtype is FP8 - if use_fp8_kv_cache: + # CHANGE: Only append KV scales in static mode + if use_static_fp8_kv_scales: # Get number of layers directly from transformer config num_layers = self.megatron_bridge.transformer_config.num_layers - print(f"[@@KV_SCALES@@] stream_weights_via_ipc_zmq: kv_cache_dtype=fp8, streaming KV scales for {num_layers} layers") + print(f"[KV_SCALES] stream_weights_via_ipc_zmq: Static FP8 KV scales mode, streaming scales for {num_layers} layers") keys = [] for layer_idx in range(num_layers): scale_names = get_vllm_qkv_scale_names(layer_idx) @@ -2061,7 +2069,7 @@ def iter_with_kv_scales(): ).reshape(1) yield param_name, scale_tensor else: - print(f"[KV_SCALES] stream_weights_via_ipc_zmq: kv_cache_dtype is not fp8, skipping KV scales") + print(f"[KV_SCALES] stream_weights_via_ipc_zmq: Not using static FP8 KV scales, skipping KV scales") # Use the shared implementation stream_weights_via_ipc_zmq_impl( @@ -2083,23 +2091,27 @@ def broadcast_weights_for_collective( conversion_tasks=self.refit_conversion_tasks, # used for metadata caching ) - # CHANGE: Only check and broadcast KV scales when kv_cache_dtype is FP8 - use_fp8_kv_cache = False + # CHANGE: Only broadcast KV scales in static FP8 KV scale mode + # (kv_cache_dtype=fp8 AND calculate_kv_scales=False) + # When calculate_kv_scales=True, vLLM calculates scales dynamically, no need to broadcast + use_static_fp8_kv_scales = False if "generation" in self.cfg and self.cfg["generation"] is not None: vllm_cfg = self.cfg["generation"].get("vllm_cfg", {}) kv_cache_dtype = vllm_cfg.get("kv_cache_dtype", "auto") - use_fp8_kv_cache = kv_cache_dtype == "fp8" + calculate_kv_scales = vllm_cfg.get("calculate_kv_scales", False) + # Only use static scales when kv_cache is fp8 but NOT dynamically calculated + use_static_fp8_kv_scales = kv_cache_dtype == "fp8" and not calculate_kv_scales def iter_with_kv_scales(): # First yield all model weights for name, tensor in hf_params_generator: yield name, tensor - # CHANGE: Only append KV scales when kv_cache_dtype is FP8 - if use_fp8_kv_cache: + # CHANGE: Only append KV scales in static mode + if use_static_fp8_kv_scales: # Get number of layers directly from transformer config num_layers = self.megatron_bridge.transformer_config.num_layers - print(f"[@@KV_SCALES@@] broadcast_weights_for_collective: kv_cache_dtype=fp8, broadcasting KV scales for {num_layers} layers") + print(f"[KV_SCALES] broadcast_weights_for_collective: Static FP8 KV scales mode, broadcasting scales for {num_layers} layers") keys = [] for layer_idx in range(num_layers): scale_names = get_vllm_qkv_scale_names(layer_idx) @@ -2115,7 +2127,7 @@ def iter_with_kv_scales(): ).reshape(1) yield param_name, scale_tensor else: - print(f"[KV_SCALES] broadcast_weights_for_collective: kv_cache_dtype is not fp8, skipping KV scales") + print(f"[KV_SCALES] broadcast_weights_for_collective: Not using static FP8 KV scales, skipping KV scales") # param_iterator will return (name, tensor), we only need tensor packed_broadcast_producer( From ede399c7e860b2bb37176b14c182196c7ed5cafd Mon Sep 17 00:00:00 2001 From: Shuang Yu Date: Sun, 16 Nov 2025 20:20:20 -0800 Subject: [PATCH 16/40] Code draft to support dynamic kv scales calculation Signed-off-by: Zhaopeng Qiu --- .../models/generation/vllm/vllm_backend.py | 22 ++++--------------- 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index 0d0074398a..561ae72a51 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -170,15 +170,8 @@ def update_weights_via_ipc_zmq(self) -> bool: is_fp8 = kv_cache_dtype is not None and 'fp8' in str(kv_cache_dtype).lower() if is_fp8: - # Check if dynamic calculation is enabled by examining attention layers - # If any attention layer has calculate_kv_scales attribute, we're in dynamic mode - calculate_kv_scales_enabled = False - for module in self.model_runner.model.modules(): - if hasattr(module, 'calculate_kv_scales') and hasattr(module, 'kv_cache_dtype'): - if module.kv_cache_dtype == "fp8": - # If calculate_kv_scales is True, we're in dynamic mode - calculate_kv_scales_enabled = True - break + from nemo_rl.models.generation import fp8 + calculate_kv_scales_enabled = fp8.global_fp8_config.calculate_kv_scales if fp8.global_fp8_config else False # Only use static scales when kv_cache is fp8 AND NOT using dynamic calculation use_static_fp8_kv_scales = not calculate_kv_scales_enabled @@ -260,15 +253,8 @@ def _load_model_weights(weights, model_runner): is_fp8 = kv_cache_dtype is not None and 'fp8' in str(kv_cache_dtype).lower() if is_fp8: - # Check if dynamic calculation is enabled by examining attention layers - # If any attention layer has calculate_kv_scales attribute, we're in dynamic mode - calculate_kv_scales_enabled = False - for module in self.model_runner.model.modules(): - if hasattr(module, 'calculate_kv_scales') and hasattr(module, 'kv_cache_dtype'): - if module.kv_cache_dtype == "fp8": - # If calculate_kv_scales is True, we're in dynamic mode - calculate_kv_scales_enabled = True - break + from nemo_rl.models.generation import fp8 + calculate_kv_scales_enabled = fp8.global_fp8_config.calculate_kv_scales if fp8.global_fp8_config else False # Only use static scales when kv_cache is fp8 AND NOT using dynamic calculation use_static_fp8_kv_scales = not calculate_kv_scales_enabled From ff455d61927eda3655588bec6ff771d8e5cb61fe Mon Sep 17 00:00:00 2001 From: Shuang Yu Date: Sun, 16 Nov 2025 21:57:58 -0800 Subject: [PATCH 17/40] Refit should only takes care of kv_scales when kv_cache_dtype is fp8 otherwise bf16 won't work Signed-off-by: Zhaopeng Qiu --- nemo_rl/algorithms/grpo.py | 33 ++---- nemo_rl/models/generation/fp8.py | 108 +++--------------- nemo_rl/models/generation/vllm/config.py | 1 - .../models/generation/vllm/vllm_backend.py | 52 ++------- nemo_rl/models/generation/vllm/vllm_worker.py | 19 +-- .../models/policy/megatron_policy_worker.py | 54 ++++----- 6 files changed, 69 insertions(+), 198 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index e5ca2516b5..ec3d50afc9 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -502,9 +502,12 @@ def init_vllm(): "Importance sampling must be enabled for vLLM FP8 generation for good convergence!" ) if generation_config["vllm_cfg"].get("kv_cache_dtype") == "fp8": + # FP8 KV cache requires FP8 model precision + assert generation_config["vllm_cfg"]["precision"] == "fp8", ( + "kv_cache_dtype='fp8' requires precision='fp8'. " + "FP8 KV cache can only be used together with FP8 model weights." + ) # FP8 KV cache compatibility checks - # These checks are independent of model precision (can use bf16 or fp8 weights with fp8 KV cache) - # TODO: Add the related support assert policy_config["dtensor_cfg"]["enabled"] == False, ( "DTensor backend is not supported with kv cache fp8 enabled." ) @@ -882,15 +885,9 @@ def _should_use_penguin(master_config: MasterConfig) -> bool: def _should_sync_kv_scales(master_config: MasterConfig) -> bool: """Check if KV cache scales should be synchronized during refit. - Returns True if: - - kv_cache_dtype is fp8 AND - - calculate_kv_scales is False (static scales mode) - - When calculate_kv_scales=True (dynamic mode), vLLM calculates scales - automatically during forward passes, so no sync is needed. - - Note: KV cache scales are only relevant when kv_cache_dtype is FP8. - The model precision (fp8 or bf16) is independent of this requirement. + Returns True if kv_cache_dtype is fp8 (which requires precision=fp8). + KV scales are always computed and synced statically during training + when using FP8 KV cache. """ generation_config = master_config["policy"]["generation"] if generation_config is None: @@ -902,10 +899,9 @@ def _should_sync_kv_scales(master_config: MasterConfig) -> bool: vllm_cfg = generation_config.get("vllm_cfg", {}) kv_cache_dtype = vllm_cfg.get("kv_cache_dtype", "auto") - calculate_kv_scales = vllm_cfg.get("calculate_kv_scales", False) - # Only sync scales when using FP8 KV cache with static scales (not dynamic calculation) - return kv_cache_dtype == "fp8" and not calculate_kv_scales + # Sync scales when using FP8 KV cache (always static in this design) + return kv_cache_dtype == "fp8" def refit_policy_generation( @@ -1019,16 +1015,9 @@ def grpo_train( sync_kv_scales = _should_sync_kv_scales(master_config) kv_scales_cache = None # Cache reused for computed kv scales - vllm_cfg = master_config["policy"]["generation"].get("vllm_cfg", {}) - kv_cache_dtype = vllm_cfg.get("kv_cache_dtype", "auto") - calculate_kv_scales = vllm_cfg.get("calculate_kv_scales", False) - if sync_kv_scales: - print(f"[KV_SCALES] FP8 KV cache with static scales (kv_cache_dtype=fp8, calculate_kv_scales=False)") + print(f"[KV_SCALES] FP8 KV cache enabled (kv_cache_dtype=fp8, precision=fp8)") print(f"[KV_SCALES] Will compute and sync q_scale, k_scale, v_scale during refit") - elif kv_cache_dtype == "fp8" and calculate_kv_scales: - print(f"[KV_SCALES] FP8 KV cache with dynamic calculation (kv_cache_dtype=fp8, calculate_kv_scales=True)") - print(f"[KV_SCALES] vLLM will calculate scales dynamically on each forward pass, no sync needed") else: print("[KV_SCALES] KV cache scale sync not needed (kv_cache_dtype is not fp8)") diff --git a/nemo_rl/models/generation/fp8.py b/nemo_rl/models/generation/fp8.py index bd304ccff8..b13d2fbabf 100644 --- a/nemo_rl/models/generation/fp8.py +++ b/nemo_rl/models/generation/fp8.py @@ -42,7 +42,6 @@ class FP8Config: model_parallel_size: int = None kv_cache_dtype: str = "auto" use_fp8_weights: bool = True # Whether model weights are quantized to FP8 - calculate_kv_scales: bool = False # Whether to dynamically calculate KV scales @dataclass() @@ -111,7 +110,7 @@ def kv_cache_process_weights_after_loading(self, layer: torch.nn.Module) -> None """Modified version of BaseKVCacheMethod.process_weights_after_loading. Doesn't delete k_scale, v_scale, q_scale, and prob_scale parameters to allow - for dynamic updates. + for dynamic updates during refit. """ import torch from vllm.logger import init_logger @@ -126,9 +125,7 @@ def kv_cache_process_weights_after_loading(self, layer: torch.nn.Module) -> None # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 # regardless whether the kv-scale is available in the checkpoint. - # No need to process kv scales after loading if we are going to - # calculate them on the fly. - if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales: + if layer.kv_cache_dtype != "auto": if layer.k_scale > 0.0 and layer.v_scale > 0.0: # We prefer to use separate k_scale and v_scale if present k_scale = layer.k_scale.to("cpu").tolist() @@ -181,7 +178,6 @@ def kv_cache_process_weights_after_loading(self, layer: torch.nn.Module) -> None q_scale = layer.q_scale if current_platform.is_fp8_fnuz(): q_scale *= 2 - layer.calculate_kv_scales = False else: q_scale = 1.0 if layer.prob_scale > 0.0: @@ -304,48 +300,6 @@ def convert_calibration_to_vllm_format( return vllm_scales -def reset_calculate_kv_scales_in_worker(worker): - """Reset calculate_kv_scales flag for all attention layers after wake_up. - - This is called after wake_up to ensure KV scales are recalculated with new weights. - """ - print("[FP8_PATCHES] reset_calculate_kv_scales_in_worker called") - try: - model = worker.model_runner.model - print(f"[FP8_PATCHES] Searching for attention layers in model: {type(model).__name__}") - - # Iterate through all modules to find attention layers - attention_layers_found = 0 - for name, module in model.named_modules(): - # Check if this is an Attention layer with calculate_kv_scales attribute - if hasattr(module, 'calculate_kv_scales') and hasattr(module, 'kv_cache_dtype'): - if module.kv_cache_dtype == "fp8": - print(f"[FP8_PATCHES] Found attention layer: {name}, kv_cache_dtype={module.kv_cache_dtype}, calculate_kv_scales={module.calculate_kv_scales}") - module.calculate_kv_scales = True - attention_layers_found += 1 - print(f"[FP8_PATCHES] Reset calculate_kv_scales=True for layer: {name}") - - print(f"[FP8_PATCHES] Total attention layers reset: {attention_layers_found}") - except Exception as e: - print(f"[FP8_PATCHES] Error in reset_calculate_kv_scales_in_worker: {e}") - import traceback - traceback.print_exc() - - -def patched_wake_up(original_wake_up): - """Wrapper for Worker.wake_up that resets calculate_kv_scales after waking up.""" - def wake_up_wrapper(self, tags=None): - print("[FP8_PATCHES] patched_wake_up called") - # Call original wake_up - result = original_wake_up(self, tags) - - # Reset calculate_kv_scales for all attention layers - reset_calculate_kv_scales_in_worker(self) - - return result - return wake_up_wrapper - - def apply_fp8_patches(self, fp8_config): global global_fp8_config, fp8_patches_applied assert not fp8_patches_applied @@ -383,21 +337,12 @@ def apply_fp8_patches(self, fp8_config): # Apply KV cache patches only when using FP8 KV cache (kv_cache_dtype=fp8) if global_fp8_config.kv_cache_dtype == "fp8": print("[FP8_PATCHES] Applying FP8 KV cache patches (kv_cache_dtype=fp8)") - - if global_fp8_config.calculate_kv_scales: - # Dynamic calculation mode: patch wake_up to reset calculate_kv_scales - print("[FP8_PATCHES] Enabling dynamic KV scale recalculation (calculate_kv_scales=True)") - from vllm.v1.worker.gpu_worker import Worker - original_wake_up = Worker.wake_up - Worker.wake_up = patched_wake_up(original_wake_up) - print("[FP8_PATCHES] Patched Worker.wake_up to reset calculate_kv_scales after wake_up") - else: - # Static scales mode: patch process_weights_after_loading to preserve k_scale/v_scale for manual updates - print("[FP8_PATCHES] Using static KV scales (calculate_kv_scales=False)") - func5_path = "vllm.model_executor.layers.quantization.kv_cache.BaseKVCacheMethod.process_weights_after_loading" - patcher5 = patch(func5_path, kv_cache_process_weights_after_loading) - fp8_state.vllm_patches.append(patcher5) - print("[FP8_PATCHES] Patched process_weights_after_loading to preserve k_scale/v_scale for updates") + # Static scales mode: patch process_weights_after_loading to preserve k_scale/v_scale for manual updates + print("[FP8_PATCHES] Using static KV scales") + func5_path = "vllm.model_executor.layers.quantization.kv_cache.BaseKVCacheMethod.process_weights_after_loading" + patcher5 = patch(func5_path, kv_cache_process_weights_after_loading) + fp8_state.vllm_patches.append(patcher5) + print("[FP8_PATCHES] Patched process_weights_after_loading to preserve k_scale/v_scale for updates") for p in fp8_state.vllm_patches: p.start() @@ -415,16 +360,13 @@ def init_fp8(vllm_cfg, model_name, model_parallel_size): global global_fp8_config # Determine if we're using FP8 weights based on precision setting use_fp8_weights = vllm_cfg.get("precision") == "fp8" - - # Extract calculate_kv_scales from config (default to False for backward compatibility) - calculate_kv_scales = vllm_cfg.get("calculate_kv_scales", False) kv_cache_dtype = vllm_cfg.get("kv_cache_dtype", "auto") - # Validate configuration - if calculate_kv_scales and kv_cache_dtype != "fp8": + # Validate configuration: kv_cache_dtype=fp8 requires precision=fp8 + if kv_cache_dtype == "fp8" and not use_fp8_weights: raise ValueError( - f"calculate_kv_scales=True requires kv_cache_dtype='fp8', " - f"but got kv_cache_dtype='{kv_cache_dtype}'" + "kv_cache_dtype='fp8' requires precision='fp8'. " + "FP8 KV cache can only be used together with FP8 model weights." ) global_fp8_config = FP8Config( @@ -437,7 +379,6 @@ def init_fp8(vllm_cfg, model_name, model_parallel_size): model_parallel_size=model_parallel_size, kv_cache_dtype=kv_cache_dtype, use_fp8_weights=use_fp8_weights, - calculate_kv_scales=calculate_kv_scales, ) if vllm_cfg.get("use_deep_gemm", False): @@ -482,24 +423,13 @@ def init_fp8(vllm_cfg, model_name, model_parallel_size): # TODO: Remove this after debugging. print(f"[KV_SCALES] Global FP8 config: {global_fp8_config}") - # CHANGE: Return different kwargs based on whether we're using FP8 weights - if use_fp8_weights: - # Full FP8: quantize weights and optionally use FP8 KV cache - vllm_kwargs = { - "quantization": "fp8", - "kv_cache_dtype": kv_cache_dtype, - "hf_overrides": {"quantization_config": fp8_block_quant_kwargs}, - } - else: - # Only FP8 KV cache, no weight quantization - vllm_kwargs = { - "kv_cache_dtype": kv_cache_dtype, - } - - # Add calculate_kv_scales to the kwargs if it's set - # This will be passed to vLLM's CacheConfig - if calculate_kv_scales: - vllm_kwargs["calculate_kv_scales"] = calculate_kv_scales + # Return FP8 kwargs (precision=fp8 is required at this point) + # kv_cache_dtype can be "auto" or "fp8" + vllm_kwargs = { + "quantization": "fp8", + "kv_cache_dtype": kv_cache_dtype, + "hf_overrides": {"quantization_config": fp8_block_quant_kwargs}, + } return vllm_kwargs diff --git a/nemo_rl/models/generation/vllm/config.py b/nemo_rl/models/generation/vllm/config.py index ff6b01d08e..c3a0171679 100644 --- a/nemo_rl/models/generation/vllm/config.py +++ b/nemo_rl/models/generation/vllm/config.py @@ -29,7 +29,6 @@ class VllmSpecificArgs(TypedDict): load_format: NotRequired[str] precision: NotRequired[str] kv_cache_dtype: NotRequired[str] - calculate_kv_scales: NotRequired[bool] enforce_eager: NotRequired[bool] # By default, NeMo RL only has a Python handle to the vllm.LLM generation engine. The expose_http_server flag here will expose that generation engine as an HTTP server. # Exposing vLLM as a server is useful in instances where the multi-turn rollout is performed with utilities outside of NeMo RL, but the user still wants to take advantage of the refit logic in NeMo RL that keeps the policy and generation up to date. diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index 561ae72a51..d049f463eb 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -161,27 +161,15 @@ def update_weights_via_ipc_zmq(self) -> bool: buffer = None self.zmq_socket.send(IPCProtocol.ACK.value.encode()) - # CHANGE: Only invoke process_weights_after_loading for static FP8 KV scales - # (kv_cache_dtype=fp8 AND calculate_kv_scales=False) - # When calculate_kv_scales=True, vLLM calculates scales dynamically, no need to process - use_static_fp8_kv_scales = False + # Process weights after loading for FP8 KV cache (static scales) + use_fp8_kv_cache = False if hasattr(self.model_runner.vllm_config, 'cache_config'): kv_cache_dtype = getattr(self.model_runner.vllm_config.cache_config, 'cache_dtype', None) - is_fp8 = kv_cache_dtype is not None and 'fp8' in str(kv_cache_dtype).lower() - - if is_fp8: - from nemo_rl.models.generation import fp8 - calculate_kv_scales_enabled = fp8.global_fp8_config.calculate_kv_scales if fp8.global_fp8_config else False - - # Only use static scales when kv_cache is fp8 AND NOT using dynamic calculation - use_static_fp8_kv_scales = not calculate_kv_scales_enabled - print(f"[KV_SCALES] update_weights_via_ipc_zmq: kv_cache_dtype={kv_cache_dtype}, calculate_kv_scales={calculate_kv_scales_enabled}, use_static_fp8_kv_scales={use_static_fp8_kv_scales}") - else: - print(f"[KV_SCALES] update_weights_via_ipc_zmq: kv_cache_dtype={kv_cache_dtype}, not FP8") + use_fp8_kv_cache = kv_cache_dtype is not None and 'fp8' in str(kv_cache_dtype).lower() - if use_static_fp8_kv_scales: - # Static FP8 KV scale mode: process KV scales after weight loading - print(f"[KV_SCALES] update_weights_via_ipc_zmq: Static FP8 KV scales mode, processing scales after weight loading") + if use_fp8_kv_cache: + # FP8 KV cache: process KV scales after weight loading + print(f"[KV_SCALES] update_weights_via_ipc_zmq: FP8 KV cache detected, processing scales after weight loading") from vllm.model_executor.model_loader.utils import process_weights_after_loading # Get target device for processing @@ -193,8 +181,6 @@ def update_weights_via_ipc_zmq(self) -> bool: self.model_runner.model_config, target_device ) - else: - print(f"[KV_SCALES] update_weights_via_ipc_zmq: Not using static FP8 KV scales, skipping process_weights_after_loading") gc.collect() torch.cuda.empty_cache() @@ -244,27 +230,15 @@ def _load_model_weights(weights, model_runner): post_unpack_func=load_model_weight_func, ) - # CHANGE: Only invoke process_weights_after_loading for static FP8 KV scales - # (kv_cache_dtype=fp8 AND calculate_kv_scales=False) - # When calculate_kv_scales=True, vLLM calculates scales dynamically, no need to process - use_static_fp8_kv_scales = False + # Process weights after loading for FP8 KV cache (static scales) + use_fp8_kv_cache = False if hasattr(self.model_runner.vllm_config, 'cache_config'): kv_cache_dtype = getattr(self.model_runner.vllm_config.cache_config, 'cache_dtype', None) - is_fp8 = kv_cache_dtype is not None and 'fp8' in str(kv_cache_dtype).lower() - - if is_fp8: - from nemo_rl.models.generation import fp8 - calculate_kv_scales_enabled = fp8.global_fp8_config.calculate_kv_scales if fp8.global_fp8_config else False - - # Only use static scales when kv_cache is fp8 AND NOT using dynamic calculation - use_static_fp8_kv_scales = not calculate_kv_scales_enabled - print(f"[KV_SCALES] update_weights_from_collective: kv_cache_dtype={kv_cache_dtype}, calculate_kv_scales={calculate_kv_scales_enabled}, use_static_fp8_kv_scales={use_static_fp8_kv_scales}") - else: - print(f"[KV_SCALES] update_weights_from_collective: kv_cache_dtype={kv_cache_dtype}, not FP8") + use_fp8_kv_cache = kv_cache_dtype is not None and 'fp8' in str(kv_cache_dtype).lower() - if use_static_fp8_kv_scales: - # Static FP8 KV scale mode: process KV scales after weight loading - print(f"[KV_SCALES] update_weights_from_collective: Static FP8 KV scales mode, processing scales after weight loading") + if use_fp8_kv_cache: + # FP8 KV cache: process KV scales after weight loading + print(f"[KV_SCALES] update_weights_from_collective: FP8 KV cache detected, processing scales after weight loading") from vllm.model_executor.model_loader.utils import process_weights_after_loading # Get target device for processing @@ -276,8 +250,6 @@ def _load_model_weights(weights, model_runner): self.model_runner.model_config, target_device ) - else: - print(f"[KV_SCALES] update_weights_from_collective: Not using static FP8 KV scales, skipping process_weights_after_loading") except Exception as e: print( diff --git a/nemo_rl/models/generation/vllm/vllm_worker.py b/nemo_rl/models/generation/vllm/vllm_worker.py index 0ca55a2ebb..2a97931792 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker.py +++ b/nemo_rl/models/generation/vllm/vllm_worker.py @@ -287,25 +287,18 @@ def _patch_vllm_init_workers_ray(): ) vllm_kwargs["ray_workers_use_nsight"] = True - # Call init_fp8 when either precision is fp8 OR kv_cache_dtype is fp8 - # This ensures vLLM patches are applied for KV cache FP8 even with bf16 weights - if self.cfg["vllm_cfg"]["precision"] == "fp8" or self.cfg["vllm_cfg"].get("kv_cache_dtype") == "fp8": + # Call init_fp8 when precision is fp8 + # (kv_cache_dtype can be fp8 or auto, validated in init_fp8) + if self.cfg["vllm_cfg"]["precision"] == "fp8": from nemo_rl.models.generation.fp8 import init_fp8 fp8_kwargs = init_fp8( self.cfg["vllm_cfg"], self.model_name, model_parallel_size ) - # For FP8 precision, we need quantization="fp8" and weight quantization config - if self.cfg["vllm_cfg"]["precision"] == "fp8": - vllm_kwargs.update(fp8_kwargs) - # overriden by quant config, however vllm complains if this not passed - self.precision = "bfloat16" - else: - # For non-FP8 precision with FP8 KV cache, only set kv_cache_dtype - # Don't set quantization="fp8" as weights are not quantized - vllm_kwargs["kv_cache_dtype"] = fp8_kwargs["kv_cache_dtype"] - print(f"[KV_SCALES] Using FP8 KV cache with precision={self.precision} (weights not quantized)") + vllm_kwargs.update(fp8_kwargs) + # overriden by quant config, however vllm complains if this not passed + self.precision = "bfloat16" if not isinstance(vllm_kwargs.get("hf_overrides"), dict): vllm_kwargs["hf_overrides"] = {} diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 516b467644..9939e39f5e 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -1927,23 +1927,19 @@ def prepare_refit_info(self) -> None: metadata = (tensor.shape, tensor.dtype) refit_param_info_hf[name] = metadata - # Only include KV/Q scale metadata when using static FP8 KV cache scales - # (kv_cache_dtype=fp8 AND calculate_kv_scales=False) - # When calculate_kv_scales=True, vLLM calculates scales dynamically, no need to sync - use_static_fp8_kv_scales = False + # Include KV/Q scale metadata when using FP8 KV cache + use_fp8_kv_cache = False if "generation" in self.cfg and self.cfg["generation"] is not None: vllm_cfg = self.cfg["generation"].get("vllm_cfg", {}) kv_cache_dtype = vllm_cfg.get("kv_cache_dtype", "auto") - calculate_kv_scales = vllm_cfg.get("calculate_kv_scales", False) - # Only use static scales when kv_cache is fp8 but NOT dynamically calculated - use_static_fp8_kv_scales = kv_cache_dtype == "fp8" and not calculate_kv_scales + use_fp8_kv_cache = kv_cache_dtype == "fp8" - if use_static_fp8_kv_scales: - # Static FP8 KV scale mode: Include KV/Q scale metadata for syncing + if use_fp8_kv_cache: + # FP8 KV cache: Include KV/Q scale metadata for syncing try: # Get number of layers directly from transformer config num_layers = self.megatron_bridge.transformer_config.num_layers - print(f"[KV_SCALES] prepare_refit_info: Static FP8 KV scales mode (kv_cache_dtype=fp8, calculate_kv_scales=False)") + print(f"[KV_SCALES] prepare_refit_info: FP8 KV cache enabled (kv_cache_dtype=fp8)") print(f"[KV_SCALES] Adding scale metadata for {num_layers} layers") # Append q/k/v scale placeholders (shape [1], dtype float32) for layer_idx in range(num_layers): @@ -1954,7 +1950,7 @@ def prepare_refit_info(self) -> None: except Exception: pass else: - print(f"[KV_SCALES] prepare_refit_info: Not using static FP8 KV scales, skipping KV scale metadata") + print(f"[KV_SCALES] prepare_refit_info: FP8 KV cache not enabled, skipping KV scale metadata") return refit_param_info_hf @@ -2033,27 +2029,23 @@ def stream_weights_via_ipc_zmq( conversion_tasks=self.refit_conversion_tasks, # used for metadata caching ) - # CHANGE: Only stream KV scales in static FP8 KV scale mode - # (kv_cache_dtype=fp8 AND calculate_kv_scales=False) - # When calculate_kv_scales=True, vLLM calculates scales dynamically, no need to stream - use_static_fp8_kv_scales = False + # Stream KV scales when using FP8 KV cache + use_fp8_kv_cache = False if "generation" in self.cfg and self.cfg["generation"] is not None: vllm_cfg = self.cfg["generation"].get("vllm_cfg", {}) kv_cache_dtype = vllm_cfg.get("kv_cache_dtype", "auto") - calculate_kv_scales = vllm_cfg.get("calculate_kv_scales", False) - # Only use static scales when kv_cache is fp8 but NOT dynamically calculated - use_static_fp8_kv_scales = kv_cache_dtype == "fp8" and not calculate_kv_scales + use_fp8_kv_cache = kv_cache_dtype == "fp8" def iter_with_kv_scales(): # First yield all model weights for name, tensor in hf_params_generator: yield name, tensor - # CHANGE: Only append KV scales in static mode - if use_static_fp8_kv_scales: + # Append KV scales for FP8 KV cache + if use_fp8_kv_cache: # Get number of layers directly from transformer config num_layers = self.megatron_bridge.transformer_config.num_layers - print(f"[KV_SCALES] stream_weights_via_ipc_zmq: Static FP8 KV scales mode, streaming scales for {num_layers} layers") + print(f"[KV_SCALES] stream_weights_via_ipc_zmq: FP8 KV cache enabled, streaming scales for {num_layers} layers") keys = [] for layer_idx in range(num_layers): scale_names = get_vllm_qkv_scale_names(layer_idx) @@ -2069,7 +2061,7 @@ def iter_with_kv_scales(): ).reshape(1) yield param_name, scale_tensor else: - print(f"[KV_SCALES] stream_weights_via_ipc_zmq: Not using static FP8 KV scales, skipping KV scales") + print(f"[KV_SCALES] stream_weights_via_ipc_zmq: FP8 KV cache not enabled, skipping KV scales") # Use the shared implementation stream_weights_via_ipc_zmq_impl( @@ -2091,27 +2083,23 @@ def broadcast_weights_for_collective( conversion_tasks=self.refit_conversion_tasks, # used for metadata caching ) - # CHANGE: Only broadcast KV scales in static FP8 KV scale mode - # (kv_cache_dtype=fp8 AND calculate_kv_scales=False) - # When calculate_kv_scales=True, vLLM calculates scales dynamically, no need to broadcast - use_static_fp8_kv_scales = False + # Broadcast KV scales when using FP8 KV cache + use_fp8_kv_cache = False if "generation" in self.cfg and self.cfg["generation"] is not None: vllm_cfg = self.cfg["generation"].get("vllm_cfg", {}) kv_cache_dtype = vllm_cfg.get("kv_cache_dtype", "auto") - calculate_kv_scales = vllm_cfg.get("calculate_kv_scales", False) - # Only use static scales when kv_cache is fp8 but NOT dynamically calculated - use_static_fp8_kv_scales = kv_cache_dtype == "fp8" and not calculate_kv_scales + use_fp8_kv_cache = kv_cache_dtype == "fp8" def iter_with_kv_scales(): # First yield all model weights for name, tensor in hf_params_generator: yield name, tensor - # CHANGE: Only append KV scales in static mode - if use_static_fp8_kv_scales: + # Append KV scales for FP8 KV cache + if use_fp8_kv_cache: # Get number of layers directly from transformer config num_layers = self.megatron_bridge.transformer_config.num_layers - print(f"[KV_SCALES] broadcast_weights_for_collective: Static FP8 KV scales mode, broadcasting scales for {num_layers} layers") + print(f"[KV_SCALES] broadcast_weights_for_collective: FP8 KV cache enabled, broadcasting scales for {num_layers} layers") keys = [] for layer_idx in range(num_layers): scale_names = get_vllm_qkv_scale_names(layer_idx) @@ -2127,7 +2115,7 @@ def iter_with_kv_scales(): ).reshape(1) yield param_name, scale_tensor else: - print(f"[KV_SCALES] broadcast_weights_for_collective: Not using static FP8 KV scales, skipping KV scales") + print(f"[KV_SCALES] broadcast_weights_for_collective: FP8 KV cache not enabled, skipping KV scales") # param_iterator will return (name, tensor), we only need tensor packed_broadcast_producer( From ac2b5ede44d5a491e18ec9b6edb6242cea8c9664 Mon Sep 17 00:00:00 2001 From: Zhaopeng Qiu Date: Mon, 17 Nov 2025 06:28:21 -0800 Subject: [PATCH 18/40] lint check Signed-off-by: Zhaopeng Qiu --- nemo_rl/algorithms/grpo.py | 14 ++-- nemo_rl/models/generation/fp8.py | 18 +++--- .../models/generation/vllm/vllm_backend.py | 64 ++++++++++++------- nemo_rl/models/generation/vllm/vllm_worker.py | 2 +- .../models/policy/megatron_policy_worker.py | 32 +++++++--- 5 files changed, 81 insertions(+), 49 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index ec3d50afc9..e63d68eb86 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -899,7 +899,7 @@ def _should_sync_kv_scales(master_config: MasterConfig) -> bool: vllm_cfg = generation_config.get("vllm_cfg", {}) kv_cache_dtype = vllm_cfg.get("kv_cache_dtype", "auto") - + # Sync scales when using FP8 KV cache (always static in this design) return kv_cache_dtype == "fp8" @@ -1013,11 +1013,13 @@ def grpo_train( # Check if we need to sync KV cache scales (infer from config) sync_kv_scales = _should_sync_kv_scales(master_config) - kv_scales_cache = None # Cache reused for computed kv scales - - if sync_kv_scales: - print(f"[KV_SCALES] FP8 KV cache enabled (kv_cache_dtype=fp8, precision=fp8)") - print(f"[KV_SCALES] Will compute and sync q_scale, k_scale, v_scale during refit") + kv_scales_cache = None # Cache reused for computed kv scales + + if sync_kv_scales: + print("[KV_SCALES] FP8 KV cache enabled (kv_cache_dtype=fp8, precision=fp8)") + print( + "[KV_SCALES] Will compute and sync q_scale, k_scale, v_scale during refit" + ) else: print("[KV_SCALES] KV cache scale sync not needed (kv_cache_dtype is not fp8)") diff --git a/nemo_rl/models/generation/fp8.py b/nemo_rl/models/generation/fp8.py index b13d2fbabf..903d1a4e58 100644 --- a/nemo_rl/models/generation/fp8.py +++ b/nemo_rl/models/generation/fp8.py @@ -309,11 +309,11 @@ def apply_fp8_patches(self, fp8_config): # Apply patches conditionally based on configuration # Only apply weight patches if using FP8 weights # Only apply KV cache patches if using FP8 KV cache - + # Apply weight-related patches only when using FP8 weights (precision=fp8) if global_fp8_config.use_fp8_weights: print("[FP8_PATCHES] Applying FP8 weight quantization patches (precision=fp8)") - + # This patch is used to support torch.compile with vllm parameter subclasses, such as # PerTensorScaleParameter. Because we need weight loaders to update fp8 weights each # refit, we patch fp8 parameters to have a reference to their weight loader. Eventually @@ -322,7 +322,7 @@ def apply_fp8_patches(self, fp8_config): func1_path = "vllm.model_executor.layers.quantization.fp8.Fp8LinearMethod.process_weights_after_loading" patcher1 = patch(func1_path, process_weights_after_loading) fp8_state.vllm_patches.append(patcher1) - + # These patches add support for pow2, e8 dynamic activation scalings factors which are believed to have higher # SNR compared to plain fp32 scaling factors. This feature is still under active research. if global_fp8_config.use_activation_pow2_scale: @@ -342,7 +342,9 @@ def apply_fp8_patches(self, fp8_config): func5_path = "vllm.model_executor.layers.quantization.kv_cache.BaseKVCacheMethod.process_weights_after_loading" patcher5 = patch(func5_path, kv_cache_process_weights_after_loading) fp8_state.vllm_patches.append(patcher5) - print("[FP8_PATCHES] Patched process_weights_after_loading to preserve k_scale/v_scale for updates") + print( + "[FP8_PATCHES] Patched process_weights_after_loading to preserve k_scale/v_scale for updates" + ) for p in fp8_state.vllm_patches: p.start() @@ -361,14 +363,14 @@ def init_fp8(vllm_cfg, model_name, model_parallel_size): # Determine if we're using FP8 weights based on precision setting use_fp8_weights = vllm_cfg.get("precision") == "fp8" kv_cache_dtype = vllm_cfg.get("kv_cache_dtype", "auto") - + # Validate configuration: kv_cache_dtype=fp8 requires precision=fp8 if kv_cache_dtype == "fp8" and not use_fp8_weights: raise ValueError( "kv_cache_dtype='fp8' requires precision='fp8'. " "FP8 KV cache can only be used together with FP8 model weights." ) - + global_fp8_config = FP8Config( use_weight_pow2_scale=vllm_cfg.get("pow2_weight_scaling_factors", False), use_activation_pow2_scale=vllm_cfg.get( @@ -422,7 +424,7 @@ def init_fp8(vllm_cfg, model_name, model_parallel_size): # TODO: Remove this after debugging. print(f"[KV_SCALES] Global FP8 config: {global_fp8_config}") - + # Return FP8 kwargs (precision=fp8 is required at this point) # kv_cache_dtype can be "auto" or "fp8" vllm_kwargs = { @@ -430,7 +432,7 @@ def init_fp8(vllm_cfg, model_name, model_parallel_size): "kv_cache_dtype": kv_cache_dtype, "hf_overrides": {"quantization_config": fp8_block_quant_kwargs}, } - + return vllm_kwargs diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index d049f463eb..092ce223d8 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -163,23 +163,31 @@ def update_weights_via_ipc_zmq(self) -> bool: # Process weights after loading for FP8 KV cache (static scales) use_fp8_kv_cache = False - if hasattr(self.model_runner.vllm_config, 'cache_config'): - kv_cache_dtype = getattr(self.model_runner.vllm_config.cache_config, 'cache_dtype', None) - use_fp8_kv_cache = kv_cache_dtype is not None and 'fp8' in str(kv_cache_dtype).lower() - + if hasattr(self.model_runner.vllm_config, "cache_config"): + kv_cache_dtype = getattr( + self.model_runner.vllm_config.cache_config, "cache_dtype", None + ) + use_fp8_kv_cache = ( + kv_cache_dtype is not None and "fp8" in str(kv_cache_dtype).lower() + ) + if use_fp8_kv_cache: # FP8 KV cache: process KV scales after weight loading - print(f"[KV_SCALES] update_weights_via_ipc_zmq: FP8 KV cache detected, processing scales after weight loading") - from vllm.model_executor.model_loader.utils import process_weights_after_loading - + print( + "[KV_SCALES] update_weights_via_ipc_zmq: FP8 KV cache detected, processing scales after weight loading" + ) + from vllm.model_executor.model_loader.utils import ( + process_weights_after_loading, + ) + # Get target device for processing target_device = next(self.model_runner.model.parameters()).device - + # Call process_weights_after_loading to handle KV scales process_weights_after_loading( - self.model_runner.model, - self.model_runner.model_config, - target_device + self.model_runner.model, + self.model_runner.model_config, + target_device, ) gc.collect() @@ -229,28 +237,36 @@ def _load_model_weights(weights, model_runner): src=0, post_unpack_func=load_model_weight_func, ) - + # Process weights after loading for FP8 KV cache (static scales) use_fp8_kv_cache = False - if hasattr(self.model_runner.vllm_config, 'cache_config'): - kv_cache_dtype = getattr(self.model_runner.vllm_config.cache_config, 'cache_dtype', None) - use_fp8_kv_cache = kv_cache_dtype is not None and 'fp8' in str(kv_cache_dtype).lower() - + if hasattr(self.model_runner.vllm_config, "cache_config"): + kv_cache_dtype = getattr( + self.model_runner.vllm_config.cache_config, "cache_dtype", None + ) + use_fp8_kv_cache = ( + kv_cache_dtype is not None and "fp8" in str(kv_cache_dtype).lower() + ) + if use_fp8_kv_cache: # FP8 KV cache: process KV scales after weight loading - print(f"[KV_SCALES] update_weights_from_collective: FP8 KV cache detected, processing scales after weight loading") - from vllm.model_executor.model_loader.utils import process_weights_after_loading - + print( + "[KV_SCALES] update_weights_from_collective: FP8 KV cache detected, processing scales after weight loading" + ) + from vllm.model_executor.model_loader.utils import ( + process_weights_after_loading, + ) + # Get target device for processing target_device = next(self.model_runner.model.parameters()).device - + # Call process_weights_after_loading to handle KV scales process_weights_after_loading( - self.model_runner.model, - self.model_runner.model_config, - target_device + self.model_runner.model, + self.model_runner.model_config, + target_device, ) - + except Exception as e: print( f"Error in VllmInternalWorkerExtension.update_weights_from_collective: {e}" diff --git a/nemo_rl/models/generation/vllm/vllm_worker.py b/nemo_rl/models/generation/vllm/vllm_worker.py index 2a97931792..1021faf95a 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker.py +++ b/nemo_rl/models/generation/vllm/vllm_worker.py @@ -295,7 +295,7 @@ def _patch_vllm_init_workers_ray(): fp8_kwargs = init_fp8( self.cfg["vllm_cfg"], self.model_name, model_parallel_size ) - + vllm_kwargs.update(fp8_kwargs) # overriden by quant config, however vllm complains if this not passed self.precision = "bfloat16" diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 9939e39f5e..bc5b535e2e 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -1926,20 +1926,22 @@ def prepare_refit_info(self) -> None: for name, tensor in hf_params_generator: metadata = (tensor.shape, tensor.dtype) refit_param_info_hf[name] = metadata - + # Include KV/Q scale metadata when using FP8 KV cache use_fp8_kv_cache = False if "generation" in self.cfg and self.cfg["generation"] is not None: vllm_cfg = self.cfg["generation"].get("vllm_cfg", {}) kv_cache_dtype = vllm_cfg.get("kv_cache_dtype", "auto") use_fp8_kv_cache = kv_cache_dtype == "fp8" - + if use_fp8_kv_cache: # FP8 KV cache: Include KV/Q scale metadata for syncing try: # Get number of layers directly from transformer config num_layers = self.megatron_bridge.transformer_config.num_layers - print(f"[KV_SCALES] prepare_refit_info: FP8 KV cache enabled (kv_cache_dtype=fp8)") + print( + "[KV_SCALES] prepare_refit_info: FP8 KV cache enabled (kv_cache_dtype=fp8)" + ) print(f"[KV_SCALES] Adding scale metadata for {num_layers} layers") # Append q/k/v scale placeholders (shape [1], dtype float32) for layer_idx in range(num_layers): @@ -1950,8 +1952,10 @@ def prepare_refit_info(self) -> None: except Exception: pass else: - print(f"[KV_SCALES] prepare_refit_info: FP8 KV cache not enabled, skipping KV scale metadata") - + print( + "[KV_SCALES] prepare_refit_info: FP8 KV cache not enabled, skipping KV scale metadata" + ) + return refit_param_info_hf def _calculate_refit_param_info(self) -> list[tuple[str, int]]: @@ -2045,7 +2049,9 @@ def iter_with_kv_scales(): if use_fp8_kv_cache: # Get number of layers directly from transformer config num_layers = self.megatron_bridge.transformer_config.num_layers - print(f"[KV_SCALES] stream_weights_via_ipc_zmq: FP8 KV cache enabled, streaming scales for {num_layers} layers") + print( + f"[KV_SCALES] stream_weights_via_ipc_zmq: FP8 KV cache enabled, streaming scales for {num_layers} layers" + ) keys = [] for layer_idx in range(num_layers): scale_names = get_vllm_qkv_scale_names(layer_idx) @@ -2061,7 +2067,9 @@ def iter_with_kv_scales(): ).reshape(1) yield param_name, scale_tensor else: - print(f"[KV_SCALES] stream_weights_via_ipc_zmq: FP8 KV cache not enabled, skipping KV scales") + print( + "[KV_SCALES] stream_weights_via_ipc_zmq: FP8 KV cache not enabled, skipping KV scales" + ) # Use the shared implementation stream_weights_via_ipc_zmq_impl( @@ -2094,12 +2102,14 @@ def iter_with_kv_scales(): # First yield all model weights for name, tensor in hf_params_generator: yield name, tensor - + # Append KV scales for FP8 KV cache if use_fp8_kv_cache: # Get number of layers directly from transformer config num_layers = self.megatron_bridge.transformer_config.num_layers - print(f"[KV_SCALES] broadcast_weights_for_collective: FP8 KV cache enabled, broadcasting scales for {num_layers} layers") + print( + f"[KV_SCALES] broadcast_weights_for_collective: FP8 KV cache enabled, broadcasting scales for {num_layers} layers" + ) keys = [] for layer_idx in range(num_layers): scale_names = get_vllm_qkv_scale_names(layer_idx) @@ -2115,7 +2125,9 @@ def iter_with_kv_scales(): ).reshape(1) yield param_name, scale_tensor else: - print(f"[KV_SCALES] broadcast_weights_for_collective: FP8 KV cache not enabled, skipping KV scales") + print( + "[KV_SCALES] broadcast_weights_for_collective: FP8 KV cache not enabled, skipping KV scales" + ) # param_iterator will return (name, tensor), we only need tensor packed_broadcast_producer( From 9362035883bf4adf41bb7d0c9c01430ff5fc60f1 Mon Sep 17 00:00:00 2001 From: Zhaopeng Qiu Date: Mon, 17 Nov 2025 07:31:14 -0800 Subject: [PATCH 19/40] remove debug prints Signed-off-by: Zhaopeng Qiu --- nemo_rl/algorithms/grpo.py | 18 +++------------ nemo_rl/models/generation/fp8.py | 23 ------------------- .../models/generation/vllm/vllm_backend.py | 6 ----- 3 files changed, 3 insertions(+), 44 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index e63d68eb86..3fcd57dc36 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -927,9 +927,6 @@ def refit_policy_generation( policy.offload_before_refit() policy_generation.prepare_for_generation(tags=["weights"]) - if kv_scales: - print(f"[KV_SCALES] Refit: Adding {len(kv_scales)} KV scales to weight update") - # Create a context manager that does nothing when timer is None timer_context = ( timer.time("prepare_for_generation/transfer_and_update_weights") @@ -1015,14 +1012,6 @@ def grpo_train( sync_kv_scales = _should_sync_kv_scales(master_config) kv_scales_cache = None # Cache reused for computed kv scales - if sync_kv_scales: - print("[KV_SCALES] FP8 KV cache enabled (kv_cache_dtype=fp8, precision=fp8)") - print( - "[KV_SCALES] Will compute and sync q_scale, k_scale, v_scale during refit" - ) - else: - print("[KV_SCALES] KV cache scale sync not needed (kv_cache_dtype is not fp8)") - NEED_REFIT = True # If policy_generation is None, use the policy as the generation interface (megatron framework backend) if policy_generation is None: @@ -1116,9 +1105,7 @@ def grpo_train( if NEED_REFIT and POLICY_GENERATION_STALE: # Compute KV scales if needed for FP8 quantization if sync_kv_scales and kv_scales_cache is None: - print( - "[KV_SCALES] Computing KV cache scales for the first time..." - ) + print("▶ Computing KV cache scales...", flush=True) policy.prepare_for_lp_inference() # Create calibration data from flattened messages calibration_data = BatchedDataDict[ClippedPGLossDataDict]( @@ -1347,7 +1334,8 @@ def grpo_train( if sync_kv_scales: with timer.time("recompute_kv_scales"): print( - "[KV_SCALES] Recomputing KV cache scales after policy update..." + "▶ Recomputing KV cache scales after policy update...", + flush=True, ) kv_scales = policy.calibrate_qkv_fp8_scales( train_data, include_q=True diff --git a/nemo_rl/models/generation/fp8.py b/nemo_rl/models/generation/fp8.py index 903d1a4e58..55b15d7755 100644 --- a/nemo_rl/models/generation/fp8.py +++ b/nemo_rl/models/generation/fp8.py @@ -118,11 +118,6 @@ def kv_cache_process_weights_after_loading(self, layer: torch.nn.Module) -> None logger = init_logger(__name__) - # print(f"[KV_SCALES] kv_cache_process_weights_after_loading: layer.k_scale = {layer.k_scale}, layer.v_scale = {layer.v_scale}") - print( - f"[@@KV_SCALES@@] [fp8.py] kv_cache_process_weights_after_loading: layer.k_scale = {layer.k_scale}, layer.v_scale = {layer.v_scale}" - ) - # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 # regardless whether the kv-scale is available in the checkpoint. if layer.kv_cache_dtype != "auto": @@ -215,9 +210,6 @@ def kv_cache_process_weights_after_loading(self, layer: torch.nn.Module) -> None # IMPORTANT: We DON'T delete the parameters here to allow for dynamic updates # Original code deleted: layer.k_scale, layer.v_scale, layer.q_scale, layer.prob_scale - print( - "[KV_SCALES] Patched process_weights_after_loading: keeping k_scale, v_scale parameters for dynamic updates" - ) def get_vllm_qkv_scale_names(layer_idx: int) -> dict[str, str]: @@ -312,8 +304,6 @@ def apply_fp8_patches(self, fp8_config): # Apply weight-related patches only when using FP8 weights (precision=fp8) if global_fp8_config.use_fp8_weights: - print("[FP8_PATCHES] Applying FP8 weight quantization patches (precision=fp8)") - # This patch is used to support torch.compile with vllm parameter subclasses, such as # PerTensorScaleParameter. Because we need weight loaders to update fp8 weights each # refit, we patch fp8 parameters to have a reference to their weight loader. Eventually @@ -336,15 +326,10 @@ def apply_fp8_patches(self, fp8_config): # Apply KV cache patches only when using FP8 KV cache (kv_cache_dtype=fp8) if global_fp8_config.kv_cache_dtype == "fp8": - print("[FP8_PATCHES] Applying FP8 KV cache patches (kv_cache_dtype=fp8)") # Static scales mode: patch process_weights_after_loading to preserve k_scale/v_scale for manual updates - print("[FP8_PATCHES] Using static KV scales") func5_path = "vllm.model_executor.layers.quantization.kv_cache.BaseKVCacheMethod.process_weights_after_loading" patcher5 = patch(func5_path, kv_cache_process_weights_after_loading) fp8_state.vllm_patches.append(patcher5) - print( - "[FP8_PATCHES] Patched process_weights_after_loading to preserve k_scale/v_scale for updates" - ) for p in fp8_state.vllm_patches: p.start() @@ -422,9 +407,6 @@ def init_fp8(vllm_cfg, model_name, model_parallel_size): fp8_block_quant_kwargs["ignored_layers"] = bf16_params - # TODO: Remove this after debugging. - print(f"[KV_SCALES] Global FP8 config: {global_fp8_config}") - # Return FP8 kwargs (precision=fp8 is required at this point) # kv_cache_dtype can be "auto" or "fp8" vllm_kwargs = { @@ -528,14 +510,9 @@ def load_weights(weights, model_runner): model = model_runner.model for k, v in weights: - if "scale" in k: - print( - f"[@@KV_SCALES@@] [fp8.py] load_weights: Parameter {k}, value = {v.item() if v.numel() == 1 else v}" - ) if not _is_fp8_weight(k, model): weights_quantized.append((k, v)) continue - print(f"[@@KV_SCALES@@] [fp8.py] load_weights: Casting weight {k} into fp8") # Cast the weight into fp8 and its scale factor param_lp, param_scale = cast_tensor_to_fp8_blockwise( v.to(torch.float), diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index 092ce223d8..ef25c7d287 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -173,9 +173,6 @@ def update_weights_via_ipc_zmq(self) -> bool: if use_fp8_kv_cache: # FP8 KV cache: process KV scales after weight loading - print( - "[KV_SCALES] update_weights_via_ipc_zmq: FP8 KV cache detected, processing scales after weight loading" - ) from vllm.model_executor.model_loader.utils import ( process_weights_after_loading, ) @@ -250,9 +247,6 @@ def _load_model_weights(weights, model_runner): if use_fp8_kv_cache: # FP8 KV cache: process KV scales after weight loading - print( - "[KV_SCALES] update_weights_from_collective: FP8 KV cache detected, processing scales after weight loading" - ) from vllm.model_executor.model_loader.utils import ( process_weights_after_loading, ) From f35662ddb3279a2df1e4d4d4055837446ebbf637 Mon Sep 17 00:00:00 2001 From: Zhaopeng Qiu Date: Mon, 17 Nov 2025 07:32:45 -0800 Subject: [PATCH 20/40] make refitting with kv scales cleaner Signed-off-by: Zhaopeng Qiu --- .../models/policy/megatron_policy_worker.py | 184 +++++------------- 1 file changed, 51 insertions(+), 133 deletions(-) diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index bc5b535e2e..9162ef8cee 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -1903,13 +1903,6 @@ def maybe_init_zmq(self): self.zmq_socket.setsockopt(zmq.LINGER, 0) self.zmq_socket.bind(self.get_zmq_address()) - def _extract_layer_key(self, module_name: str) -> int: - # Expected format: "module.decoder.layers..self_attention.query_key_value" - m = re.search(r"model\.layers\.(\d+)", module_name) - if m is not None: - return int(m.group(1)) - return -1 - @torch.no_grad() @wrap_with_nvtx_name("megatron_policy_worker/prepare_refit_info") def prepare_refit_info(self) -> None: @@ -1918,43 +1911,9 @@ def prepare_refit_info(self) -> None: # Collect tensor metadata for refit / hf side info refit_param_info_hf = {} - hf_params_generator = self.megatron_bridge.export_hf_weights( - [self.model], - show_progress=False, - conversion_tasks=self.refit_conversion_tasks, # used for metadata caching - ) - for name, tensor in hf_params_generator: - metadata = (tensor.shape, tensor.dtype) - refit_param_info_hf[name] = metadata - - # Include KV/Q scale metadata when using FP8 KV cache - use_fp8_kv_cache = False - if "generation" in self.cfg and self.cfg["generation"] is not None: - vllm_cfg = self.cfg["generation"].get("vllm_cfg", {}) - kv_cache_dtype = vllm_cfg.get("kv_cache_dtype", "auto") - use_fp8_kv_cache = kv_cache_dtype == "fp8" - - if use_fp8_kv_cache: - # FP8 KV cache: Include KV/Q scale metadata for syncing - try: - # Get number of layers directly from transformer config - num_layers = self.megatron_bridge.transformer_config.num_layers - print( - "[KV_SCALES] prepare_refit_info: FP8 KV cache enabled (kv_cache_dtype=fp8)" - ) - print(f"[KV_SCALES] Adding scale metadata for {num_layers} layers") - # Append q/k/v scale placeholders (shape [1], dtype float32) - for layer_idx in range(num_layers): - scale_names = get_vllm_qkv_scale_names(layer_idx) - for param_name in scale_names.values(): - if param_name not in refit_param_info_hf: - refit_param_info_hf[param_name] = ([1], torch.float32) - except Exception: - pass - else: - print( - "[KV_SCALES] prepare_refit_info: FP8 KV cache not enabled, skipping KV scale metadata" - ) + # Reuse shared iterator that appends FP8 KV/Q scales when enabled + for name, tensor in self._iter_params_with_optional_kv_scales(): + refit_param_info_hf[name] = (tensor.shape, tensor.dtype) return refit_param_info_hf @@ -2016,64 +1975,67 @@ def get_free_memory_bytes(self) -> int: device_idx = torch.cuda.current_device() return get_free_memory_bytes(device_idx) - @torch.no_grad() - @wrap_with_nvtx_name("megatron_policy_worker/stream_weights_via_ipc_zmq") - def stream_weights_via_ipc_zmq( - self, buffer_size_bytes: int = 0, kv_scales: Optional[dict[str, float]] = None - ) -> None: - """Stream model weights to peer process via ZMQ IPC socket.""" - self.maybe_init_zmq() - - from nemo_rl.models.policy.utils import stream_weights_via_ipc_zmq_impl + def _iter_params_with_optional_kv_scales( + self, + kv_scales: Optional[dict[str, float]] = None, + ) -> Iterator[tuple[str, torch.Tensor]]: + """Yield exported HF parameters and optionally append FP8 KV/Q scale tensors. - # Generate HF parameters for streaming - hf_params_generator = self.megatron_bridge.export_hf_weights( + This helper is used by both IPC-based streaming and collective broadcast + so that the logic for adding KV scales stays consistent in one place. + """ + base_iter = self.megatron_bridge.export_hf_weights( [self.model], show_progress=False, conversion_tasks=self.refit_conversion_tasks, # used for metadata caching ) - # Stream KV scales when using FP8 KV cache + # Yield the original parameters first. + for name, tensor in base_iter: + yield name, tensor + + # Check whether FP8 KV cache is enabled. use_fp8_kv_cache = False if "generation" in self.cfg and self.cfg["generation"] is not None: vllm_cfg = self.cfg["generation"].get("vllm_cfg", {}) kv_cache_dtype = vllm_cfg.get("kv_cache_dtype", "auto") use_fp8_kv_cache = kv_cache_dtype == "fp8" - def iter_with_kv_scales(): - # First yield all model weights - for name, tensor in hf_params_generator: - yield name, tensor - - # Append KV scales for FP8 KV cache - if use_fp8_kv_cache: - # Get number of layers directly from transformer config - num_layers = self.megatron_bridge.transformer_config.num_layers - print( - f"[KV_SCALES] stream_weights_via_ipc_zmq: FP8 KV cache enabled, streaming scales for {num_layers} layers" - ) - keys = [] - for layer_idx in range(num_layers): - scale_names = get_vllm_qkv_scale_names(layer_idx) - keys.extend(scale_names.values()) - # Append kv-scale entries to match metadata; use provided value or default 1.0 - for param_name in keys: - if kv_scales and param_name in kv_scales: - scale_value = kv_scales[param_name] - else: - scale_value = 1.0 - scale_tensor = torch.tensor( - scale_value, dtype=torch.float32, device="cuda" - ).reshape(1) - yield param_name, scale_tensor + if not use_fp8_kv_cache: + return + + # Append KV (and potentially Q) scale entries to match metadata. + num_layers = self.megatron_bridge.transformer_config.num_layers + keys: list[str] = [] + for layer_idx in range(num_layers): + scale_names = get_vllm_qkv_scale_names(layer_idx) + keys.extend(scale_names.values()) + + for param_name in keys: + if kv_scales and param_name in kv_scales: + scale_value = kv_scales[param_name] else: - print( - "[KV_SCALES] stream_weights_via_ipc_zmq: FP8 KV cache not enabled, skipping KV scales" - ) + scale_value = 1.0 + scale_tensor = torch.tensor( + scale_value, dtype=torch.float32, device="cuda" + ).reshape(1) + yield param_name, scale_tensor + + @torch.no_grad() + @wrap_with_nvtx_name("megatron_policy_worker/stream_weights_via_ipc_zmq") + def stream_weights_via_ipc_zmq( + self, buffer_size_bytes: int = 0, kv_scales: Optional[dict[str, float]] = None + ) -> None: + """Stream model weights to peer process via ZMQ IPC socket.""" + self.maybe_init_zmq() - # Use the shared implementation + from nemo_rl.models.policy.utils import stream_weights_via_ipc_zmq_impl + + # Use the shared implementation to append optional KV scales. stream_weights_via_ipc_zmq_impl( - params_generator=iter_with_kv_scales(), + params_generator=self._iter_params_with_optional_kv_scales( + kv_scales=kv_scales + ), buffer_size_bytes=buffer_size_bytes, zmq_socket=self.zmq_socket, rank=self.rank, @@ -2085,53 +2047,9 @@ def broadcast_weights_for_collective( self, kv_scales: Optional[dict[str, float]] = None ) -> None: """Broadcast the weights for collective communication.""" - hf_params_generator = self.megatron_bridge.export_hf_weights( - [self.model], - show_progress=False, - conversion_tasks=self.refit_conversion_tasks, # used for metadata caching - ) - - # Broadcast KV scales when using FP8 KV cache - use_fp8_kv_cache = False - if "generation" in self.cfg and self.cfg["generation"] is not None: - vllm_cfg = self.cfg["generation"].get("vllm_cfg", {}) - kv_cache_dtype = vllm_cfg.get("kv_cache_dtype", "auto") - use_fp8_kv_cache = kv_cache_dtype == "fp8" - - def iter_with_kv_scales(): - # First yield all model weights - for name, tensor in hf_params_generator: - yield name, tensor - - # Append KV scales for FP8 KV cache - if use_fp8_kv_cache: - # Get number of layers directly from transformer config - num_layers = self.megatron_bridge.transformer_config.num_layers - print( - f"[KV_SCALES] broadcast_weights_for_collective: FP8 KV cache enabled, broadcasting scales for {num_layers} layers" - ) - keys = [] - for layer_idx in range(num_layers): - scale_names = get_vllm_qkv_scale_names(layer_idx) - keys.extend(scale_names.values()) - # Append kv-scale entries to match metadata; use provided value or default 1.0 - for param_name in keys: - if kv_scales and param_name in kv_scales: - scale_value = kv_scales[param_name] - else: - scale_value = 1.0 - scale_tensor = torch.tensor( - scale_value, dtype=torch.float32, device="cuda" - ).reshape(1) - yield param_name, scale_tensor - else: - print( - "[KV_SCALES] broadcast_weights_for_collective: FP8 KV cache not enabled, skipping KV scales" - ) - - # param_iterator will return (name, tensor), we only need tensor + # param_iterator will return (name, tensor), we only need tensor. packed_broadcast_producer( - iterator=iter_with_kv_scales(), + iterator=self._iter_params_with_optional_kv_scales(kv_scales=kv_scales), group=self.model_update_group, src=0, post_iter_func=lambda x: x[1], From 667661e0d4340b689e4db7c75ff73c0319242e0f Mon Sep 17 00:00:00 2001 From: Zhaopeng Qiu Date: Mon, 17 Nov 2025 21:50:15 -0800 Subject: [PATCH 21/40] remove debug print; raise errors of calibration process; refine refitting code Signed-off-by: Zhaopeng Qiu --- .../models/generation/vllm/vllm_backend.py | 83 +++++++---------- nemo_rl/models/policy/interfaces.py | 2 - nemo_rl/models/policy/lm_policy.py | 2 - .../models/policy/megatron_policy_worker.py | 89 +++++-------------- 4 files changed, 57 insertions(+), 119 deletions(-) diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index ef25c7d287..41cd6e27ca 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -96,6 +96,35 @@ def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: """ self.state_dict_info = state_dict_info # pyrefly: ignore[implicitly-defined-attribute] This class does not define __init__ so assignments like this should be ignored + def _maybe_process_fp8_kv_cache(self) -> None: + """Process weights after loading for FP8 KV cache (static scales).""" + use_fp8_kv_cache = False + if hasattr(self.model_runner.vllm_config, "cache_config"): + kv_cache_dtype = getattr( + self.model_runner.vllm_config.cache_config, "cache_dtype", None + ) + use_fp8_kv_cache = ( + kv_cache_dtype is not None and "fp8" in str(kv_cache_dtype).lower() + ) + + if not use_fp8_kv_cache: + return + + # FP8 KV cache: process KV scales after weight loading + from vllm.model_executor.model_loader.utils import ( + process_weights_after_loading, + ) + + # Get target device for processing + target_device = next(self.model_runner.model.parameters()).device + + # Call process_weights_after_loading to handle KV scales + process_weights_after_loading( + self.model_runner.model, + self.model_runner.model_config, + target_device, + ) + @wrap_with_nvtx_name("vllm_internal_worker_extension/update_weights_via_ipc_zmq") def update_weights_via_ipc_zmq(self) -> bool: """Receive and update model weights via ZMQ IPC socket. @@ -161,31 +190,8 @@ def update_weights_via_ipc_zmq(self) -> bool: buffer = None self.zmq_socket.send(IPCProtocol.ACK.value.encode()) - # Process weights after loading for FP8 KV cache (static scales) - use_fp8_kv_cache = False - if hasattr(self.model_runner.vllm_config, "cache_config"): - kv_cache_dtype = getattr( - self.model_runner.vllm_config.cache_config, "cache_dtype", None - ) - use_fp8_kv_cache = ( - kv_cache_dtype is not None and "fp8" in str(kv_cache_dtype).lower() - ) - - if use_fp8_kv_cache: - # FP8 KV cache: process KV scales after weight loading - from vllm.model_executor.model_loader.utils import ( - process_weights_after_loading, - ) - - # Get target device for processing - target_device = next(self.model_runner.model.parameters()).device - - # Call process_weights_after_loading to handle KV scales - process_weights_after_loading( - self.model_runner.model, - self.model_runner.model_config, - target_device, - ) + # Process weights after loading for FP8 KV cache + self._maybe_process_fp8_kv_cache() gc.collect() torch.cuda.empty_cache() @@ -235,31 +241,8 @@ def _load_model_weights(weights, model_runner): post_unpack_func=load_model_weight_func, ) - # Process weights after loading for FP8 KV cache (static scales) - use_fp8_kv_cache = False - if hasattr(self.model_runner.vllm_config, "cache_config"): - kv_cache_dtype = getattr( - self.model_runner.vllm_config.cache_config, "cache_dtype", None - ) - use_fp8_kv_cache = ( - kv_cache_dtype is not None and "fp8" in str(kv_cache_dtype).lower() - ) - - if use_fp8_kv_cache: - # FP8 KV cache: process KV scales after weight loading - from vllm.model_executor.model_loader.utils import ( - process_weights_after_loading, - ) - - # Get target device for processing - target_device = next(self.model_runner.model.parameters()).device - - # Call process_weights_after_loading to handle KV scales - process_weights_after_loading( - self.model_runner.model, - self.model_runner.model_config, - target_device, - ) + # Process weights after loading for FP8 KV cache + self._maybe_process_fp8_kv_cache() except Exception as e: print( diff --git a/nemo_rl/models/policy/interfaces.py b/nemo_rl/models/policy/interfaces.py index 70e7ca4459..76949749ac 100644 --- a/nemo_rl/models/policy/interfaces.py +++ b/nemo_rl/models/policy/interfaces.py @@ -128,7 +128,6 @@ def calibrate_qkv_fp8_scales( micro_batch_size: Optional[int] = None, percentile: float = 99.9, margin: float = 1.05, - save_path: Optional[str] = None, include_q: bool = False, ) -> dict[str, Any]: """Calibrate FP8 scales for Q/K/V activations used by KV cache. @@ -138,7 +137,6 @@ def calibrate_qkv_fp8_scales( micro_batch_size: Optional override for micro batch size during calibration. percentile: Percentile for per-tensor amax estimation. margin: Safety margin multiplier applied to amax. - save_path: If provided, rank0 will write JSON results to this path. include_q: Whether to also compute scale for Q in addition to K/V. Returns: diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index 5e63c66f73..a6b314d75f 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -696,7 +696,6 @@ def calibrate_qkv_fp8_scales( micro_batch_size: Optional[int] = None, percentile: float = 99.9, margin: float = 1.05, - save_path: Optional[str] = None, include_q: bool = False, ) -> dict[str, Any]: """Trigger KV-cache FP8 scale calibration across Megatron workers and return results. @@ -748,7 +747,6 @@ def calibrate_qkv_fp8_scales( "micro_batch_size": micro_batch_size, "percentile": percentile, "margin": margin, - "save_path": save_path, "include_q": include_q, }, ) diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 9162ef8cee..46feedabfd 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import gc -import json import math import os import re @@ -2414,7 +2413,6 @@ def calibrate_qkv_fp8_scales( micro_batch_size: Optional[int] = None, percentile: float = 99.9, margin: float = 1.05, - save_path: Optional[str] = None, include_q: bool = False, ) -> dict[str, Any]: """One-shot calibration of Q/K/V activation scales (for FP8 KV cache). @@ -2470,29 +2468,23 @@ def _pre_hook_builder_core_attention(module_name: str): layer_key = _extract_layer_key(module_name) def _pre_hook(module, inputs): - try: - args = inputs if isinstance(inputs, (tuple, list)) else (inputs,) - if len(args) == 1 and isinstance(args[0], (tuple, list)): - args = args[0] - # Expected first 3 args to be q, k, v (typical signature for Megatron CoreAttention) - q = args[0] - k = args[1] - v = args[2] - if include_q: - layer_to_samples_q[layer_key].append( - float(torch.amax(torch.abs(q)).item()) - ) - layer_to_samples_k[layer_key].append( - float(torch.amax(torch.abs(k)).item()) - ) - layer_to_samples_v[layer_key].append( - float(torch.amax(torch.abs(v)).item()) - ) - except Exception as e: - print( - f"[KV_SCALES] Error in core_attention pre-hook on {module_name}: {e}" + args = inputs if isinstance(inputs, (tuple, list)) else (inputs,) + if len(args) == 1 and isinstance(args[0], (tuple, list)): + args = args[0] + # Expected first 3 args to be q, k, v (typical signature for Megatron CoreAttention) + q = args[0] + k = args[1] + v = args[2] + if include_q: + layer_to_samples_q[layer_key].append( + float(torch.amax(torch.abs(q)).item()) ) - pass + layer_to_samples_k[layer_key].append( + float(torch.amax(torch.abs(k)).item()) + ) + layer_to_samples_v[layer_key].append( + float(torch.amax(torch.abs(v)).item()) + ) return _pre_hook @@ -2507,34 +2499,12 @@ def _pre_hook(module, inputs): hook_handles.append(handle) matched_modules.append((name, module.__class__.__name__, "pre")) except Exception as e: - print(f"[KV_SCALES] Error registering pre-hook on {name}: {e}") - continue - - if not hook_handles: - print( - "[KV_SCALES] No QKV proj modules matched for hook. Example module/param names:" - ) - try: - # Print the first 10 modules and parameters to help locate the actual names - cnt = 0 - for n, _m in self.model.named_modules(): - if cnt >= 10: - break - print(f" [module] {n}") - cnt += 1 - cnt = 0 - for n, _p in self.model.named_parameters(): - if cnt >= 10: - break - print(f" [param] {n}") - cnt += 1 - except Exception: - pass - else: - # Lightly print the first few modules and parameters to confirm hits - print("[KV_SCALES] Registered hooks on modules (showing up to 8):") - for i, (mn, cls, kind) in enumerate(matched_modules[:8]): - print(f" {i:02d}: {mn} <{cls}> [{kind}]") + print( + f"Error registering pre-hook for qkv scale calibration on {name}: {e}" + " Please check if the model is compatible with the current calibration logic. " + "The expected module name is 'self_attention.core_attention'." + ) + raise # Run a forward pass to trigger hooks (reuse get_logprobs forward path) try: @@ -2544,8 +2514,8 @@ def _pre_hook(module, inputs): try: h.remove() except Exception as e: - print(f"[KV_SCALES] Error removing hook: {e}") - pass + print(f"Error removing hook for qkv scale calibration: {e}") + raise # Compute local percentile amax def _percentile(values: list[float], p: float) -> float: @@ -2616,7 +2586,6 @@ def _percentile(values: list[float], p: float) -> float: "margin": margin, "layers": result_layers, } - print(f"[KV_SCALES] Calibrated KV cache scales: {final_result}") # Sync results across all ranks (broadcast rank0's result) if world_size > 1: @@ -2629,16 +2598,6 @@ def _percentile(values: list[float], p: float) -> float: torch.distributed.broadcast_object_list(obj_list, src=0) final_result = obj_list[0] # type: ignore - # Optional save to JSON (only rank0) - if save_path is not None and ( - not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 - ): - try: - with open(save_path, "w") as f: - json.dump(final_result, f) - except Exception: - pass - return final_result From b84b10ed640fcb2d617cd9864f2408aa5dd34a6b Mon Sep 17 00:00:00 2001 From: Zhaopeng Qiu Date: Mon, 17 Nov 2025 22:02:31 -0800 Subject: [PATCH 22/40] remove old hotfix about save_ckpt Signed-off-by: Zhaopeng Qiu --- .../models/policy/megatron_policy_worker.py | 31 ------------------- 1 file changed, 31 deletions(-) diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 46feedabfd..ba5835f82d 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -2230,12 +2230,6 @@ def save_checkpoint( weights_path: The specific directory path where the checkpoint will be saved. optimizer_path: If not None, optimizer and scheduler states are saved if they exist. """ - # Temporary fix to avoid OOM after saving checkpoint - allocated = torch.cuda.memory_allocated() / (1024**3) # Convert to GB - reserved = torch.cuda.memory_reserved() / (1024**3) # Convert to GB - print( - f"GPU Memory before saving checkpoint: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved" - ) if not torch.distributed.is_initialized(): raise RuntimeError( "Distributed process group is not initialized. Cannot save checkpoint." @@ -2297,31 +2291,6 @@ def save_checkpoint( if not is_training: # Restore training state if it was changed self.model.train() - # Temporary fix to avoid OOM after saving checkpoint: https://github.com/NVIDIA-NeMo/RL/issues/1057 - torch.randn(1).cuda() # wake up torch allocator - if hasattr(self, "optimizer") and self.optimizer is not None: - # Iterate through the state dictionaries for each parameter group - if isinstance(self.optimizer, ChainedOptimizer): - optimizer_state = self.optimizer.state - else: - optimizer_state = self.optimizer._get_state() - for _, state in optimizer_state.items(): - # Iterate through the state items (e.g., momentum, variance) for a parameter - for k, v in state.items(): - # Check if the item is a tensor and on the GPU - if torch.is_tensor(v) and v.is_cuda: - # Move the tensor to CPU and update the state dictionary - state[k] = v.to("cpu") - - gc.collect() - torch.cuda.empty_cache() - - allocated = torch.cuda.memory_allocated() / (1024**3) # Convert to GB - reserved = torch.cuda.memory_reserved() / (1024**3) # Convert to GB - print( - f"GPU Memory after saving checkpoint: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved" - ) - except Exception as e: print(f"Failed to save checkpoint to {weights_path}: {e}") raise From 50c4abd499c2895b07235669d6e1b50458b3a498 Mon Sep 17 00:00:00 2001 From: Zhaopeng Qiu Date: Mon, 17 Nov 2025 23:55:01 -0800 Subject: [PATCH 23/40] avoid importing vllm at grpo.py Signed-off-by: Zhaopeng Qiu --- nemo_rl/algorithms/grpo.py | 11 ++--------- nemo_rl/models/generation/fp8.py | 3 +++ nemo_rl/models/policy/megatron_policy_worker.py | 9 +++++++-- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 3fcd57dc36..3d78e02629 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -59,7 +59,6 @@ run_async_penguin_rollout, run_multi_turn_rollout, ) -from nemo_rl.models.generation.fp8 import convert_calibration_to_vllm_format from nemo_rl.models.generation.interfaces import GenerationInterface from nemo_rl.models.generation.vllm import VllmConfig, VllmGeneration from nemo_rl.models.policy import PolicyConfig @@ -1119,13 +1118,9 @@ def grpo_train( batched_flat.get_multimodal_dict(as_tensors=False) ) calibration_data.to("cpu") - kv_scales = policy.calibrate_qkv_fp8_scales( + kv_scales_cache = policy.calibrate_qkv_fp8_scales( calibration_data, include_q=True )["layers"] - # Convert calibration results to vLLM parameter format - kv_scales_cache = convert_calibration_to_vllm_format( - kv_scales - ) refit_policy_generation( policy, @@ -1337,11 +1332,9 @@ def grpo_train( "▶ Recomputing KV cache scales after policy update...", flush=True, ) - kv_scales = policy.calibrate_qkv_fp8_scales( + kv_scales_cache = policy.calibrate_qkv_fp8_scales( train_data, include_q=True )["layers"] - # Convert calibration results to vLLM parameter format - kv_scales_cache = convert_calibration_to_vllm_format(kv_scales) # Set generation as stale to force refit with new scales POLICY_GENERATION_STALE = True diff --git a/nemo_rl/models/generation/fp8.py b/nemo_rl/models/generation/fp8.py index 55b15d7755..ca7e17ec03 100644 --- a/nemo_rl/models/generation/fp8.py +++ b/nemo_rl/models/generation/fp8.py @@ -252,6 +252,9 @@ def convert_calibration_to_vllm_format( ) -> dict[str, float]: """Convert NeMo-RL calibration results to vLLM parameter format. + Currently only used by megatron policy worker. + After FP8 KV cache is supported by DTensor path, this function can be reused. + This function transforms the calibration output format (with layer_N keys) into the flat dictionary format that vLLM expects for parameter loading. diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index ba5835f82d..d19a320954 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -109,7 +109,10 @@ from_parallel_logits_to_logprobs_packed_sequences, ) from nemo_rl.distributed.named_sharding import NamedSharding -from nemo_rl.models.generation.fp8 import get_vllm_qkv_scale_names +from nemo_rl.models.generation.fp8 import ( + convert_calibration_to_vllm_format, + get_vllm_qkv_scale_names, +) from nemo_rl.models.generation.interfaces import ( GenerationDatumSpec, GenerationOutputSpec, @@ -2549,11 +2552,13 @@ def _percentile(values: list[float], p: float) -> float: out_entry["v_scale"] = float(v_scale) result_layers[layer_key] = out_entry + vllm_format_scales = convert_calibration_to_vllm_format(result_layers) + final_result = { "format": "fp8", "percentile": percentile, "margin": margin, - "layers": result_layers, + "layers": vllm_format_scales, } # Sync results across all ranks (broadcast rank0's result) From 0dbf7ab7775d78dbac2d0e4097edeec75d22f98b Mon Sep 17 00:00:00 2001 From: Zhaopeng Qiu Date: Tue, 18 Nov 2025 04:46:50 -0800 Subject: [PATCH 24/40] add placeholder func and parameter for dtensor path Signed-off-by: Zhaopeng Qiu --- .../models/policy/dtensor_policy_worker.py | 24 +++++++++++++++++-- .../models/policy/dtensor_policy_worker_v2.py | 24 +++++++++++++++++-- 2 files changed, 44 insertions(+), 4 deletions(-) diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 9b15733e2e..e15c7e7d11 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -1746,9 +1746,27 @@ def get_free_memory_bytes(self) -> int: device_idx = torch.cuda.current_device() return get_free_memory_bytes(device_idx) + @torch.no_grad() + def calibrate_qkv_fp8_scales( + self, + data: BatchedDataDict[Any], + micro_batch_size: Optional[int] = None, + percentile: float = 99.9, + margin: float = 1.05, + include_q: bool = False, + ) -> dict[str, Any]: + """Placeholder for FP8 Q/K/V scale calibration, not implemented for DTensorPolicyWorker.""" + raise NotImplementedError( + "calibrate_qkv_fp8_scales is not implemented for DTensorPolicyWorker" + ) + @torch.no_grad() @wrap_with_nvtx_name("dtensor_policy_worker/stream_weights_via_ipc_zmq") - def stream_weights_via_ipc_zmq(self, buffer_size_bytes: int = 0) -> None: + def stream_weights_via_ipc_zmq( + self, + buffer_size_bytes: int = 0, + kv_scales: Optional[dict[str, float]] = None, + ) -> None: """Stream model weights to peer process via ZMQ IPC socket.""" self.maybe_init_zmq() # Manually move model to cuda for cpu offload case @@ -1782,7 +1800,9 @@ def dtensor_params_generator(): ) @torch.no_grad() - def broadcast_weights_for_collective(self) -> None: + def broadcast_weights_for_collective( + self, kv_scales: Optional[dict[str, float]] = None + ) -> None: """Broadcast the weights for collective communication.""" # Manually move model to cuda for cpu offload case if self.cpu_offload: diff --git a/nemo_rl/models/policy/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/dtensor_policy_worker_v2.py index d1691b22ef..ea5e53d5e9 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/dtensor_policy_worker_v2.py @@ -1707,9 +1707,27 @@ def get_free_memory_bytes(self) -> int: device_idx = torch.cuda.current_device() return get_free_memory_bytes(device_idx) + @torch.no_grad() + def calibrate_qkv_fp8_scales( + self, + data: BatchedDataDict[Any], + micro_batch_size: Optional[int] = None, + percentile: float = 99.9, + margin: float = 1.05, + include_q: bool = False, + ) -> dict[str, Any]: + """Placeholder for FP8 Q/K/V scale calibration, not implemented for DTensorPolicyWorkerV2.""" + raise NotImplementedError( + "calibrate_qkv_fp8_scales is not implemented for DTensorPolicyWorkerV2" + ) + @torch.no_grad() @wrap_with_nvtx_name("dtensor_policy_worker_v2/stream_weights_via_ipc_zmq") - def stream_weights_via_ipc_zmq(self, buffer_size_bytes: int = 0) -> None: + def stream_weights_via_ipc_zmq( + self, + buffer_size_bytes: int = 0, + kv_scales: Optional[dict[str, float]] = None, + ) -> None: """Stream model weights to peer process via ZMQ IPC socket.""" self.maybe_init_zmq() # Manually move model to cuda for cpu offload case @@ -1743,7 +1761,9 @@ def dtensor_params_generator(): ) @torch.no_grad() - def broadcast_weights_for_collective(self) -> None: + def broadcast_weights_for_collective( + self, kv_scales: Optional[dict[str, float]] = None + ) -> None: """Broadcast the weights for collective communication.""" # Manually move model to cuda for cpu offload case if self.cpu_offload: From 0ea586b2072c5715ecb38da60638fc6c3dd3accf Mon Sep 17 00:00:00 2001 From: Shuang Yu Date: Fri, 21 Nov 2025 10:55:50 +0800 Subject: [PATCH 25/40] Refit should take care of kv_scales in the validation phase Signed-off-by: Shuang Yu --- nemo_rl/algorithms/grpo.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 3d78e02629..7c97bbc606 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -1348,7 +1348,10 @@ def grpo_train( if val_period > 0 and (total_steps + 1) % val_period == 0: if NEED_REFIT and POLICY_GENERATION_STALE: refit_policy_generation( - policy, policy_generation, colocated_inference + policy, + policy_generation, + colocated_inference, + kv_scales=kv_scales_cache if sync_kv_scales else None, ) POLICY_GENERATION_STALE = False else: From 8f759a1f9c71bcaf6e2ad51a278c004c571ed809 Mon Sep 17 00:00:00 2001 From: Shuang Yu Date: Fri, 21 Nov 2025 10:59:12 +0800 Subject: [PATCH 26/40] Remote TODO comment Signed-off-by: Shuang Yu --- nemo_rl/algorithms/grpo.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 7c97bbc606..ccabc3eec1 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -1344,7 +1344,6 @@ def grpo_train( ) # Run validation if it's a validation step - # TODO: Add validation with kv scales if needed if val_period > 0 and (total_steps + 1) % val_period == 0: if NEED_REFIT and POLICY_GENERATION_STALE: refit_policy_generation( From 6f3bed7519e4dc24c5cba3b12b9be2d31980e0e5 Mon Sep 17 00:00:00 2001 From: Shuang Yu Date: Tue, 25 Nov 2025 13:06:46 +0800 Subject: [PATCH 27/40] Rename the example yaml file to grpo_math_qwen3_8B_fp8_kvcache.yaml and move it to the recipes path Signed-off-by: Shuang Yu --- .../llm/grpo_math_qwen3_8B_fp8_kvcache.yaml} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/configs/{grpo_math_8B_megatron_fp8_kvcache.yaml => recipes/llm/grpo_math_qwen3_8B_fp8_kvcache.yaml} (100%) diff --git a/examples/configs/grpo_math_8B_megatron_fp8_kvcache.yaml b/examples/configs/recipes/llm/grpo_math_qwen3_8B_fp8_kvcache.yaml similarity index 100% rename from examples/configs/grpo_math_8B_megatron_fp8_kvcache.yaml rename to examples/configs/recipes/llm/grpo_math_qwen3_8B_fp8_kvcache.yaml From 4089ab58ea0d70f9904a9899f31d74a51be5858f Mon Sep 17 00:00:00 2001 From: Shuang Yu Date: Tue, 25 Nov 2025 13:53:54 +0800 Subject: [PATCH 28/40] Add kv_cache fp8 test case to test_vllm_generation_with_megatron_training Signed-off-by: Shuang Yu --- tests/unit/models/generation/test_vllm_generation.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index 79e4112485..463422afd8 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -1920,14 +1920,19 @@ async def test_vllm_refit_non_colocated_update_weights( @pytest.mark.timeout(360) @pytest.mark.parametrize("tensor_parallel_size", [1, 2]) @pytest.mark.parametrize("vllm_precision", ["bfloat16", "fp8"]) +@pytest.mark.parametrize("kv_cache_dtype", [None, "fp8"]) def test_vllm_generation_with_megatron_training( - cluster, tokenizer, tensor_parallel_size, vllm_precision + cluster, tokenizer, tensor_parallel_size, vllm_precision, kv_cache_dtype ): """Test that uses vLLM for generation and Megatron policy for training and logprob computation. This test validates that vLLM and Megatron policies can work together. """ + # Skip invalid configurations: kv_cache_dtype=fp8 requires precision=fp8 + if kv_cache_dtype == "fp8" and vllm_precision != "fp8": + pytest.skip("kv_cache_dtype='fp8' requires precision='fp8'") + # Skip the fp8 tests if the GPU is not H100 or newer (compute capability < 9.0) if vllm_precision == "fp8": major_capability, _ = torch.cuda.get_device_capability() @@ -1951,6 +1956,8 @@ def test_vllm_generation_with_megatron_training( vllm_config["tokenizer"]["name"] = model_name vllm_config["vllm_cfg"]["async_engine"] = False vllm_config["vllm_cfg"]["precision"] = vllm_precision + if kv_cache_dtype is not None: + vllm_config["vllm_cfg"]["kv_cache_dtype"] = kv_cache_dtype vllm_config = configure_generation_config(vllm_config, test_tokenizer) # Megatron config with same model From ac6f66c906de549581b685f42e4003e289776348 Mon Sep 17 00:00:00 2001 From: alexchiu Date: Wed, 26 Nov 2025 10:18:30 +0800 Subject: [PATCH 29/40] update pp>1 assert info Co-authored-by: Guyue Huang <140554423+guyueh1@users.noreply.github.com> Signed-off-by: alexchiu --- nemo_rl/algorithms/grpo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index d4f8b85e52..56a0da8a3b 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -514,7 +514,7 @@ def init_vllm(): "Async rollouts is not supported with kv cache fp8 enabled." ) assert policy_config["megatron_cfg"]["pipeline_model_parallel_size"] == 1, ( - "Pipeline model parallel size must be 1 for megatron backend with kv cache fp8 enabled." + "Currently when using FP8 KV cache in generation, then in megatron we only support pipeline_model_parallel_size=1. We will add more support in future." ) ## make vllm hf overrides match the training policy From af60c9a8e82f15df9421ad01d2537ba69720d00f Mon Sep 17 00:00:00 2001 From: Zhaopeng Qiu Date: Tue, 25 Nov 2025 18:32:54 -0800 Subject: [PATCH 30/40] update guard statements in DTensor path files Signed-off-by: Zhaopeng Qiu --- nemo_rl/models/policy/dtensor_policy_worker.py | 10 ++++++++++ nemo_rl/models/policy/dtensor_policy_worker_v2.py | 10 ++++++++++ 2 files changed, 20 insertions(+) diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 84850de975..efce840837 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -1764,6 +1764,11 @@ def stream_weights_via_ipc_zmq( kv_scales: Optional[dict[str, float]] = None, ) -> None: """Stream model weights to peer process via ZMQ IPC socket.""" + if kv_scales is not None: + raise NotImplementedError( + "FP8 kvcache is not currently supported for DTensor path, we will support it in the future." + ) + self.maybe_init_zmq() # Manually move model to cuda for cpu offload case if self.cpu_offload: @@ -1800,6 +1805,11 @@ def broadcast_weights_for_collective( self, kv_scales: Optional[dict[str, float]] = None ) -> None: """Broadcast the weights for collective communication.""" + if kv_scales is not None: + raise NotImplementedError( + "FP8 kvcache is not currently supported for DTensor path, we will support it in the future." + ) + # Manually move model to cuda for cpu offload case if self.cpu_offload: print( diff --git a/nemo_rl/models/policy/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/dtensor_policy_worker_v2.py index 3fe671aaf8..5751ce7f62 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/dtensor_policy_worker_v2.py @@ -1723,6 +1723,11 @@ def stream_weights_via_ipc_zmq( kv_scales: Optional[dict[str, float]] = None, ) -> None: """Stream model weights to peer process via ZMQ IPC socket.""" + if kv_scales is not None: + raise NotImplementedError( + "FP8 kvcache is not currently supported for DTensor path, we will support it in the future." + ) + self.maybe_init_zmq() # Manually move model to cuda for cpu offload case if self.cpu_offload: @@ -1759,6 +1764,11 @@ def broadcast_weights_for_collective( self, kv_scales: Optional[dict[str, float]] = None ) -> None: """Broadcast the weights for collective communication.""" + if kv_scales is not None: + raise NotImplementedError( + "FP8 kvcache is not currently supported for DTensor path, we will support it in the future." + ) + # Manually move model to cuda for cpu offload case if self.cpu_offload: print( From 3a3711939afc426f0be58b0867938208f7c50c63 Mon Sep 17 00:00:00 2001 From: Zhaopeng Qiu Date: Wed, 26 Nov 2025 00:51:02 -0800 Subject: [PATCH 31/40] add l1 test; update config yaml Signed-off-by: Zhaopeng Qiu --- .../llm/grpo_math_qwen3_8B_fp8_kvcache.yaml | 19 --------- tests/functional/L1_Functional_Tests_GPU.sh | 1 + tests/functional/grpo_fp8_kvcache.sh | 41 +++++++++++++++++++ 3 files changed, 42 insertions(+), 19 deletions(-) delete mode 100644 examples/configs/recipes/llm/grpo_math_qwen3_8B_fp8_kvcache.yaml create mode 100644 tests/functional/grpo_fp8_kvcache.sh diff --git a/examples/configs/recipes/llm/grpo_math_qwen3_8B_fp8_kvcache.yaml b/examples/configs/recipes/llm/grpo_math_qwen3_8B_fp8_kvcache.yaml deleted file mode 100644 index 8932af29b2..0000000000 --- a/examples/configs/recipes/llm/grpo_math_qwen3_8B_fp8_kvcache.yaml +++ /dev/null @@ -1,19 +0,0 @@ -# GRPO Algorithm Configuration -defaults: "grpo_math_8B_megatron.yaml" - -loss_fn: - use_importance_sampling_correction: true - -policy: - model_name: "Qwen/Qwen3-8B-Base" - megatron_cfg: - converter_type: "Qwen3ForCausalLM" - tensor_model_parallel_size: 4 - pipeline_model_parallel_size: 1 - generation: - vllm_cfg: - precision: 'fp8' - kv_cache_dtype: 'fp8' - use_deep_gemm: true - num_last_layers_in_bf16: 0 - num_first_layers_in_bf16: 0 diff --git a/tests/functional/L1_Functional_Tests_GPU.sh b/tests/functional/L1_Functional_Tests_GPU.sh index 99fc24c52b..fb1fb2101d 100644 --- a/tests/functional/L1_Functional_Tests_GPU.sh +++ b/tests/functional/L1_Functional_Tests_GPU.sh @@ -25,6 +25,7 @@ time uv run --no-sync bash ./tests/functional/grpo_async.sh time uv run --no-sync bash ./tests/functional/grpo_megatron.sh time uv run --no-sync bash ./tests/functional/grpo_multiturn.sh time uv run --no-sync bash ./tests/functional/grpo_non_colocated.sh +time uv run --no-sync bash ./tests/functional/grpo_fp8_kvcache.sh time uv run --no-sync bash ./tests/functional/dpo.sh time uv run --no-sync bash ./tests/functional/rm.sh time uv run --no-sync bash ./tests/functional/eval.sh diff --git a/tests/functional/grpo_fp8_kvcache.sh b/tests/functional/grpo_fp8_kvcache.sh new file mode 100644 index 0000000000..0132f0e11d --- /dev/null +++ b/tests/functional/grpo_fp8_kvcache.sh @@ -0,0 +1,41 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) +# Mark the current repo as safe, since wandb fetches metadata about the repo +git config --global --add safe.directory $PROJECT_ROOT + +set -eou pipefail + +EXP_NAME=$(basename $0 .sh) +EXP_DIR=$SCRIPT_DIR/$EXP_NAME +LOG_DIR=$EXP_DIR/logs +JSON_METRICS=$EXP_DIR/metrics.json +RUN_LOG=$EXP_DIR/run.log +export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} + +rm -rf $EXP_DIR $LOG_DIR +mkdir -p $EXP_DIR $LOG_DIR + +cd $PROJECT_ROOT +uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \ + $PROJECT_ROOT/examples/run_grpo_math.py \ + --config $PROJECT_ROOT/examples/configs/recipes/llm/grpo-qwen3-8b-base-1n8g-fp8-kvcache-megatron.yaml \ + policy.model_name=Qwen/Qwen3-0.6B-Base \ + cluster.gpus_per_node=2 \ + grpo.max_num_steps=3 \ + policy.megatron_cfg.tensor_model_parallel_size=2 \ + policy.megatron_cfg.scheduler.lr_warmup_iters=0 \ + policy.generation.vllm_cfg.use_deep_gemm=false \ + logger.tensorboard_enabled=true \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=false \ + logger.monitor_gpus=true \ + checkpointing.enabled=false \ + $@ \ + 2>&1 | tee $RUN_LOG + +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +uv run tests/check_metrics.py $JSON_METRICS \ + 'data["train/token_mult_prob_error"]["3"] < 1.5' From 231a739d81af5d379a0c17d5da14aa5b2d8fe7c3 Mon Sep 17 00:00:00 2001 From: Zhaopeng Qiu Date: Wed, 26 Nov 2025 00:52:33 -0800 Subject: [PATCH 32/40] at first calibration align with training data processing to ensure parallel training compatibility Signed-off-by: Zhaopeng Qiu --- nemo_rl/algorithms/grpo.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 66de08b5e7..d972db2ffa 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -1114,16 +1114,27 @@ def grpo_train( if sync_kv_scales and kv_scales_cache is None: print("▶ Computing KV cache scales...", flush=True) policy.prepare_for_lp_inference() + # Align with training data processing to ensure parallel training compatibility + calib_flat, calib_input_lengths = ( + batched_message_log_to_flat_message( + repeated_batch["message_log"], + pad_value_dict={ + "token_ids": tokenizer.pad_token_id + }, + make_sequence_length_divisible_by=master_config[ + "policy" + ]["make_sequence_length_divisible_by"], + ) + ) # Create calibration data from flattened messages calibration_data = BatchedDataDict[ClippedPGLossDataDict]( { - "input_ids": batched_flat["token_ids"], - "input_lengths": input_lengths, + "input_ids": calib_flat["token_ids"], + "input_lengths": calib_input_lengths, } ) - # this will be mini-batched inside the policy, so maintain the packed multimodal structure calibration_data.update( - batched_flat.get_multimodal_dict(as_tensors=False) + calib_flat.get_multimodal_dict(as_tensors=False) ) calibration_data.to("cpu") kv_scales_cache = policy.calibrate_qkv_fp8_scales( From 4f1324a8dcd7e160988177ca90c200512c80f187 Mon Sep 17 00:00:00 2001 From: Zhaopeng Qiu Date: Wed, 26 Nov 2025 17:59:46 -0800 Subject: [PATCH 33/40] remove l1 test; upload missed recipe yaml Signed-off-by: Zhaopeng Qiu --- ...en3-8b-base-1n8g-fp8-kvcache-megatron.yaml | 49 +++++++++++++++++++ tests/functional/L1_Functional_Tests_GPU.sh | 1 - tests/functional/grpo_fp8_kvcache.sh | 41 ---------------- 3 files changed, 49 insertions(+), 42 deletions(-) create mode 100644 examples/configs/recipes/llm/grpo-qwen3-8b-base-1n8g-fp8-kvcache-megatron.yaml delete mode 100644 tests/functional/grpo_fp8_kvcache.sh diff --git a/examples/configs/recipes/llm/grpo-qwen3-8b-base-1n8g-fp8-kvcache-megatron.yaml b/examples/configs/recipes/llm/grpo-qwen3-8b-base-1n8g-fp8-kvcache-megatron.yaml new file mode 100644 index 0000000000..78b4597c2c --- /dev/null +++ b/examples/configs/recipes/llm/grpo-qwen3-8b-base-1n8g-fp8-kvcache-megatron.yaml @@ -0,0 +1,49 @@ +defaults: ../../grpo_math_1B.yaml +grpo: + val_period: 20 +checkpointing: + enabled: false + checkpoint_dir: results/grpo_qwen3_8b_fp8_kvcache +loss_fn: + use_importance_sampling_correction: true +policy: + model_name: Qwen/Qwen3-8B-Base + train_micro_batch_size: 1 + logprob_batch_size: 1 + max_total_sequence_length: 8192 + dtensor_cfg: + enabled: false + optimizer: null + scheduler: null + megatron_cfg: + enabled: true + converter_type: Qwen3ForCausalLM + tensor_model_parallel_size: 4 + optimizer: + lr: 1.0e-06 + min_lr: 1.0e-06 + weight_decay: 0.1 + use_precision_aware_optimizer: false + scheduler: + lr_decay_iters: null + lr_warmup_iters: 10 + lr_warmup_init: 1.0e-07 + make_sequence_length_divisible_by: ${mul:${policy.megatron_cfg.tensor_model_parallel_size}, + 2} + generation: + vllm_cfg: + precision: fp8 + kv_cache_dtype: fp8 + use_deep_gemm: true +data: + max_input_seq_length: 2048 + prompt_file: null + dataset_name: DAPOMath17K +env: + dapo: + num_workers: 16 + math: + num_workers: 16 + math_verify_impl: dapo_math_verify +cluster: + gpus_per_node: 8 diff --git a/tests/functional/L1_Functional_Tests_GPU.sh b/tests/functional/L1_Functional_Tests_GPU.sh index fb1fb2101d..99fc24c52b 100644 --- a/tests/functional/L1_Functional_Tests_GPU.sh +++ b/tests/functional/L1_Functional_Tests_GPU.sh @@ -25,7 +25,6 @@ time uv run --no-sync bash ./tests/functional/grpo_async.sh time uv run --no-sync bash ./tests/functional/grpo_megatron.sh time uv run --no-sync bash ./tests/functional/grpo_multiturn.sh time uv run --no-sync bash ./tests/functional/grpo_non_colocated.sh -time uv run --no-sync bash ./tests/functional/grpo_fp8_kvcache.sh time uv run --no-sync bash ./tests/functional/dpo.sh time uv run --no-sync bash ./tests/functional/rm.sh time uv run --no-sync bash ./tests/functional/eval.sh diff --git a/tests/functional/grpo_fp8_kvcache.sh b/tests/functional/grpo_fp8_kvcache.sh deleted file mode 100644 index 0132f0e11d..0000000000 --- a/tests/functional/grpo_fp8_kvcache.sh +++ /dev/null @@ -1,41 +0,0 @@ -#!/bin/bash - -SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) -PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) -# Mark the current repo as safe, since wandb fetches metadata about the repo -git config --global --add safe.directory $PROJECT_ROOT - -set -eou pipefail - -EXP_NAME=$(basename $0 .sh) -EXP_DIR=$SCRIPT_DIR/$EXP_NAME -LOG_DIR=$EXP_DIR/logs -JSON_METRICS=$EXP_DIR/metrics.json -RUN_LOG=$EXP_DIR/run.log -export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} - -rm -rf $EXP_DIR $LOG_DIR -mkdir -p $EXP_DIR $LOG_DIR - -cd $PROJECT_ROOT -uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \ - $PROJECT_ROOT/examples/run_grpo_math.py \ - --config $PROJECT_ROOT/examples/configs/recipes/llm/grpo-qwen3-8b-base-1n8g-fp8-kvcache-megatron.yaml \ - policy.model_name=Qwen/Qwen3-0.6B-Base \ - cluster.gpus_per_node=2 \ - grpo.max_num_steps=3 \ - policy.megatron_cfg.tensor_model_parallel_size=2 \ - policy.megatron_cfg.scheduler.lr_warmup_iters=0 \ - policy.generation.vllm_cfg.use_deep_gemm=false \ - logger.tensorboard_enabled=true \ - logger.log_dir=$LOG_DIR \ - logger.wandb_enabled=false \ - logger.monitor_gpus=true \ - checkpointing.enabled=false \ - $@ \ - 2>&1 | tee $RUN_LOG - -uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS - -uv run tests/check_metrics.py $JSON_METRICS \ - 'data["train/token_mult_prob_error"]["3"] < 1.5' From 94d16ec98520e1e7a7d8ef68109b8f91cc5a7577 Mon Sep 17 00:00:00 2001 From: Zhaopeng Qiu Date: Wed, 26 Nov 2025 19:51:04 -0800 Subject: [PATCH 34/40] resolve fp8 patch conflicts Signed-off-by: Zhaopeng Qiu --- nemo_rl/models/generation/fp8.py | 143 +++---------------------------- 1 file changed, 12 insertions(+), 131 deletions(-) diff --git a/nemo_rl/models/generation/fp8.py b/nemo_rl/models/generation/fp8.py index 116c224920..1dddfe6c24 100644 --- a/nemo_rl/models/generation/fp8.py +++ b/nemo_rl/models/generation/fp8.py @@ -107,112 +107,6 @@ def patched_run_workers(self, *args, **kwargs): fp8_patches_applied = True -def kv_cache_process_weights_after_loading(self, layer: torch.nn.Module) -> None: - """Modified version of BaseKVCacheMethod.process_weights_after_loading. - - Doesn't delete k_scale, v_scale, q_scale, and prob_scale parameters to allow - for dynamic updates during refit. - """ - import torch - from vllm.logger import init_logger - from vllm.platforms import current_platform - - logger = init_logger(__name__) - - # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 - # regardless whether the kv-scale is available in the checkpoint. - if layer.kv_cache_dtype != "auto": - if layer.k_scale > 0.0 and layer.v_scale > 0.0: - # We prefer to use separate k_scale and v_scale if present - k_scale = layer.k_scale.to("cpu").tolist() - v_scale = layer.v_scale.to("cpu").tolist() - if current_platform.is_fp8_fnuz(): - k_scale *= 2 - v_scale *= 2 - elif layer.k_scale < 0.0 and layer.v_scale < 0.0: - # If no scales were loaded (both scales are invalid negative - # values), use the default value of 1.0 - k_scale = 1.0 - v_scale = 1.0 - else: - # If we find a single kv_scale in the checkpoint, we remap - # kv_scale to k_scale during weight loading, and duplicate - # k_scale to v_scale here - assert layer.k_scale > 0.0 - scale_to_duplicate = max(layer.k_scale, layer.v_scale) - k_scale = scale_to_duplicate.to("cpu").tolist() - v_scale = scale_to_duplicate.to("cpu").tolist() - if current_platform.is_fp8_fnuz(): - k_scale *= 2 - v_scale *= 2 - - if not isinstance(k_scale, float) or not isinstance(v_scale, float): - raise ValueError("Only support per-tensor scaling factor for fp8 KV cache") - - if layer.q_scale < 0.0: - logger.warning_once( - "Checkpoint does not provide a q scaling factor. " - "Setting it to k_scale. This only matters for " - "the flash-attn backend." - ) - layer._q_scale.copy_(k_scale) - layer._q_scale_float = k_scale - - # These are used in the final Attention.forward() - layer._k_scale.copy_(k_scale) - layer._v_scale.copy_(v_scale) - layer._k_scale_float = k_scale - layer._v_scale_float = v_scale - if k_scale == 1.0 and v_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype: - logger.warning_once( - "Using KV cache scaling factor 1.0 for fp8_e4m3. This " - "may cause accuracy issues. Please make sure k/v_scale " - "scaling factors are available in the fp8 checkpoint." - ) - - if layer.q_scale > 0.0: - q_scale = layer.q_scale - if current_platform.is_fp8_fnuz(): - q_scale *= 2 - else: - q_scale = 1.0 - if layer.prob_scale > 0.0: - prob_scale = layer.prob_scale - if current_platform.is_fp8_fnuz(): - prob_scale *= 2 - else: - prob_scale = 1.0 - - is_singleton_float = ( - lambda x: isinstance(x, float) - or isinstance(x, torch.Tensor) - and x.numel() == 1 - and x.is_floating_point() - ) - if not is_singleton_float(q_scale) or not is_singleton_float(prob_scale): - raise ValueError( - "Only support per-tensor scaling factorfor fp8-quantized Q/prob" - ) - - # These are used in the final Attention.forward() - layer._q_scale.copy_(q_scale) - layer._q_scale_float = ( - q_scale.item() if isinstance(q_scale, torch.Tensor) else q_scale - ) - - layer._prob_scale.copy_(prob_scale) - if layer.kv_cache_dtype == "fp8" and (q_scale == 1.0 or prob_scale == 1.0): - logger.warning_once( - f"Using uncalibrated q_scale {q_scale} and/or prob_scale " - f"{prob_scale} with fp8 attention. This may cause accuracy " - "issues. Please make sure q/prob scaling factors are " - "available in the fp8 checkpoint." - ) - - # IMPORTANT: We DON'T delete the parameters here to allow for dynamic updates - # Original code deleted: layer.k_scale, layer.v_scale, layer.q_scale, layer.prob_scale - - def get_vllm_qkv_scale_names(layer_idx: int) -> dict[str, str]: """Get vLLM-compatible parameter names for Q/K/V FP8 scales. @@ -316,6 +210,9 @@ def apply_fp8_patches(self, fp8_config): func1_path = "vllm.model_executor.layers.quantization.fp8.Fp8LinearMethod.process_weights_after_loading" patcher1 = patch(func1_path, process_weights_after_loading) fp8_state.vllm_patches.append(patcher1) + func2_path = "vllm.model_executor.layers.quantization.fp8.Fp8MoEMethod.process_weights_after_loading" + patcher2 = patch(func2_path, process_weights_after_loading_moe) + fp8_state.vllm_patches.append(patcher2) # These patches add support for pow2, e8 dynamic activation scalings factors which are believed to have higher # SNR compared to plain fp32 scaling factors. This feature is still under active research. @@ -332,32 +229,8 @@ def apply_fp8_patches(self, fp8_config): if global_fp8_config.kv_cache_dtype == "fp8": # Static scales mode: patch process_weights_after_loading to preserve k_scale/v_scale for manual updates func5_path = "vllm.model_executor.layers.quantization.kv_cache.BaseKVCacheMethod.process_weights_after_loading" - patcher5 = patch(func5_path, kv_cache_process_weights_after_loading) + patcher5 = patch(func5_path, process_weights_after_loading_kv) fp8_state.vllm_patches.append(patcher5) - # This patch is used to support torch.compile with vllm parameter subclasses, such as - # PerTensorScaleParameter. Because we need weight loaders to update fp8 weights each - # refit, we patch fp8 parameters to have a reference to their weight loader. Eventually - # with pytorch 2.8, parameter subclassing with torch.compile will be natively supported, in - # which this patch can be removed. - func1_path = "vllm.model_executor.layers.quantization.fp8.Fp8LinearMethod.process_weights_after_loading" - patcher1 = patch(func1_path, process_weights_after_loading) - fp8_state.vllm_patches.append(patcher1) - func2_path = "vllm.model_executor.layers.quantization.fp8.Fp8MoEMethod.process_weights_after_loading" - patcher2 = patch(func2_path, process_weights_after_loading_moe) - fp8_state.vllm_patches.append(patcher2) - func3_path = "vllm.model_executor.layers.quantization.kv_cache.BaseKVCacheMethod.process_weights_after_loading" - patcher3 = patch(func3_path, process_weights_after_loading_kv) - fp8_state.vllm_patches.append(patcher3) - # These patches add support for pow2, e8 dynamic activation scalings factors which are believed to have higher - # SNR compared to plain fp32 scaling factors. This feature is still under active research. - if global_fp8_config.use_activation_pow2_scale: - func2_path = "vllm.model_executor.layers.quantization.utils.fp8_utils.per_token_group_quant_fp8" - func3_path = "vllm.model_executor.layers.quantization.utils.fp8_utils._per_token_group_quant_fp8" - func4_path = "vllm.model_executor.layers.quantization.utils.fp8_utils._per_token_group_quant_fp8_colmajor" - patcher2 = patch(func2_path, per_token_group_quant_fp8) - patcher3 = patch(func3_path, _per_token_group_quant_fp8) - patcher4 = patch(func4_path, _per_token_group_quant_fp8_colmajor) - fp8_state.vllm_patches.append(patcher2, patcher3, patcher4) for p in fp8_state.vllm_patches: p.start() @@ -763,6 +636,11 @@ def process_weights_after_loading_moe(self, layer) -> None: def process_weights_after_loading_kv(self, layer) -> None: + """Modified version of BaseKVCacheMethod.process_weights_after_loading. + + Doesn't delete k_scale, v_scale, q_scale, and prob_scale parameters to allow + for dynamic updates during refit. + """ # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 # regardless whether the kv-scale is available in the checkpoint. # No need to process kv scales after loading if we are going to @@ -840,6 +718,9 @@ def process_weights_after_loading_kv(self, layer) -> None: layer._prob_scale.copy_(prob_scale) + # IMPORTANT: We DON'T delete the parameters here to allow for dynamic updates + # Original code deleted: layer.k_scale, layer.v_scale, layer.q_scale, layer.prob_scale + @triton.jit def _per_token_group_quant_fp8( From 48b20aaf1da5e8dae90466a8ac898ed72ec4026d Mon Sep 17 00:00:00 2001 From: Zhaopeng Qiu Date: Thu, 27 Nov 2025 00:42:28 -0800 Subject: [PATCH 35/40] add nightly test Signed-off-by: Zhaopeng Qiu --- ...qwen3-8b-base-1n8g-fp8-kvcache-megatron.sh | 40 +++++++++++++++++++ tests/test_suites/nightly.txt | 1 + 2 files changed, 41 insertions(+) create mode 100755 tests/test_suites/llm/grpo-qwen3-8b-base-1n8g-fp8-kvcache-megatron.sh diff --git a/tests/test_suites/llm/grpo-qwen3-8b-base-1n8g-fp8-kvcache-megatron.sh b/tests/test_suites/llm/grpo-qwen3-8b-base-1n8g-fp8-kvcache-megatron.sh new file mode 100755 index 0000000000..edcec4fa4a --- /dev/null +++ b/tests/test_suites/llm/grpo-qwen3-8b-base-1n8g-fp8-kvcache-megatron.sh @@ -0,0 +1,40 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=60 +MAX_STEPS=60 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=240 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_grpo_math.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + # With a few number of steps the logprob can have spikes that can move the average up. + # Enabling fp8 kvcache can cause the logprob to be slightly higher than fp8 linear only path, so we allow a larger tolerance. + uv run tests/check_metrics.py $JSON_METRICS \ + 'mean(data["train/token_mult_prob_error"], ignore_top_p=0.1) < 1.5' +fi diff --git a/tests/test_suites/nightly.txt b/tests/test_suites/nightly.txt index 72b4c6debc..ca4b96fd16 100644 --- a/tests/test_suites/nightly.txt +++ b/tests/test_suites/nightly.txt @@ -41,6 +41,7 @@ tests/test_suites/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.sh tests/test_suites/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-rollouts.v3.sh tests/test_suites/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-e2e.sh tests/test_suites/llm/grpo-moonlight-16ba3b-4n8g-megatron-fp8-e2e.sh +tests/test_suites/llm/grpo-qwen3-8b-base-1n8g-fp8-kvcache-megatron.sh # Non-colocated tests/test_suites/llm/grpo-llama3.1-8b-instruct-2n8g-fsdp2tp1-noncolocated.sh From 603c3662c6695a3958e85f52c8812520969200b4 Mon Sep 17 00:00:00 2001 From: Zhaopeng Qiu Date: Thu, 27 Nov 2025 02:29:20 -0800 Subject: [PATCH 36/40] increase gpu hours for new nightly test Signed-off-by: Zhaopeng Qiu --- tests/unit/test_recipes_and_test_suites.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_recipes_and_test_suites.py b/tests/unit/test_recipes_and_test_suites.py index dbdf009dae..9bf6283fdf 100644 --- a/tests/unit/test_recipes_and_test_suites.py +++ b/tests/unit/test_recipes_and_test_suites.py @@ -170,7 +170,7 @@ def test_all_recipe_yamls_accounted_for_in_test_suites( ) -def test_nightly_compute_stays_below_1040_hours(nightly_test_suite, tracker): +def test_nightly_compute_stays_below_1070_hours(nightly_test_suite, tracker): command = f"DRYRUN=1 HF_HOME=... HF_DATASETS_CACHE=... CONTAINER= ACCOUNT= PARTITION= ./tools/launch {' '.join(nightly_test_suite)}" print(f"Running command: {command}") @@ -202,8 +202,8 @@ def test_nightly_compute_stays_below_1040_hours(nightly_test_suite, tracker): f"Last line of output was not as expected: '{last_line}'" ) total_gpu_hours = float(last_line.split(":")[-1].strip()) - assert total_gpu_hours <= 1040, ( - f"Total GPU hours exceeded 1040: {last_line}. We should revisit the test suites to reduce the total GPU hours." + assert total_gpu_hours <= 1070, ( + f"Total GPU hours exceeded 1070: {last_line}. We should revisit the test suites to reduce the total GPU hours." ) tracker.track("total_nightly_gpu_hours", total_gpu_hours) From 7ca82f30b9d967826c578a46db3cffd2ded39cc2 Mon Sep 17 00:00:00 2001 From: Zhaopeng Qiu Date: Thu, 27 Nov 2025 02:32:08 -0800 Subject: [PATCH 37/40] allow a larger logprob tolerance Signed-off-by: Zhaopeng Qiu --- .../llm/grpo-qwen3-8b-base-1n8g-fp8-kvcache-megatron.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_suites/llm/grpo-qwen3-8b-base-1n8g-fp8-kvcache-megatron.sh b/tests/test_suites/llm/grpo-qwen3-8b-base-1n8g-fp8-kvcache-megatron.sh index edcec4fa4a..d1068a7ffa 100755 --- a/tests/test_suites/llm/grpo-qwen3-8b-base-1n8g-fp8-kvcache-megatron.sh +++ b/tests/test_suites/llm/grpo-qwen3-8b-base-1n8g-fp8-kvcache-megatron.sh @@ -36,5 +36,5 @@ if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | ma # With a few number of steps the logprob can have spikes that can move the average up. # Enabling fp8 kvcache can cause the logprob to be slightly higher than fp8 linear only path, so we allow a larger tolerance. uv run tests/check_metrics.py $JSON_METRICS \ - 'mean(data["train/token_mult_prob_error"], ignore_top_p=0.1) < 1.5' + 'mean(data["train/token_mult_prob_error"], ignore_top_p=0.15) < 2.0' fi From b34ad763a9bce7ede29c302b5b654ffc0af4ef1d Mon Sep 17 00:00:00 2001 From: alexchiu Date: Mon, 1 Dec 2025 16:07:26 +0800 Subject: [PATCH 38/40] update kv_cache_dtype with choices Co-authored-by: Terry Kong Signed-off-by: alexchiu --- nemo_rl/models/generation/vllm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_rl/models/generation/vllm/config.py b/nemo_rl/models/generation/vllm/config.py index c3a0171679..731f60f226 100644 --- a/nemo_rl/models/generation/vllm/config.py +++ b/nemo_rl/models/generation/vllm/config.py @@ -28,7 +28,7 @@ class VllmSpecificArgs(TypedDict): async_engine: bool load_format: NotRequired[str] precision: NotRequired[str] - kv_cache_dtype: NotRequired[str] + kv_cache_dtype: Literal["auto", "fp8"] enforce_eager: NotRequired[bool] # By default, NeMo RL only has a Python handle to the vllm.LLM generation engine. The expose_http_server flag here will expose that generation engine as an HTTP server. # Exposing vLLM as a server is useful in instances where the multi-turn rollout is performed with utilities outside of NeMo RL, but the user still wants to take advantage of the refit logic in NeMo RL that keeps the policy and generation up to date. From 40fa1acb3680f042bfa5385e7a8fb05534297eff Mon Sep 17 00:00:00 2001 From: Zhaopeng Qiu Date: Mon, 1 Dec 2025 23:47:01 -0800 Subject: [PATCH 39/40] add default kv_cache_dtype; update checking logic code Signed-off-by: Zhaopeng Qiu --- examples/configs/distillation_math.yaml | 1 + examples/configs/grpo_math_1B.yaml | 1 + examples/configs/vlm_grpo_3B.yaml | 1 + examples/configs/vlm_grpo_3B_megatron.yaml | 1 + nemo_rl/algorithms/grpo.py | 14 ++++++-------- nemo_rl/models/generation/fp8.py | 13 +++++++++---- nemo_rl/models/generation/vllm/config.py | 4 ++-- nemo_rl/models/generation/vllm/vllm_worker.py | 2 +- .../models/policy/megatron_policy_worker.py | 18 +++++++++++++----- 9 files changed, 35 insertions(+), 20 deletions(-) diff --git a/examples/configs/distillation_math.yaml b/examples/configs/distillation_math.yaml index b77c6d3893..b4082a969c 100644 --- a/examples/configs/distillation_math.yaml +++ b/examples/configs/distillation_math.yaml @@ -173,6 +173,7 @@ policy: &POLICY_BASE vllm_cfg: async_engine: false precision: ${...precision} + kv_cache_dtype: "auto" tensor_parallel_size: 1 pipeline_parallel_size: 1 expert_parallel_size: 1 # When EP > 1, EP must be a multiple of TP since vLLM's EP = DP * TP diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index e380c61bd2..af5d3d7335 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -217,6 +217,7 @@ policy: vllm_cfg: async_engine: false precision: ${policy.precision} + kv_cache_dtype: "auto" tensor_parallel_size: 1 pipeline_parallel_size: 1 expert_parallel_size: 1 # When EP > 1, EP must be a multiple of TP since vLLM's EP = DP * TP diff --git a/examples/configs/vlm_grpo_3B.yaml b/examples/configs/vlm_grpo_3B.yaml index 4e21205491..6b9d3ac077 100644 --- a/examples/configs/vlm_grpo_3B.yaml +++ b/examples/configs/vlm_grpo_3B.yaml @@ -206,6 +206,7 @@ policy: vllm_cfg: async_engine: false # Only for internal testing, will be enabled by https://github.com/NVIDIA/NeMo-RL/issues/447. precision: ${policy.precision} + kv_cache_dtype: "auto" tensor_parallel_size: 1 pipeline_parallel_size: 1 expert_parallel_size: 1 diff --git a/examples/configs/vlm_grpo_3B_megatron.yaml b/examples/configs/vlm_grpo_3B_megatron.yaml index dd206d75ac..49c29c2138 100644 --- a/examples/configs/vlm_grpo_3B_megatron.yaml +++ b/examples/configs/vlm_grpo_3B_megatron.yaml @@ -106,6 +106,7 @@ policy: vllm_cfg: async_engine: false precision: ${policy.precision} + kv_cache_dtype: "auto" tensor_parallel_size: 1 pipeline_parallel_size: 1 expert_parallel_size: 1 diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index d972db2ffa..58e8347f61 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -500,10 +500,10 @@ def init_vllm(): assert loss_config["use_importance_sampling_correction"] is True, ( "Importance sampling must be enabled for vLLM FP8 generation for good convergence!" ) - if generation_config["vllm_cfg"].get("kv_cache_dtype") == "fp8": + if generation_config["vllm_cfg"]["kv_cache_dtype"].startswith("fp8"): # FP8 KV cache requires FP8 model precision assert generation_config["vllm_cfg"]["precision"] == "fp8", ( - "kv_cache_dtype='fp8' requires precision='fp8'. " + f"kv_cache_dtype='{generation_config['vllm_cfg']['kv_cache_dtype']}' requires precision='fp8'. " "FP8 KV cache can only be used together with FP8 model weights." ) # FP8 KV cache compatibility checks @@ -900,15 +900,13 @@ def _should_sync_kv_scales(master_config: MasterConfig) -> bool: if generation_config is None: return False - backend = generation_config.get("backend", "") - if backend != "vllm": + if generation_config["backend"] != "vllm": return False - vllm_cfg = generation_config.get("vllm_cfg", {}) - kv_cache_dtype = vllm_cfg.get("kv_cache_dtype", "auto") + vllm_cfg = cast(VllmConfig, generation_config)["vllm_cfg"] - # Sync scales when using FP8 KV cache (always static in this design) - return kv_cache_dtype == "fp8" + # Sync scales when using FP8 KV cache + return vllm_cfg["kv_cache_dtype"].startswith("fp8") def refit_policy_generation( diff --git a/nemo_rl/models/generation/fp8.py b/nemo_rl/models/generation/fp8.py index 1dddfe6c24..0939e3582b 100644 --- a/nemo_rl/models/generation/fp8.py +++ b/nemo_rl/models/generation/fp8.py @@ -243,12 +243,18 @@ def init_fp8(vllm_cfg, model_name, model_parallel_size): global global_fp8_config # Determine if we're using FP8 weights based on precision setting use_fp8_weights = vllm_cfg.get("precision") == "fp8" - kv_cache_dtype = vllm_cfg.get("kv_cache_dtype", "auto") + kv_cache_dtype = vllm_cfg["kv_cache_dtype"] + + # Validate configuration: kv_cache_dtype + if kv_cache_dtype not in ["auto", "fp8", "fp8_e4m3"]: + raise ValueError( + f"kv_cache_dtype must be one of ['auto', 'fp8', 'fp8_e4m3'], but got {kv_cache_dtype}" + ) # Validate configuration: kv_cache_dtype=fp8 requires precision=fp8 - if kv_cache_dtype == "fp8" and not use_fp8_weights: + if kv_cache_dtype.startswith("fp8") and not use_fp8_weights: raise ValueError( - "kv_cache_dtype='fp8' requires precision='fp8'. " + f"kv_cache_dtype='{kv_cache_dtype}' requires precision='fp8'. " "FP8 KV cache can only be used together with FP8 model weights." ) @@ -322,7 +328,6 @@ def init_fp8(vllm_cfg, model_name, model_parallel_size): print("ignored_layers", fp8_block_quant_kwargs["ignored_layers"]) # Return FP8 kwargs (precision=fp8 is required at this point) - # kv_cache_dtype can be "auto" or "fp8" vllm_kwargs = { "quantization": "fp8", "kv_cache_dtype": kv_cache_dtype, diff --git a/nemo_rl/models/generation/vllm/config.py b/nemo_rl/models/generation/vllm/config.py index 731f60f226..65a174b2bc 100644 --- a/nemo_rl/models/generation/vllm/config.py +++ b/nemo_rl/models/generation/vllm/config.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, NotRequired, TypedDict +from typing import Any, Literal, NotRequired, TypedDict from nemo_rl.models.generation.interfaces import GenerationConfig @@ -28,7 +28,7 @@ class VllmSpecificArgs(TypedDict): async_engine: bool load_format: NotRequired[str] precision: NotRequired[str] - kv_cache_dtype: Literal["auto", "fp8"] + kv_cache_dtype: Literal["auto", "fp8", "fp8_e4m3"] enforce_eager: NotRequired[bool] # By default, NeMo RL only has a Python handle to the vllm.LLM generation engine. The expose_http_server flag here will expose that generation engine as an HTTP server. # Exposing vLLM as a server is useful in instances where the multi-turn rollout is performed with utilities outside of NeMo RL, but the user still wants to take advantage of the refit logic in NeMo RL that keeps the policy and generation up to date. diff --git a/nemo_rl/models/generation/vllm/vllm_worker.py b/nemo_rl/models/generation/vllm/vllm_worker.py index 94dec10aac..6c0d3577e9 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker.py +++ b/nemo_rl/models/generation/vllm/vllm_worker.py @@ -287,7 +287,7 @@ def _patch_vllm_init_workers_ray(): vllm_kwargs["ray_workers_use_nsight"] = True # Call init_fp8 when precision is fp8 - # (kv_cache_dtype can be fp8 or auto, validated in init_fp8) + # (kv_cache_dtype can be fp8/fp8_e4m3 or auto, validated in init_fp8) if self.cfg["vllm_cfg"]["precision"] == "fp8": from nemo_rl.models.generation.fp8 import init_fp8 diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 8da1427dc3..35d9011c14 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -20,7 +20,7 @@ from collections import defaultdict from contextlib import AbstractContextManager, contextmanager, nullcontext from functools import partial -from typing import Any, Iterator, Optional, TypeVar +from typing import Any, Iterator, Optional, TypeVar, cast import ray import torch @@ -118,6 +118,7 @@ GenerationOutputSpec, verify_right_padding, ) +from nemo_rl.models.generation.vllm.config import VllmConfig from nemo_rl.models.megatron.common import ( _pack_sequences_for_megatron, broadcast_tensor, @@ -2021,10 +2022,17 @@ def _iter_params_with_optional_kv_scales( # Check whether FP8 KV cache is enabled. use_fp8_kv_cache = False - if "generation" in self.cfg and self.cfg["generation"] is not None: - vllm_cfg = self.cfg["generation"].get("vllm_cfg", {}) - kv_cache_dtype = vllm_cfg.get("kv_cache_dtype", "auto") - use_fp8_kv_cache = kv_cache_dtype == "fp8" + if ( + "generation" in self.cfg + and self.cfg["generation"] is not None + and self.cfg["generation"]["backend"] == "vllm" + ): + generation_cfg = cast(VllmConfig, self.cfg["generation"]) + use_fp8_kv_cache = ( + "vllm_cfg" in generation_cfg + and "kv_cache_dtype" in generation_cfg["vllm_cfg"] + and generation_cfg["vllm_cfg"]["kv_cache_dtype"].startswith("fp8") + ) if not use_fp8_kv_cache: return From db3ea88624aad10b6f4c14749a355374eab4b4b9 Mon Sep 17 00:00:00 2001 From: Zhaopeng Qiu Date: Tue, 2 Dec 2025 01:41:16 -0800 Subject: [PATCH 40/40] add requires_kv_scale_sync property to GenerationInterface Signed-off-by: Zhaopeng Qiu --- nemo_rl/algorithms/grpo.py | 27 +++---------------- nemo_rl/models/generation/interfaces.py | 5 ++++ .../models/generation/vllm/vllm_generation.py | 10 +++++++ 3 files changed, 19 insertions(+), 23 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index cd13a55826..8d2f06b431 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -889,27 +889,6 @@ def _should_use_penguin(master_config: MasterConfig) -> bool: return should_use_penguin -# Function to check if KV cache scales should be calculated and synchronized during refit -def _should_sync_kv_scales(master_config: MasterConfig) -> bool: - """Check if KV cache scales should be synchronized during refit. - - Returns True if kv_cache_dtype is fp8 (which requires precision=fp8). - KV scales are always computed and synced statically during training - when using FP8 KV cache. - """ - generation_config = master_config["policy"]["generation"] - if generation_config is None: - return False - - if generation_config["backend"] != "vllm": - return False - - vllm_cfg = cast(VllmConfig, generation_config)["vllm_cfg"] - - # Sync scales when using FP8 KV cache - return vllm_cfg["kv_cache_dtype"].startswith("fp8") - - def refit_policy_generation( policy: ColocatablePolicyInterface, policy_generation: GenerationInterface, @@ -1014,8 +993,6 @@ def grpo_train( ) timeout.start_iterations() - # Check if we need to sync KV cache scales (infer from config) - sync_kv_scales = _should_sync_kv_scales(master_config) kv_scales_cache = None # Cache reused for computed kv scales NEED_REFIT = True @@ -1026,6 +1003,10 @@ def grpo_train( POLICY_GENERATION_STALE = True # tracks if generation needs a refit before running assert policy_generation is not None # for mypy type check + # Check if we need to sync KV cache scales + # When fallback to policy as the policy_generation, we use getattr to check. + sync_kv_scales = getattr(policy_generation, "requires_kv_scale_sync", False) + # common config/state itmes current_step = grpo_save_state["current_step"] # current step within an epoch total_steps = grpo_save_state["total_steps"] # total steps across all epochs diff --git a/nemo_rl/models/generation/interfaces.py b/nemo_rl/models/generation/interfaces.py index f7f58b383f..d134027bdf 100644 --- a/nemo_rl/models/generation/interfaces.py +++ b/nemo_rl/models/generation/interfaces.py @@ -236,6 +236,11 @@ def prepare_for_generation(self, *args: Any, **kwargs: Any) -> bool: def finish_generation(self, *args: Any, **kwargs: Any) -> bool: pass + @property + def requires_kv_scale_sync(self) -> bool: + """Whether the generation backend requires KV cache scales synchronization.""" + return False + def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None: """Prepare the info for refit.""" raise NotImplementedError diff --git a/nemo_rl/models/generation/vllm/vllm_generation.py b/nemo_rl/models/generation/vllm/vllm_generation.py index 87e480a31e..8357856464 100644 --- a/nemo_rl/models/generation/vllm/vllm_generation.py +++ b/nemo_rl/models/generation/vllm/vllm_generation.py @@ -898,3 +898,13 @@ def invalidate_kv_cache(self) -> bool: except Exception as e: print(f"Error invalidating vLLM caches: {e}") return False + + @property + def requires_kv_scale_sync(self) -> bool: + """Check if KV cache scales should be synchronized during refit. + + Returns True if kv_cache_dtype is fp8/fp8_e4m3. + """ + return "kv_cache_dtype" in self.cfg["vllm_cfg"] and self.cfg["vllm_cfg"][ + "kv_cache_dtype" + ].startswith("fp8")