From 1d2dca2b1bc6e0fa2ce7f477d52a019a2f7939d4 Mon Sep 17 00:00:00 2001 From: Xue Huang Date: Mon, 20 Oct 2025 23:12:53 -0700 Subject: [PATCH 01/18] update for async vllm --- docs/advance/fp8.md | 38 ++ docs/index.rst | 1 + .../_generated_ppo_megatron_trainer.yaml | 4 +- verl/trainer/config/ppo_megatron_trainer.yaml | 4 + verl/utils/fp8_utils.py | 358 ++++++++++++++++++ verl/workers/config/rollout.py | 4 + .../rollout/vllm_rollout/vllm_async_server.py | 19 +- .../workers/sharding_manager/megatron_vllm.py | 10 +- 8 files changed, 433 insertions(+), 5 deletions(-) create mode 100644 docs/advance/fp8.md create mode 100644 verl/utils/fp8_utils.py diff --git a/docs/advance/fp8.md b/docs/advance/fp8.md new file mode 100644 index 00000000000..958d3e6eb62 --- /dev/null +++ b/docs/advance/fp8.md @@ -0,0 +1,38 @@ +# FP8 for verl + +Last updated: 09/18/2025. + +This module is still in development. Currently we support FP8 rollout, using FP8 blockwise scaling (used in Deepseek, +which is 1x128 quantization for activations and 128x128 quantization for model weights). + +We monkey patches several vLLM functions to enable FP8 rollout for reinforcement learning. +1. **load_weights**: A custom `load_weights` function to quantize the on-the-fly model weights from a higher-precision format to FP8. +2. **process weights after loading**: Replace `vllm.model_executor.layers.quantization.fp8.Fp8LinearMethod.process_weights_after_loading` +function to handle model weights loading after quantization. + +**Notes**: +- Currently, we only support VLLM rollout with Megatron training. SGLang rollout with Megatron training is on the roadmap. +- Only support Batch generate sequences. +- We also support FP8 per tensor quantization, but after preliminary testing, there were issues with the accuracy and it is not recommended to use it. + +## Usage + +FP8 can be enabled in the config file `verl/trainer/config/ppo_megatron_trainer.yaml`: + +``` + rollout: + quantization: True + + use_block_quant_rollout: True +``` + +Or it can be enabled by command line: +- `actor_rollout_ref.rollout.quantization=True` +- `actor_rollout_ref.rollout.use_block_quant_rollout=True` + +## Plans + +- add performance of FP8 rollout +- add accuracy curves of FP8 rollout +- support SGLang rollout with FP8 quantization +- enable FP8 training in megatron \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index b0eebbb3fdf..b0c2a63ae74 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -131,6 +131,7 @@ verl is fast with: advance/fully_async data/transfer_queue.md advance/grafana_prometheus.md + advance/fp8.md .. toctree:: :maxdepth: 1 diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml index 22faa9a495d..85ef78b99cc 100644 --- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -265,6 +265,8 @@ actor_rollout_ref: port: 9090 file: /tmp/ray/session_latest/metrics/prometheus/prometheus.yml served_model_name: ${oc.select:actor_rollout_ref.model.path,null} + quantization: false + use_block_quant_rollout: true layer_name_map: qkv_layer_name: qkv gate_proj_layer_name: gate_up @@ -286,8 +288,6 @@ data: use_shm: false train_files: ~/data/rlhf/gsm8k/train.parquet val_files: ~/data/rlhf/gsm8k/test.parquet - train_max_samples: -1 - val_max_samples: -1 prompt_key: prompt reward_fn_key: data_source max_prompt_length: 512 diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index 0815451d178..0c1b77afcd5 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -45,6 +45,10 @@ actor_rollout_ref: use_remove_padding: false rollout: + quantization: False + + use_block_quant_rollout: True + layer_name_map: qkv_layer_name: qkv gate_proj_layer_name: gate_up diff --git a/verl/utils/fp8_utils.py b/verl/utils/fp8_utils.py new file mode 100644 index 00000000000..c6b15797021 --- /dev/null +++ b/verl/utils/fp8_utils.py @@ -0,0 +1,358 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from dataclasses import dataclass, field +from typing import Optional +from unittest.mock import patch + +import torch + +try: + from vllm._custom_ops import scaled_fp8_quant + from vllm.model_executor.layers.linear import LinearBase +except ImportError as e: + raise ImportError("FP8 quantization not available") from e + +logger = logging.getLogger(__name__) + +FP8_BLOCK_QUANT_KWARGS = { + "activation_scheme": "dynamic", + "fmt": "e4m3", + "quant_method": "fp8", + "weight_block_size": [128, 128], +} + + +# Ref: https://github.com/NVIDIA-NeMo/RL/commit/bc24887c72a6e1b2699a228bc87c588546dfe6b7 +@dataclass() +class FP8State: + # A cache of fp8 parameter names, we can check this cache to see if a + # param name corresponds to a fp8 weight + seen_params: set = field(default_factory=lambda: set()) + fp8_param_names: set = field(default_factory=lambda: set()) + vllm_patches: list = field(default_factory=lambda: []) + + +fp8_state: FP8State = FP8State() + + +def is_fp8_model(vllm_config): + from vllm.model_executor.layers.quantization.fp8 import Fp8Config + + if hasattr(vllm_config, "quant_config") and isinstance(vllm_config.quant_config, Fp8Config): + return True + + return False + + +def get_module_from_param_name(model, name: str): + # Split the name into parts (e.g., 'layers', '0', 'self_attn', 'q_proj', 'weight') + # The module path is all but the last part (the parameter's own name) + path_parts = name.split(".") + module_path = path_parts[:-1] + # Replace with the fused model name + packed_modules_mapping = model.packed_modules_mapping + reversed_mapping = { + original_name: fused_name + for fused_name, original_names_list in packed_modules_mapping.items() + for original_name in original_names_list + } + if module_path[-1] in reversed_mapping.keys(): + module_path[-1] = reversed_mapping[module_path[-1]] + + current_module = model + try: + # Traverse the model hierarchy + for part in module_path: + if isinstance(current_module, torch.nn.ModuleList): + current_module = current_module[int(part)] + else: + current_module = getattr(current_module, part) + except (AttributeError, IndexError, ValueError) as e: + print(f"Warning: Could not find module for parameter '{name}'. Error: {e}") + return current_module + + +def is_fp8_weight(name, model): + if name not in fp8_state.seen_params: + fp8_state.seen_params.add(name) + # Filter out bias params + if name.endswith("weight"): + module = get_module_from_param_name(model, name) + # We currently only quantize linear layers + if isinstance(module, LinearBase) and module.weight.dtype == torch.float8_e4m3fn: + fp8_state.fp8_param_names.add(name) + return name in fp8_state.fp8_param_names + + +def scaled_fp8_blockwise( + data_hp, + weight_block_size, +): + # cast tensor from high precision to FP8 with 128*128 blockwise quantization. + assert len(data_hp.shape) == 2, "Only 2d input tensor is supported" + + block_size1 = weight_block_size[1] + block_size0 = weight_block_size[0] + assert data_hp.shape[1] % block_size1 == 0, ( + f"data_hp.shape[1] {data_hp.shape[1]} must be a multiple of block_size1: {block_size1}." + ) + assert data_hp.shape[0] % block_size0 == 0, ( + f"data_hp.shape[0] {data_hp.shape[0]} must be a multiple of block_size0: {block_size0}." + ) + + # FP8 + max_dtype = torch.finfo(torch.float8_e4m3fn).max + + original_shape = data_hp.shape + blk_m, blk_n = data_hp.shape[0] // block_size0, data_hp.shape[1] // block_size1 + + assert block_size1 == block_size0 + data_hp = data_hp.reshape(blk_m, block_size0, blk_n, block_size1) + + # Permute to (BLK_M, BLK_N, BLOCK_SIZE_M, BLOCK_SIZE_N) + data_hp = data_hp.permute(0, 2, 1, 3) + # Flatten to (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) + data_hp = data_hp.to(torch.float32).contiguous().flatten(start_dim=2) + + # Calculate max absolute value per block + max_abs = torch.amax(torch.abs(data_hp), dim=-1, keepdim=True) + + # Use FP32 scale + scale_fp = max_dtype / max_abs + scale_fp = torch.where(max_abs == 0, 1.0, scale_fp) + # preserve the behavior for 0 amax case + scale_fp = torch.where(max_abs == torch.inf, 1.0, scale_fp) + + descale_fp = torch.reciprocal(scale_fp) + + # Scale and saturate cast the data elements to max of target dtype + data_lp = torch.clamp(data_hp * scale_fp, min=-1 * max_dtype, max=max_dtype) + + fp_data = data_lp.to(torch.float8_e4m3fn) + + # (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) to (M, N) + fp_data = fp_data.reshape(blk_m, blk_n, block_size0, block_size1).permute(0, 2, 1, 3).reshape(original_shape) + + # Convert to target format, but still in original precision container + return fp_data, descale_fp + + +def quant_weights(weights, model, quant_config): + weights_quantized = [] + for k, v in weights: + if not is_fp8_weight(k, model): + weights_quantized.append((k, v)) + continue + # Cast the weight into fp8 and its scale factor + if quant_config.weight_block_size is not None: + logger.info("Using blockwise quantization") + param_lp, param_scale = scaled_fp8_blockwise( + v.to(torch.float), + weight_block_size=quant_config.weight_block_size, + ) + param_scale = param_scale.squeeze(-1) + weights_quantized.append([k, param_lp]) + weights_quantized.append([k + "_scale_inv", param_scale]) + + else: + logger.info("Using Per tensor quantization") + original_shape = v.shape + # Use per tensor quantization + quantized_tensor, scale = scaled_fp8_quant(v) + # Reshape back to original shape + quantized_tensor = quantized_tensor.view(original_shape) + + scale_k = k.replace(".weight", ".weight_scale") + scale = scale.view(1) + weights_quantized.extend([(k, quantized_tensor), (scale_k, scale)]) + + return weights_quantized + + +def load_quanted_weights(weights, model_runner): + model = model_runner.model + quant_config = model_runner.vllm_config.quant_config + + weights_quantized = quant_weights(weights, model, quant_config) + + # Monkey patch the param class to their subclass, as certain models + # will check the param type to call the proper weightloader + for name, param in model.named_parameters(): + if hasattr(param, "subclass_type"): + param.orig_type = param.__class__ + param.__class__ = param.subclass_type + # Finally load the weights into vllm + loaded_params = model.load_weights(weights_quantized) + # Undo the type change above to the original type + for name, param in model.named_parameters(): + if hasattr(param, "subclass_type"): + param.__class__ = param.orig_type + return loaded_params + + +def process_weights_after_loading(self, layer) -> None: + logger.debug("Applying patch process_weights_after_loading") + try: + from vllm.model_executor.layers.quantization.utils.w8a8_utils import requantize_with_max_scale + from vllm.model_executor.parameter import ( + BlockQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter, + ) + except Exception: + try: + from sglang.srt.layers.parameter import ( + BlockQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter, + ) + from sglang.srt.layers.quantization.utils import requantize_with_max_scale + except Exception: + print("error") + from torch.nn import Parameter + + def _create_param_from_subclass_attributes(custom_param): + param = Parameter(custom_param.data, requires_grad=False) + base_param_dir = dir(torch.nn.Parameter) + custom_param_dir = dir(custom_param) + # Find the attributes that are unique to the custom parameter + custom_attributes = [ + attr for attr in custom_param_dir if attr not in base_param_dir and not attr.startswith("__") + ] + # Set the custom attributes into the base parameter object + for attr in custom_attributes: + setattr(param, attr, getattr(custom_param, attr)) + + param.subclass_type = type(custom_param) + return param + + if self.block_quant: + assert self.block_quant and self.quant_config.is_checkpoint_fp8_serialized + assert self.quant_config.activation_scheme == "dynamic" + weight = layer.weight.data + weight_scale_inv = layer.weight_scale_inv.data + weight = self._maybe_pad_weight(weight) + + layer.weight = _create_param_from_subclass_attributes( + ModelWeightParameter( + data=weight, + output_dim=0, + input_dim=1, + weight_loader=layer.weight.weight_loader, + ) + ) + layer.weight_scale_inv = _create_param_from_subclass_attributes( + BlockQuantScaleParameter( + data=weight_scale_inv, + output_dim=0, + input_dim=1, + weight_loader=layer.weight_scale_inv.weight_loader, + ) + ) + + else: + weight = layer.weight.data + weight_scale = layer.weight_scale.data + + # # If using w8a8, torch._scaled_mm needs per tensor, so + # # requantize the logical shards as a single weight. + if not self.use_marlin: + # Dequant -> Quant with max scale so we can run per tensor. + + weight_scale, weight = requantize_with_max_scale( + weight=weight, + weight_scale=weight_scale, + logical_widths=layer.logical_widths, + ) + + weight = self._maybe_pad_weight(weight) + # Update layer with new values. + # layer.weight = Parameter(weight.t(), requires_grad=False) + # layer.weight_scale = Parameter(weight_scale, requires_grad=False) + + layer.weight = _create_param_from_subclass_attributes( + ModelWeightParameter( + data=weight, + output_dim=0, + input_dim=1, + weight_loader=layer.weight.weight_loader, + ) + ) + layer.weight_scale = _create_param_from_subclass_attributes( + PerTensorScaleParameter( + data=weight_scale.repeat(len(layer.logical_widths)), + weight_loader=layer.weight_scale.weight_loader, + ) + ) + + +def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import apply_fp8_marlin_linear + from vllm.model_executor.layers.quantization.utils.w8a8_utils import requantize_with_max_scale + + if self.use_marlin: + return apply_fp8_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias, + ) + + if self.block_quant: + assert self.quant_config.weight_block_size is not None + return torch.ops.vllm.apply_w8a8_block_fp8_linear( + input=x, + weight=layer.weight, + block_size=self.quant_config.weight_block_size, + weight_scale=layer.weight_scale_inv, + input_scale=layer.input_scale, + bias=bias, + cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, + use_aiter_and_is_supported=self.use_aiter_and_is_supported, + ) + + weight_scale, weight = requantize_with_max_scale( + weight=layer.weight, + weight_scale=layer.weight_scale, + logical_widths=layer.logical_widths, + ) + return self.fp8_linear.apply( + input=x, + weight=weight.t(), + weight_scale=weight_scale, + out_dtype=self.out_dtype, + input_scale=layer.input_scale, + bias=bias, + ) + + +def apply_vllm_fp8_patches(block_quant=True): + print("xueh apply_vllm_fp8_patches") + if block_quant: + func_path = "vllm.model_executor.layers.quantization.fp8.Fp8LinearMethod.process_weights_after_loading" + patcher = patch(func_path, process_weights_after_loading) + patcher.start() + else: + func1_path = "vllm.model_executor.layers.quantization.fp8.Fp8LinearMethod.process_weights_after_loading" + patcher1 = patch(func1_path, process_weights_after_loading) + patcher1.start() + func2_path = "vllm.model_executor.layers.quantization.fp8.Fp8LinearMethod.apply" + patcher2 = patch(func2_path, apply) + patcher2.start() diff --git a/verl/workers/config/rollout.py b/verl/workers/config/rollout.py index b3f2aca3a95..cd92bc24df4 100644 --- a/verl/workers/config/rollout.py +++ b/verl/workers/config/rollout.py @@ -198,6 +198,10 @@ class RolloutConfig(BaseConfig): skip_tokenizer_init: bool = False + quantization: bool = False + + use_block_quant_rollout: bool = False + def __post_init__(self): """Validate the rollout config""" if self.expert_parallel_size > 1: diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index 287c8dc4d96..975e6d27bd2 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -47,6 +47,7 @@ from verl.workers.rollout.replica import RolloutMode, RolloutReplica, TokenOutput from verl.workers.rollout.utils import get_free_port, is_valid_ipv6_address, run_unvicorn from verl.workers.rollout.vllm_rollout import vLLMAsyncRollout +from verl.utils.fp8_utils import apply_vllm_fp8_patches from verl.workers.rollout.vllm_rollout.utils import ( VLLM_LORA_INT_ID, VLLM_LORA_NAME, @@ -215,7 +216,18 @@ async def launch_server(self, master_address: str = None, master_port: int = Non max_new_tokens=self.config.response_length, ) logger.info(f"override_generation_config: {override_generation_config}") - + quantization = self.config.quantization + use_block_quant = self.config.use_block_quant_rollout + if quantization: + if use_block_quant: + FP8_BLOCK_QUANT_KWARGS = { + "activation_scheme": "dynamic", + "fmt": "e4m3", + "quant_method": "fp8", + "weight_block_size": [128, 128], + } + fp8_block_quant_kwargs = dict(FP8_BLOCK_QUANT_KWARGS) + apply_vllm_fp8_patches(block_quant=use_block_quant) args = { "dtype": self.config.dtype, "load_format": self.config.load_format, @@ -234,6 +246,8 @@ async def launch_server(self, master_address: str = None, master_port: int = Non "tensor_parallel_size": self.config.tensor_model_parallel_size, "seed": self.config.get("seed", 0), "override_generation_config": json.dumps(override_generation_config), + "quantization": "fp8" if quantization else None, + "hf_overrides": {"quantization_config": fp8_block_quant_kwargs} if quantization and use_block_quant else None, **engine_kwargs, } @@ -284,7 +298,8 @@ async def launch_server(self, master_address: str = None, master_port: int = Non server_args.append(f"--{k}") else: server_args.append(f"--{k}") - server_args.append(str(v)) + # Use json.dumps for dict to ensure valid JSON format + server_args.append(json.dumps(v) if isinstance(v, dict) else str(v)) if self.replica_rank == 0: pprint(server_args) diff --git a/verl/workers/sharding_manager/megatron_vllm.py b/verl/workers/sharding_manager/megatron_vllm.py index 6adc89c0985..f26f396c923 100644 --- a/verl/workers/sharding_manager/megatron_vllm.py +++ b/verl/workers/sharding_manager/megatron_vllm.py @@ -31,6 +31,7 @@ from verl.third_party.vllm import LLM, VLLM_SLEEP_LEVEL from verl.third_party.vllm import parallel_state as vllm_ps from verl.utils.device import get_torch_device, set_expandable_segments +from verl.utils.fp8_utils import is_fp8_model, load_quanted_weights from verl.utils.import_utils import deprecated from verl.utils.megatron_utils import load_megatron_model_to_gpu, offload_megatron_model_to_cpu, per_tensor_generator from verl.utils.memory_utils import aggressive_empty_cache @@ -172,7 +173,14 @@ def __enter__(self): from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader patch_vllm_moe_model_weight_loader(model) - loaded_params = model.load_weights(per_tensor_param) + + if is_fp8_model(self.model_runner.vllm_config): + # load_quanted_weights additionally casts bf16 weights into fp8 + logger.info("load weights weight quantization") + loaded_params = load_quanted_weights(per_tensor_param, self.model_runner) + else: + loaded_params = model.load_weights(weights=per_tensor_param) + info = f"vLLM load weights, loaded_params: {len(loaded_params)}" logger.info(info) From e02c3106607b5fc88f9e8d962263581bbf273b3a Mon Sep 17 00:00:00 2001 From: Xue Huang Date: Wed, 17 Sep 2025 15:12:10 +0800 Subject: [PATCH 02/18] blockwise fp8 rollout --- verl/utils/fp8_utils.py | 178 +++++++++++++++++- .../workers/sharding_manager/megatron_vllm.py | 10 +- 2 files changed, 169 insertions(+), 19 deletions(-) diff --git a/verl/utils/fp8_utils.py b/verl/utils/fp8_utils.py index c6b15797021..9df5e99bad7 100644 --- a/verl/utils/fp8_utils.py +++ b/verl/utils/fp8_utils.py @@ -17,12 +17,14 @@ from dataclasses import dataclass, field from typing import Optional from unittest.mock import patch +import vllm import torch try: from vllm._custom_ops import scaled_fp8_quant from vllm.model_executor.layers.linear import LinearBase + from vllm.model_executor.layers.fused_moe.layer import FusedMoE except ImportError as e: raise ImportError("FP8 quantization not available") from e @@ -77,7 +79,9 @@ def get_module_from_param_name(model, name: str): try: # Traverse the model hierarchy for part in module_path: - if isinstance(current_module, torch.nn.ModuleList): + if isinstance(current_module, FusedMoE): + return current_module + elif isinstance(current_module, torch.nn.ModuleList): current_module = current_module[int(part)] else: current_module = getattr(current_module, part) @@ -93,7 +97,13 @@ def is_fp8_weight(name, model): if name.endswith("weight"): module = get_module_from_param_name(model, name) # We currently only quantize linear layers - if isinstance(module, LinearBase) and module.weight.dtype == torch.float8_e4m3fn: + + if ( + (isinstance(module, LinearBase) and module.weight.dtype == torch.float8_e4m3fn) + or (isinstance(module, FusedMoE) + and module.w13_weight.dtype == torch.float8_e4m3fn + and module.w2_weight.dtype == torch.float8_e4m3fn) + ): fp8_state.fp8_param_names.add(name) return name in fp8_state.fp8_param_names @@ -161,12 +171,18 @@ def quant_weights(weights, model, quant_config): if quant_config.weight_block_size is not None: logger.info("Using blockwise quantization") param_lp, param_scale = scaled_fp8_blockwise( - v.to(torch.float), + v.to(torch.bfloat16), weight_block_size=quant_config.weight_block_size, ) param_scale = param_scale.squeeze(-1) weights_quantized.append([k, param_lp]) - weights_quantized.append([k + "_scale_inv", param_scale]) + if vllm.__version__ >= "0.11.0": + if "expert" in k: + weights_quantized.append([k + "_scale_inv", param_scale]) + else: + weights_quantized.append([k + "_scale", param_scale]) + else: + weights_quantized.append([k + "_scale_inv", param_scale]) else: logger.info("Using Per tensor quantization") @@ -204,7 +220,7 @@ def load_quanted_weights(weights, model_runner): return loaded_params -def process_weights_after_loading(self, layer) -> None: +def process_weights_after_loading_for_vllm10(self, layer) -> None: logger.debug("Applying patch process_weights_after_loading") try: from vllm.model_executor.layers.quantization.utils.w8a8_utils import requantize_with_max_scale @@ -300,6 +316,146 @@ def _create_param_from_subclass_attributes(custom_param): ) +def process_weights_after_loading_for_vllm11(self, layer) -> None: + """This function is used to process the weights after loading for a Linear layer. + + Compared to the original process_weights_after_loading in vllm, we just avoid creation of + new torch.nn.Parameter objects, because that removes the weight_loader attribute which we need for refit. + """ + from torch.nn import Parameter + from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + maybe_post_process_fp8_weight_block, + process_fp8_weight_block_strategy, + ) + from vllm.model_executor.parameter import ( + BlockQuantScaleParameter, + ModelWeightParameter, + ) + + assert self.block_quant and self.quant_config.is_checkpoint_fp8_serialized + assert self.quant_config.activation_scheme == "dynamic" + + def _create_param_from_subclass_attributes(custom_param): + param = Parameter(custom_param.data, requires_grad=False) + base_param_dir = dir(torch.nn.Parameter) + custom_param_dir = dir(custom_param) + # Find the attributes that are unique to the custom parameter + custom_attributes = [ + attr + for attr in custom_param_dir + if attr not in base_param_dir and not attr.startswith("__") + ] + # Set the custom attributes into the base parameter object + for attr in custom_attributes: + setattr(param, attr, getattr(custom_param, attr)) + + param.subclass_type = type(custom_param) + return param + + weight_scale = ( + layer.weight_scale_inv + if hasattr(layer, "weight_scale_inv") + else layer.weight_scale + ) + weight, weight_scale = process_fp8_weight_block_strategy(layer.weight, weight_scale) + + layer.weight = _create_param_from_subclass_attributes( + ModelWeightParameter( + data=weight.data, + output_dim=0, + input_dim=1, + weight_loader=layer.weight.weight_loader, + ) + ) + layer.weight_scale = _create_param_from_subclass_attributes( + BlockQuantScaleParameter( + data=weight_scale.data, + output_dim=0, + input_dim=1, + weight_loader=layer.weight_scale_inv.weight_loader, + ) + ) + + del layer.weight_scale_inv + + maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported) + + +def process_weights_after_loading_moe(self, layer) -> None: + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled + from vllm.model_executor.layers.quantization.fp8 import _swap_w13_to_w31, _is_col_major + from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used + from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace) + + self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() + assert self.quant_config.activation_scheme == "dynamic" + if self.flashinfer_moe_enabled: + w13_weight = _swap_w13_to_w31(layer.w13_weight.data) + w13_weight_scale_inv = _swap_w13_to_w31( + layer.w13_weight_scale_inv.data) + w2_weight = layer.w2_weight.data + w2_weight_scale_inv = layer.w2_weight_scale_inv.data + else: + w13_weight = layer.w13_weight.data + w13_weight_scale_inv = layer.w13_weight_scale_inv.data + w2_weight = layer.w2_weight + w2_weight_scale_inv = layer.w2_weight_scale_inv + + from torch.nn import Parameter + def _create_param_from_subclass_attributes(custom_data, custom_weight): + param = Parameter(custom_data, requires_grad=False) + base_param_dir = dir(torch.nn.Parameter) + custom_weight_dir = dir(custom_weight) + # Find the attributes that are unique to the custom parameter + custom_attributes = [ + attr for attr in custom_weight_dir if attr not in base_param_dir and not attr.startswith("__") + ] + # Set the custom attributes into the base parameter object + for attr in custom_attributes: + setattr(param, attr, getattr(custom_weight, attr)) + + return param + + layer.w13_weight = _create_param_from_subclass_attributes(w13_weight, layer.w13_weight) + layer.w13_weight_scale_inv = _create_param_from_subclass_attributes(w13_weight_scale_inv, layer.w13_weight_scale_inv) + layer.w2_weight = _create_param_from_subclass_attributes(w2_weight, layer.w2_weight) + layer.w2_weight_scale_inv = _create_param_from_subclass_attributes(w2_weight_scale_inv, layer.w2_weight_scale_inv) + + # DeepGemm scales need to be transposed and aligned. We try to do + # it ahead of time for performance reasons. + if self.allow_deep_gemm and not is_blackwell_deep_gemm_used(): + # Lazy import to avoid CUDA initialization problems. + if _is_col_major(layer.w13_weight_scale_inv): + layer.w13_weight_scale_inv = \ + get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous() + if _is_col_major(layer.w2_weight_scale_inv): + layer.w2_weight_scale_inv = \ + get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous() + + if is_blackwell_deep_gemm_used(): + assert layer.weight_block_size is not None + # Re-quantise the expert weights so their scales are UE8M0. + block_sz = tuple(layer.weight_block_size) + requant_weight_ue8m0_inplace( + layer.w13_weight.data, + layer.w13_weight_scale_inv.data, + block_sz, + ) + requant_weight_ue8m0_inplace( + layer.w2_weight.data, + layer.w2_weight_scale_inv.data, + block_sz, + ) + + if _is_col_major(layer.w13_weight_scale_inv): + layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor( + layer.w13_weight_scale_inv).contiguous() + if _is_col_major(layer.w2_weight_scale_inv): + layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor( + layer.w2_weight_scale_inv).contiguous() + + def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import apply_fp8_marlin_linear from vllm.model_executor.layers.quantization.utils.w8a8_utils import requantize_with_max_scale @@ -344,14 +500,16 @@ def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Te def apply_vllm_fp8_patches(block_quant=True): - print("xueh apply_vllm_fp8_patches") if block_quant: - func_path = "vllm.model_executor.layers.quantization.fp8.Fp8LinearMethod.process_weights_after_loading" - patcher = patch(func_path, process_weights_after_loading) - patcher.start() + func1_path = "vllm.model_executor.layers.quantization.fp8.Fp8LinearMethod.process_weights_after_loading" + patcher1 = patch(func1_path, process_weights_after_loading_for_vllm11 if vllm.__version__ >= "0.11.0" else process_weights_after_loading_for_vllm10) + patcher1.start() + func2_path = "vllm.model_executor.layers.quantization.fp8.Fp8MoEMethod.process_weights_after_loading" + patcher2 = patch(func2_path, process_weights_after_loading_moe) + patcher2.start() else: func1_path = "vllm.model_executor.layers.quantization.fp8.Fp8LinearMethod.process_weights_after_loading" - patcher1 = patch(func1_path, process_weights_after_loading) + patcher1 = patch(func1_path, process_weights_after_loading_for_vllm10) patcher1.start() func2_path = "vllm.model_executor.layers.quantization.fp8.Fp8LinearMethod.apply" patcher2 = patch(func2_path, apply) diff --git a/verl/workers/sharding_manager/megatron_vllm.py b/verl/workers/sharding_manager/megatron_vllm.py index f26f396c923..6adc89c0985 100644 --- a/verl/workers/sharding_manager/megatron_vllm.py +++ b/verl/workers/sharding_manager/megatron_vllm.py @@ -31,7 +31,6 @@ from verl.third_party.vllm import LLM, VLLM_SLEEP_LEVEL from verl.third_party.vllm import parallel_state as vllm_ps from verl.utils.device import get_torch_device, set_expandable_segments -from verl.utils.fp8_utils import is_fp8_model, load_quanted_weights from verl.utils.import_utils import deprecated from verl.utils.megatron_utils import load_megatron_model_to_gpu, offload_megatron_model_to_cpu, per_tensor_generator from verl.utils.memory_utils import aggressive_empty_cache @@ -173,14 +172,7 @@ def __enter__(self): from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader patch_vllm_moe_model_weight_loader(model) - - if is_fp8_model(self.model_runner.vllm_config): - # load_quanted_weights additionally casts bf16 weights into fp8 - logger.info("load weights weight quantization") - loaded_params = load_quanted_weights(per_tensor_param, self.model_runner) - else: - loaded_params = model.load_weights(weights=per_tensor_param) - + loaded_params = model.load_weights(per_tensor_param) info = f"vLLM load weights, loaded_params: {len(loaded_params)}" logger.info(info) From 9c843bfa5afa373b6686a0d0399f4cb2a3e4500d Mon Sep 17 00:00:00 2001 From: Xue Huang Date: Thu, 18 Sep 2025 10:34:48 +0800 Subject: [PATCH 03/18] add doc --- docs/advance/fp8.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/docs/advance/fp8.md b/docs/advance/fp8.md index 958d3e6eb62..94670a34fa7 100644 --- a/docs/advance/fp8.md +++ b/docs/advance/fp8.md @@ -1,6 +1,6 @@ # FP8 for verl -Last updated: 09/18/2025. +Last updated: 11/7/2025. This module is still in development. Currently we support FP8 rollout, using FP8 blockwise scaling (used in Deepseek, which is 1x128 quantization for activations and 128x128 quantization for model weights). @@ -12,7 +12,6 @@ function to handle model weights loading after quantization. **Notes**: - Currently, we only support VLLM rollout with Megatron training. SGLang rollout with Megatron training is on the roadmap. -- Only support Batch generate sequences. - We also support FP8 per tensor quantization, but after preliminary testing, there were issues with the accuracy and it is not recommended to use it. ## Usage @@ -35,4 +34,4 @@ Or it can be enabled by command line: - add performance of FP8 rollout - add accuracy curves of FP8 rollout - support SGLang rollout with FP8 quantization -- enable FP8 training in megatron \ No newline at end of file +- enable FP8 training in megatron From 972f287bd93636776f2912737fb7c00d77419767 Mon Sep 17 00:00:00 2001 From: Xue Huang Date: Fri, 7 Nov 2025 09:59:39 +0000 Subject: [PATCH 04/18] some fix --- verl/workers/rollout/vllm_rollout/vllm_async_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index 975e6d27bd2..ab416d2346b 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -296,7 +296,7 @@ async def launch_server(self, master_address: str = None, master_port: int = Non if isinstance(v, bool): if v: server_args.append(f"--{k}") - else: + elif v is not None: server_args.append(f"--{k}") # Use json.dumps for dict to ensure valid JSON format server_args.append(json.dumps(v) if isinstance(v, dict) else str(v)) From 360b0f58ae77beffee3b6d5fb6a521683b236188 Mon Sep 17 00:00:00 2001 From: Xue Huang Date: Wed, 19 Nov 2025 02:34:49 +0000 Subject: [PATCH 05/18] udpate vllm_rollout_spmd for async server --- .../rollout/vllm_rollout/vllm_rollout_spmd.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py index 0fdd3815c89..47ede4c72d2 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py @@ -75,6 +75,7 @@ from verl.third_party.vllm import VLLM_SLEEP_LEVEL, get_version from verl.utils.device import is_npu_available from verl.utils.distributed import initialize_global_process_group_ray +from verl.utils.fp8_utils import apply_vllm_fp8_patches, is_fp8_model, load_quanted_weights from verl.utils.import_utils import deprecated from verl.utils.model import get_lora_rank_from_adapter from verl.utils.profiler import GPUMemoryLogger @@ -603,6 +604,8 @@ def _init_worker(self, all_kwargs: list[dict[str, Any]]): if self.lora_config: lora_dtype = getattr(torch, self.config.dtype) self.vllm_config.lora_config = LoRAConfig(lora_dtype=lora_dtype, **self.lora_config) + if self.config.quantization: + apply_vllm_fp8_patches(block_quant=self.config.use_block_quant_rollout) self.inference_engine = WorkerWrapperBase(vllm_config=self.vllm_config) self.inference_engine.init_worker(all_kwargs) @@ -656,9 +659,20 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None else: from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader - model = self.inference_engine.worker.model_runner.model + model_runner = self.inference_engine.worker.model_runner + model = model_runner.model patch_vllm_moe_model_weight_loader(model) - model.load_weights(weights) + + # Add the FP8 related logic here as sharding manager has been deprecated. + # Check if FP8 quantization is enabled and apply appropriate weight loading + if is_fp8_model(model_runner.vllm_config): + logger.info(f"FP8 model detected (async): {model_runner.vllm_config.quant_config}") + # Convert bf16 weights to fp8 format before loading + loaded_params = load_quanted_weights(weights, model_runner) + logger.info(f"FP8 weights loaded (async), loaded_params: {len(loaded_params)}") + else: + logger.debug("Loading standard weights (non-FP8, async)") + model.load_weights(weights) def generate_sequences(self, prompts: DataProto) -> DataProto: """Batch generate sequences in sync mode.""" From 00b6596c29933e11a58b56112ae6a0c4f59c2f7f Mon Sep 17 00:00:00 2001 From: Xue Huang Date: Wed, 19 Nov 2025 03:59:49 +0000 Subject: [PATCH 06/18] udpate vllm quant --- .../{fp8_utils.py => vllm/vllm_fp8_utils.py} | 155 ++++-------------- .../rollout/vllm_rollout/vllm_async_server.py | 2 +- .../rollout/vllm_rollout/vllm_rollout_spmd.py | 7 +- 3 files changed, 36 insertions(+), 128 deletions(-) rename verl/utils/{fp8_utils.py => vllm/vllm_fp8_utils.py} (76%) diff --git a/verl/utils/fp8_utils.py b/verl/utils/vllm/vllm_fp8_utils.py similarity index 76% rename from verl/utils/fp8_utils.py rename to verl/utils/vllm/vllm_fp8_utils.py index 9df5e99bad7..1d20b543921 100644 --- a/verl/utils/fp8_utils.py +++ b/verl/utils/vllm/vllm_fp8_utils.py @@ -185,16 +185,7 @@ def quant_weights(weights, model, quant_config): weights_quantized.append([k + "_scale_inv", param_scale]) else: - logger.info("Using Per tensor quantization") - original_shape = v.shape - # Use per tensor quantization - quantized_tensor, scale = scaled_fp8_quant(v) - # Reshape back to original shape - quantized_tensor = quantized_tensor.view(original_shape) - - scale_k = k.replace(".weight", ".weight_scale") - scale = scale.view(1) - weights_quantized.extend([(k, quantized_tensor), (scale_k, scale)]) + raise ValueError("Currently only support blockwise quantization, please set weight_block_size in quant_config") return weights_quantized @@ -221,6 +212,11 @@ def load_quanted_weights(weights, model_runner): def process_weights_after_loading_for_vllm10(self, layer) -> None: + """This function is used to process the weights after loading for a Linear layer, it is used for vllm v0.10 + + Compared to the original process_weights_after_loading in vllm, we just avoid creation of + new torch.nn.Parameter objects, because that removes the weight_loader attribute which we need for refit. + """ logger.debug("Applying patch process_weights_after_loading") try: from vllm.model_executor.layers.quantization.utils.w8a8_utils import requantize_with_max_scale @@ -230,15 +226,7 @@ def process_weights_after_loading_for_vllm10(self, layer) -> None: PerTensorScaleParameter, ) except Exception: - try: - from sglang.srt.layers.parameter import ( - BlockQuantScaleParameter, - ModelWeightParameter, - PerTensorScaleParameter, - ) - from sglang.srt.layers.quantization.utils import requantize_with_max_scale - except Exception: - print("error") + print("error") from torch.nn import Parameter def _create_param_from_subclass_attributes(custom_param): @@ -256,68 +244,32 @@ def _create_param_from_subclass_attributes(custom_param): param.subclass_type = type(custom_param) return param - if self.block_quant: - assert self.block_quant and self.quant_config.is_checkpoint_fp8_serialized - assert self.quant_config.activation_scheme == "dynamic" - weight = layer.weight.data - weight_scale_inv = layer.weight_scale_inv.data - weight = self._maybe_pad_weight(weight) - - layer.weight = _create_param_from_subclass_attributes( - ModelWeightParameter( - data=weight, - output_dim=0, - input_dim=1, - weight_loader=layer.weight.weight_loader, - ) - ) - layer.weight_scale_inv = _create_param_from_subclass_attributes( - BlockQuantScaleParameter( - data=weight_scale_inv, - output_dim=0, - input_dim=1, - weight_loader=layer.weight_scale_inv.weight_loader, - ) - ) - - else: - weight = layer.weight.data - weight_scale = layer.weight_scale.data - - # # If using w8a8, torch._scaled_mm needs per tensor, so - # # requantize the logical shards as a single weight. - if not self.use_marlin: - # Dequant -> Quant with max scale so we can run per tensor. - - weight_scale, weight = requantize_with_max_scale( - weight=weight, - weight_scale=weight_scale, - logical_widths=layer.logical_widths, - ) + assert self.block_quant and self.quant_config.is_checkpoint_fp8_serialized + assert self.quant_config.activation_scheme == "dynamic" + weight = layer.weight.data + weight_scale_inv = layer.weight_scale_inv.data + weight = self._maybe_pad_weight(weight) - weight = self._maybe_pad_weight(weight) - # Update layer with new values. - # layer.weight = Parameter(weight.t(), requires_grad=False) - # layer.weight_scale = Parameter(weight_scale, requires_grad=False) - - layer.weight = _create_param_from_subclass_attributes( - ModelWeightParameter( - data=weight, - output_dim=0, - input_dim=1, - weight_loader=layer.weight.weight_loader, - ) + layer.weight = _create_param_from_subclass_attributes( + ModelWeightParameter( + data=weight, + output_dim=0, + input_dim=1, + weight_loader=layer.weight.weight_loader, ) - layer.weight_scale = _create_param_from_subclass_attributes( - PerTensorScaleParameter( - data=weight_scale.repeat(len(layer.logical_widths)), - weight_loader=layer.weight_scale.weight_loader, - ) + ) + layer.weight_scale_inv = _create_param_from_subclass_attributes( + BlockQuantScaleParameter( + data=weight_scale_inv, + output_dim=0, + input_dim=1, + weight_loader=layer.weight_scale_inv.weight_loader, ) + ) def process_weights_after_loading_for_vllm11(self, layer) -> None: - """This function is used to process the weights after loading for a Linear layer. + """This function is used to process the weights after loading for a Linear layer, it is used for vllm 0.11 Compared to the original process_weights_after_loading in vllm, we just avoid creation of new torch.nn.Parameter objects, because that removes the weight_loader attribute which we need for refit. @@ -387,7 +339,7 @@ def process_weights_after_loading_moe(self, layer) -> None: from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used from vllm.model_executor.layers.quantization.utils.fp8_utils import ( get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace) - + self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() assert self.quant_config.activation_scheme == "dynamic" if self.flashinfer_moe_enabled: @@ -456,51 +408,9 @@ def _create_param_from_subclass_attributes(custom_data, custom_weight): layer.w2_weight_scale_inv).contiguous() -def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import apply_fp8_marlin_linear - from vllm.model_executor.layers.quantization.utils.w8a8_utils import requantize_with_max_scale - - if self.use_marlin: - return apply_fp8_marlin_linear( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - workspace=layer.workspace, - size_n=layer.output_size_per_partition, - size_k=layer.input_size_per_partition, - bias=bias, - ) - - if self.block_quant: - assert self.quant_config.weight_block_size is not None - return torch.ops.vllm.apply_w8a8_block_fp8_linear( - input=x, - weight=layer.weight, - block_size=self.quant_config.weight_block_size, - weight_scale=layer.weight_scale_inv, - input_scale=layer.input_scale, - bias=bias, - cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, - use_aiter_and_is_supported=self.use_aiter_and_is_supported, - ) - - weight_scale, weight = requantize_with_max_scale( - weight=layer.weight, - weight_scale=layer.weight_scale, - logical_widths=layer.logical_widths, - ) - return self.fp8_linear.apply( - input=x, - weight=weight.t(), - weight_scale=weight_scale, - out_dtype=self.out_dtype, - input_scale=layer.input_scale, - bias=bias, - ) - - def apply_vllm_fp8_patches(block_quant=True): if block_quant: + print("Applying vllm fp8 patches for blockwise quantization") func1_path = "vllm.model_executor.layers.quantization.fp8.Fp8LinearMethod.process_weights_after_loading" patcher1 = patch(func1_path, process_weights_after_loading_for_vllm11 if vllm.__version__ >= "0.11.0" else process_weights_after_loading_for_vllm10) patcher1.start() @@ -508,9 +418,4 @@ def apply_vllm_fp8_patches(block_quant=True): patcher2 = patch(func2_path, process_weights_after_loading_moe) patcher2.start() else: - func1_path = "vllm.model_executor.layers.quantization.fp8.Fp8LinearMethod.process_weights_after_loading" - patcher1 = patch(func1_path, process_weights_after_loading_for_vllm10) - patcher1.start() - func2_path = "vllm.model_executor.layers.quantization.fp8.Fp8LinearMethod.apply" - patcher2 = patch(func2_path, apply) - patcher2.start() + raise ValueError("Only blockwise quantization is supported for FP8 rollout") diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index ab416d2346b..bbafc815dcd 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -43,11 +43,11 @@ from verl.single_controller.ray import RayClassWithInitArgs from verl.utils.config import omega_conf_to_dataclass +from verl.utils.vllm.vllm_fp8_utils import apply_vllm_fp8_patches from verl.workers.config import HFModelConfig, RewardModelConfig, RolloutConfig from verl.workers.rollout.replica import RolloutMode, RolloutReplica, TokenOutput from verl.workers.rollout.utils import get_free_port, is_valid_ipv6_address, run_unvicorn from verl.workers.rollout.vllm_rollout import vLLMAsyncRollout -from verl.utils.fp8_utils import apply_vllm_fp8_patches from verl.workers.rollout.vllm_rollout.utils import ( VLLM_LORA_INT_ID, VLLM_LORA_NAME, diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py index 47ede4c72d2..1973d3cd6d6 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py @@ -75,13 +75,13 @@ from verl.third_party.vllm import VLLM_SLEEP_LEVEL, get_version from verl.utils.device import is_npu_available from verl.utils.distributed import initialize_global_process_group_ray -from verl.utils.fp8_utils import apply_vllm_fp8_patches, is_fp8_model, load_quanted_weights from verl.utils.import_utils import deprecated from verl.utils.model import get_lora_rank_from_adapter from verl.utils.profiler import GPUMemoryLogger from verl.utils.ray_utils import ray_noset_visible_devices from verl.utils.torch_functional import get_response_mask, pad_2d_list_to_length from verl.utils.vllm import TensorLoRARequest, VLLMHijack, is_version_ge +from verl.utils.vllm.vllm_fp8_utils import apply_vllm_fp8_patches, is_fp8_model, load_quanted_weights from verl.workers.config import HFModelConfig, RolloutConfig from verl.workers.rollout.base import BaseRollout from verl.workers.rollout.utils import get_free_port, is_valid_ipv6_address @@ -665,13 +665,16 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None # Add the FP8 related logic here as sharding manager has been deprecated. # Check if FP8 quantization is enabled and apply appropriate weight loading + print("Checking if FP8 model is detected") if is_fp8_model(model_runner.vllm_config): + print("Loading FP8 weights (async)") logger.info(f"FP8 model detected (async): {model_runner.vllm_config.quant_config}") # Convert bf16 weights to fp8 format before loading loaded_params = load_quanted_weights(weights, model_runner) logger.info(f"FP8 weights loaded (async), loaded_params: {len(loaded_params)}") else: - logger.debug("Loading standard weights (non-FP8, async)") + print("Loading standard weights (non-FP8, async)") + logger.info("Loading standard weights (non-FP8, async)") model.load_weights(weights) def generate_sequences(self, prompts: DataProto) -> DataProto: From bece9e574b54b41293131e38b0782d3b53db0cab Mon Sep 17 00:00:00 2001 From: Xue Huang Date: Wed, 19 Nov 2025 08:14:24 +0000 Subject: [PATCH 07/18] update doc --- docs/advance/fp8.md | 97 +++++++++++++++++++++++++++++++++++++++------ 1 file changed, 85 insertions(+), 12 deletions(-) diff --git a/docs/advance/fp8.md b/docs/advance/fp8.md index 94670a34fa7..28b9d2e3b4d 100644 --- a/docs/advance/fp8.md +++ b/docs/advance/fp8.md @@ -1,18 +1,93 @@ -# FP8 for verl +# FP8 rollout for verl -Last updated: 11/7/2025. +Last updated: 11/19/2025 -This module is still in development. Currently we support FP8 rollout, using FP8 blockwise scaling (used in Deepseek, -which is 1x128 quantization for activations and 128x128 quantization for model weights). +This document introduces FP8 rollout with vllm inference backend in verl. -We monkey patches several vLLM functions to enable FP8 rollout for reinforcement learning. + +We monkey patch several vLLM functions to enable FP8 rollout for reinforcement learning. 1. **load_weights**: A custom `load_weights` function to quantize the on-the-fly model weights from a higher-precision format to FP8. 2. **process weights after loading**: Replace `vllm.model_executor.layers.quantization.fp8.Fp8LinearMethod.process_weights_after_loading` function to handle model weights loading after quantization. -**Notes**: -- Currently, we only support VLLM rollout with Megatron training. SGLang rollout with Megatron training is on the roadmap. -- We also support FP8 per tensor quantization, but after preliminary testing, there were issues with the accuracy and it is not recommended to use it. +## Support Matrix +- FP8 blockwise quantization for rollout + - Used in Deepseek, +which is 1x128 quantization for activations and 128x128 quantization for model weights +- Dense models and MoE models +- Async rollout interfaces +- vLLM 0.10.x & vLLM 0.11 +- FSDP and Megatron training backends + +## Experiments and Outcomes +### Qwen3-8B-Base Dense Model + +**Configuration** +- DAPO recipe. AIME24 online validation. +- vLLM(FP8 spmd rollout) + FSDP + - Note that SPMD rollout has been deprecated, so we removed the FP8 SPMD rollout. +- Prompt batch size 32, n=16. +- Rollout batch size: 32\*3*16 +- Train_batch_size & ppo_mini_batch_size 32 +- Max response length 20K +- Token-level TIS, C=2 +- 8*H100 +- vLLM 0.10.0+CUDA 12.6 vs vLLM 0.11.0+CUDA 12.9 + +**Accuracy** +![Qwen3-8b-base_fp8_acc]( +https://github.com/Agoniii/verl/blob/xueh/fp8_pr_images/docs/advance/images/Qwen3-8b-base_fp8_acc.png?raw=true) +*dark green: BF16, orange: FP8 rollout + token-level TIS, light green: FP8 rollout without TIS* + +Results and observations: +- With TIS, FP8 rollout aligns with BF16 +- Obvious accuracy drop when TIS is not enabled +- Higher mismatch kl but within acceptable range throughout the training + + +**Performance** + +![Qwen3-8b-base_fp8_rollout_perf]( +https://github.com/Agoniii/verl/blob/xueh/fp8_pr_images/docs/advance/images/Qwen3-8b-base_fp8_rollout_perf.png?raw=true) +*green: BF16, orange: FP8 rollout + CUDA12.6 + DeepGemm, purple: FP8 rollout + CUDA 12.9 + DeepGemm* + +Results and observations: +- FP8 rollout leads to around ~12% rollout speedup with CUDA 12.6 + DeepGemm +- When upgrading to CUDA 12.9, speedup can be up to ~18% + +### Qwen3-30B-A3B-Base MoE Model + +**Configuration** +- DAPO recipe. AIME24 online validation. +- FP8 async rollout, vLLM+FSDP +- Prompt batch size 32 +- Rollout batch size: 32\*3*16 +- Train_batch_size & ppo_mini_batch_size 32 +- Max response length 20K +- Token-level TIS, C=2 +- 2\*8*H100 +- vLLM 0.10.0+CUDA 12.6 + +**Accuracy** +![Qwen3-30b-a3b_fp8_acc]( +https://github.com/Agoniii/verl/blob/xueh/fp8_pr_images/docs/advance/images/Qwen3-30b-a3b_fp8_acc.png?raw=true) +*grey: BF16 + token-level TIS, red: FP8 rollout + token-level TIS* + +Results and observations: +- Rollout & training distribution mismatch is in general higher for MoE +- Rollout correction required even for BF16 +- FP8 rollout with token-level TIS aligns with BF16 + + +**Performance** + +![Qwen3-30b-a3b_fp8_perf]( +https://github.com/Agoniii/verl/blob/xueh/fp8_pr_images/docs/advance/images/Qwen3-30b-a3b_fp8_perf.png?raw=true) +*grey: BF16 + token-level TIS, red: FP8 rollout + token-level TIS​* + +Results and observations: +- FP8 rollout : over 35% rollout speedup +- Expecting more perf gain with CUDA 12.9 ## Usage @@ -31,7 +106,5 @@ Or it can be enabled by command line: ## Plans -- add performance of FP8 rollout -- add accuracy curves of FP8 rollout -- support SGLang rollout with FP8 quantization -- enable FP8 training in megatron +- will open another PR to support FP8 rollout in SGLang +- further to enable FP8 training in megatron From f5634930e3d62bef5b04c1f7b547166cc0a4c6ab Mon Sep 17 00:00:00 2001 From: Xue Huang Date: Wed, 19 Nov 2025 22:02:13 +0800 Subject: [PATCH 08/18] update comments --- verl/workers/rollout/vllm_rollout/vllm_async_server.py | 2 ++ verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py | 5 ++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index bbafc815dcd..ceb95cdd7e2 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -227,6 +227,8 @@ async def launch_server(self, master_address: str = None, master_port: int = Non "weight_block_size": [128, 128], } fp8_block_quant_kwargs = dict(FP8_BLOCK_QUANT_KWARGS) + # Apply vllm fp8 patches + # Will remove the patch after vllm support on-the-fly quant for rollout natively. apply_vllm_fp8_patches(block_quant=use_block_quant) args = { "dtype": self.config.dtype, diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py index 1973d3cd6d6..52c1563ab9e 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py @@ -605,6 +605,8 @@ def _init_worker(self, all_kwargs: list[dict[str, Any]]): lora_dtype = getattr(torch, self.config.dtype) self.vllm_config.lora_config = LoRAConfig(lora_dtype=lora_dtype, **self.lora_config) if self.config.quantization: + # Apply vllm fp8 patches + # Will remove the patch after vllm support on-the-fly quant for rollout natively. apply_vllm_fp8_patches(block_quant=self.config.use_block_quant_rollout) self.inference_engine = WorkerWrapperBase(vllm_config=self.vllm_config) self.inference_engine.init_worker(all_kwargs) @@ -665,15 +667,12 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None # Add the FP8 related logic here as sharding manager has been deprecated. # Check if FP8 quantization is enabled and apply appropriate weight loading - print("Checking if FP8 model is detected") if is_fp8_model(model_runner.vllm_config): - print("Loading FP8 weights (async)") logger.info(f"FP8 model detected (async): {model_runner.vllm_config.quant_config}") # Convert bf16 weights to fp8 format before loading loaded_params = load_quanted_weights(weights, model_runner) logger.info(f"FP8 weights loaded (async), loaded_params: {len(loaded_params)}") else: - print("Loading standard weights (non-FP8, async)") logger.info("Loading standard weights (non-FP8, async)") model.load_weights(weights) From 4d4f7b69649194b97874a8f2d4515b2804d4d4da Mon Sep 17 00:00:00 2001 From: Xue Huang Date: Wed, 19 Nov 2025 22:19:15 +0800 Subject: [PATCH 09/18] update dtype --- verl/utils/vllm/vllm_fp8_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/verl/utils/vllm/vllm_fp8_utils.py b/verl/utils/vllm/vllm_fp8_utils.py index 1d20b543921..0e4f928269f 100644 --- a/verl/utils/vllm/vllm_fp8_utils.py +++ b/verl/utils/vllm/vllm_fp8_utils.py @@ -161,7 +161,7 @@ def scaled_fp8_blockwise( return fp_data, descale_fp -def quant_weights(weights, model, quant_config): +def quant_weights(weights, model, quant_config, dtype=torch.bfloat16): weights_quantized = [] for k, v in weights: if not is_fp8_weight(k, model): @@ -171,7 +171,7 @@ def quant_weights(weights, model, quant_config): if quant_config.weight_block_size is not None: logger.info("Using blockwise quantization") param_lp, param_scale = scaled_fp8_blockwise( - v.to(torch.bfloat16), + v.to(dtype), weight_block_size=quant_config.weight_block_size, ) param_scale = param_scale.squeeze(-1) @@ -193,8 +193,9 @@ def quant_weights(weights, model, quant_config): def load_quanted_weights(weights, model_runner): model = model_runner.model quant_config = model_runner.vllm_config.quant_config + vllm_dtype = model_runner.vllm_config.model_config.dtype - weights_quantized = quant_weights(weights, model, quant_config) + weights_quantized = quant_weights(weights, model, quant_config, dtype=vllm_dtype) # Monkey patch the param class to their subclass, as certain models # will check the param type to call the proper weightloader From 96a28325b1ce0c2c5eb9bc5b26e65995a7d65524 Mon Sep 17 00:00:00 2001 From: Xue Huang Date: Wed, 19 Nov 2025 23:08:20 +0800 Subject: [PATCH 10/18] update scripts for fp8 rollout --- docs/advance/fp8.md | 4 + ...run_dapo_qwen3_moe_30b_vllm_fp8_rollout.sh | 172 ++++++++++++++++++ 2 files changed, 176 insertions(+) create mode 100644 recipe/dapo/run_dapo_qwen3_moe_30b_vllm_fp8_rollout.sh diff --git a/docs/advance/fp8.md b/docs/advance/fp8.md index 28b9d2e3b4d..ab023c46289 100644 --- a/docs/advance/fp8.md +++ b/docs/advance/fp8.md @@ -68,6 +68,8 @@ Results and observations: - 2\*8*H100 - vLLM 0.10.0+CUDA 12.6 +Please refer to `recipe/dapo/run_dapo_qwen3_moe_30b_vllm_fp8_rollout.sh` + **Accuracy** ![Qwen3-30b-a3b_fp8_acc]( https://github.com/Agoniii/verl/blob/xueh/fp8_pr_images/docs/advance/images/Qwen3-30b-a3b_fp8_acc.png?raw=true) @@ -104,6 +106,8 @@ Or it can be enabled by command line: - `actor_rollout_ref.rollout.quantization=True` - `actor_rollout_ref.rollout.use_block_quant_rollout=True` +Please refer to `recipe/dapo/run_dapo_qwen3_moe_30b_vllm_fp8_rollout.sh` + ## Plans - will open another PR to support FP8 rollout in SGLang diff --git a/recipe/dapo/run_dapo_qwen3_moe_30b_vllm_fp8_rollout.sh b/recipe/dapo/run_dapo_qwen3_moe_30b_vllm_fp8_rollout.sh new file mode 100644 index 00000000000..d3273980099 --- /dev/null +++ b/recipe/dapo/run_dapo_qwen3_moe_30b_vllm_fp8_rollout.sh @@ -0,0 +1,172 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +project_name='DAPO-FP8-ROLLOUT' +exp_name='DAPO-Qwen3-MOE-30B-VLLM-FP8-ROLLOUT' + + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +# Rollout Correction parameters for FP8 rollout +rollout_is=token +rollout_is_threshold=2.0 +rollout_rs=null +rollout_rs_threshold=null +rollout_rs_threshold_lower=null +rollout_token_veto_threshold=null + +max_prompt_length=$((1024)) +max_response_length=$((1024 * 20)) +enable_overlong_buffer=True +overlong_buffer_len=512 +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=True +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=32 +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 +gen_prompt_bsz=96 + +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +echo "WORKING_DIR: ${WORKING_DIR}" +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env_fp8.yaml"} +echo "RUNTIME_ENV: ${RUNTIME_ENV}" +NNODES=${NNODES:-2} +echo "NNODES: ${NNODES}" + +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH="Qwen/Qwen3-30B-A3B-Base" +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=1.0 + +# Performance Related Parameter +sp_size=4 +use_dynamic_bsz=True +actor_ppo_max_token_len=$((max_prompt_length + max_response_length)) +infer_ppo_max_token_len=$((max_prompt_length + max_response_length)) +offload=true +gen_tp=1 +train_tp=1 +train_pp=1 + +# Set Flash-RL environment variables +export VERL_LOGGING_LEVEL=DEBUG +export VLLM_LOGGING_LEVEL=DEBUG +export VLLM_CONFIGURE_LOGGING=1 +export VLLM_USE_V1=1 +export VLLM_USE_DEEP_GEMM=1 +export TORCH_NCCL_AVOID_RECORD_STREAMS=1 + +RAY_ADDRESS='http://127.0.0.1:8265' ray job submit --runtime-env=${RUNTIME_ENV} \ +-- python3 -m recipe.dapo.main_dapo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.return_raw_chat=True \ + data.filter_overlong_prompts=True \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.nccl_timeout=1800 \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.rollout_correction.rollout_is=${rollout_is} \ + algorithm.rollout_correction.rollout_is_threshold=${rollout_is_threshold} \ + algorithm.rollout_correction.rollout_rs=${rollout_rs} \ + algorithm.rollout_correction.rollout_rs_threshold=${rollout_rs_threshold} \ + algorithm.rollout_correction.rollout_rs_threshold_lower=${rollout_rs_threshold_lower} \ + algorithm.rollout_correction.rollout_token_veto_threshold=${rollout_token_veto_threshold} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.optim.clip_grad=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$(( 1024 * 32 )) \ + actor_rollout_ref.rollout.max_num_seqs=256 \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=0.6 \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + +actor_rollout_ref.rollout.quantization=True \ + +actor_rollout_ref.rollout.use_block_quant_rollout=True \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.mode=async \ + actor_rollout_ref.rollout.calculate_log_probs=True \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + reward_model.reward_manager=dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + reward_model.overlong_buffer.log=False \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=False \ + trainer.test_freq=5 \ + trainer.save_freq=5 \ + trainer.total_epochs=100 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.log_val_generations=1 \ + trainer.total_training_steps=500 \ + trainer.max_actor_ckpt_to_keep=5 \ + +trainer.dump_high_diff_tokens=False \ + +trainer.dump_high_diff_dir="${CKPTS_DIR}/30B_logprob_diff_dumps" \ + actor_rollout_ref.rollout.enforce_eager=False From c4757ef4783921840695e8227ee4cc9b88730319 Mon Sep 17 00:00:00 2001 From: Xue Huang Date: Wed, 19 Nov 2025 23:15:33 +0800 Subject: [PATCH 11/18] add ci for fp8 rollout --- .github/workflows/e2e_ppo_trainer_megatron_vllm.yml | 4 ++++ tests/special_e2e/run_ppo_trainer_megatron.sh | 5 ++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/.github/workflows/e2e_ppo_trainer_megatron_vllm.yml b/.github/workflows/e2e_ppo_trainer_megatron_vllm.yml index 78054ad3deb..d5d184b7776 100644 --- a/.github/workflows/e2e_ppo_trainer_megatron_vllm.yml +++ b/.github/workflows/e2e_ppo_trainer_megatron_vllm.yml @@ -193,6 +193,10 @@ jobs: exp_name="qwen3-0.6b-megatron-gsm8k-minimal" python -m verl.model_merger test --backend megatron --tie-word-embedding --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface python -m verl.model_merger test --backend megatron --is-value-model --local_dir checkpoints/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface + - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with FP8 rollout + run: | + ray stop --force + FP8_ROLLOUT=1 TOTAL_TRAIN_STEPS=2 MODEL_ID=Qwen/Qwen3-0.6B bash tests/special_e2e/run_ppo_trainer_megatron.sh - name: clean up run: | rm -rf checkpoints diff --git a/tests/special_e2e/run_ppo_trainer_megatron.sh b/tests/special_e2e/run_ppo_trainer_megatron.sh index 1e79784f3a2..2332bfa3da1 100644 --- a/tests/special_e2e/run_ppo_trainer_megatron.sh +++ b/tests/special_e2e/run_ppo_trainer_megatron.sh @@ -18,7 +18,7 @@ if [ "$USE_DUMMY_MODEL" = "True" ]; then echo "[ERROR] DUMMY_MODEL_CONFIG_PATH not set" exit 1 fi - + python scripts/init_random_model.py \ --hf_model_path "${MODEL_PATH}" \ --new_config_path "${DUMMY_MODEL_CONFIG_PATH}" \ @@ -128,6 +128,7 @@ ENGINE=${ENGINE:-"vllm"} exp_name="$(basename "${MODEL_ID,,}")-megatron-gsm8k-minimal" ROLLOUT_MODE=${ROLLOUT_MODE:-sync} +FP8_ROLLOUT=${FP8_ROLLOUT:-False} RETURN_RAW_CHAT="False" SKIP_TOKENIZER_INIT=${SKIP_TOKENIZER_INIT:-False} @@ -190,6 +191,8 @@ python3 -m verl.trainer.main_ppo --config-path=config \ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ actor_rollout_ref.rollout.update_weights_bucket_megabytes=128 \ + ++actor_rollout_ref.rollout.quantization=${FP8_ROLLOUT} \ + ++actor_rollout_ref.rollout.use_block_quant_rollout=${FP8_ROLLOUT} \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ actor_rollout_ref.ref.megatron.use_mbridge=${USE_MBRIDGE} \ From 00bbce4fefc1e2ea2e2f262f8c2217bdb5a168ee Mon Sep 17 00:00:00 2001 From: Xue Huang Date: Wed, 19 Nov 2025 23:34:12 +0800 Subject: [PATCH 12/18] update format --- .../_generated_ppo_megatron_trainer.yaml | 2 + verl/utils/vllm/vllm_fp8_utils.py | 66 +++++++++---------- .../rollout/vllm_rollout/vllm_async_server.py | 4 +- 3 files changed, 36 insertions(+), 36 deletions(-) diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml index 85ef78b99cc..841f8b4089f 100644 --- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -288,6 +288,8 @@ data: use_shm: false train_files: ~/data/rlhf/gsm8k/train.parquet val_files: ~/data/rlhf/gsm8k/test.parquet + train_max_samples: -1 + val_max_samples: -1 prompt_key: prompt reward_fn_key: data_source max_prompt_length: 512 diff --git a/verl/utils/vllm/vllm_fp8_utils.py b/verl/utils/vllm/vllm_fp8_utils.py index 0e4f928269f..6c6b5db7b6c 100644 --- a/verl/utils/vllm/vllm_fp8_utils.py +++ b/verl/utils/vllm/vllm_fp8_utils.py @@ -15,16 +15,14 @@ import logging from dataclasses import dataclass, field -from typing import Optional from unittest.mock import patch -import vllm import torch +import vllm try: - from vllm._custom_ops import scaled_fp8_quant - from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.fused_moe.layer import FusedMoE + from vllm.model_executor.layers.linear import LinearBase except ImportError as e: raise ImportError("FP8 quantization not available") from e @@ -98,11 +96,10 @@ def is_fp8_weight(name, model): module = get_module_from_param_name(model, name) # We currently only quantize linear layers - if ( - (isinstance(module, LinearBase) and module.weight.dtype == torch.float8_e4m3fn) - or (isinstance(module, FusedMoE) - and module.w13_weight.dtype == torch.float8_e4m3fn - and module.w2_weight.dtype == torch.float8_e4m3fn) + if (isinstance(module, LinearBase) and module.weight.dtype == torch.float8_e4m3fn) or ( + isinstance(module, FusedMoE) + and module.w13_weight.dtype == torch.float8_e4m3fn + and module.w2_weight.dtype == torch.float8_e4m3fn ): fp8_state.fp8_param_names.add(name) return name in fp8_state.fp8_param_names @@ -185,7 +182,9 @@ def quant_weights(weights, model, quant_config, dtype=torch.bfloat16): weights_quantized.append([k + "_scale_inv", param_scale]) else: - raise ValueError("Currently only support blockwise quantization, please set weight_block_size in quant_config") + raise ValueError( + "Currently only support blockwise quantization, please set weight_block_size in quant_config" + ) return weights_quantized @@ -220,11 +219,9 @@ def process_weights_after_loading_for_vllm10(self, layer) -> None: """ logger.debug("Applying patch process_weights_after_loading") try: - from vllm.model_executor.layers.quantization.utils.w8a8_utils import requantize_with_max_scale from vllm.model_executor.parameter import ( BlockQuantScaleParameter, ModelWeightParameter, - PerTensorScaleParameter, ) except Exception: print("error") @@ -294,9 +291,7 @@ def _create_param_from_subclass_attributes(custom_param): custom_param_dir = dir(custom_param) # Find the attributes that are unique to the custom parameter custom_attributes = [ - attr - for attr in custom_param_dir - if attr not in base_param_dir and not attr.startswith("__") + attr for attr in custom_param_dir if attr not in base_param_dir and not attr.startswith("__") ] # Set the custom attributes into the base parameter object for attr in custom_attributes: @@ -305,11 +300,7 @@ def _create_param_from_subclass_attributes(custom_param): param.subclass_type = type(custom_param) return param - weight_scale = ( - layer.weight_scale_inv - if hasattr(layer, "weight_scale_inv") - else layer.weight_scale - ) + weight_scale = layer.weight_scale_inv if hasattr(layer, "weight_scale_inv") else layer.weight_scale weight, weight_scale = process_fp8_weight_block_strategy(layer.weight, weight_scale) layer.weight = _create_param_from_subclass_attributes( @@ -336,17 +327,18 @@ def _create_param_from_subclass_attributes(custom_param): def process_weights_after_loading_moe(self, layer) -> None: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled - from vllm.model_executor.layers.quantization.fp8 import _swap_w13_to_w31, _is_col_major - from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used + from vllm.model_executor.layers.quantization.fp8 import _is_col_major, _swap_w13_to_w31 from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace) + get_col_major_tma_aligned_tensor, + requant_weight_ue8m0_inplace, + ) + from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() assert self.quant_config.activation_scheme == "dynamic" if self.flashinfer_moe_enabled: w13_weight = _swap_w13_to_w31(layer.w13_weight.data) - w13_weight_scale_inv = _swap_w13_to_w31( - layer.w13_weight_scale_inv.data) + w13_weight_scale_inv = _swap_w13_to_w31(layer.w13_weight_scale_inv.data) w2_weight = layer.w2_weight.data w2_weight_scale_inv = layer.w2_weight_scale_inv.data else: @@ -356,6 +348,7 @@ def process_weights_after_loading_moe(self, layer) -> None: w2_weight_scale_inv = layer.w2_weight_scale_inv from torch.nn import Parameter + def _create_param_from_subclass_attributes(custom_data, custom_weight): param = Parameter(custom_data, requires_grad=False) base_param_dir = dir(torch.nn.Parameter) @@ -371,7 +364,9 @@ def _create_param_from_subclass_attributes(custom_data, custom_weight): return param layer.w13_weight = _create_param_from_subclass_attributes(w13_weight, layer.w13_weight) - layer.w13_weight_scale_inv = _create_param_from_subclass_attributes(w13_weight_scale_inv, layer.w13_weight_scale_inv) + layer.w13_weight_scale_inv = _create_param_from_subclass_attributes( + w13_weight_scale_inv, layer.w13_weight_scale_inv + ) layer.w2_weight = _create_param_from_subclass_attributes(w2_weight, layer.w2_weight) layer.w2_weight_scale_inv = _create_param_from_subclass_attributes(w2_weight_scale_inv, layer.w2_weight_scale_inv) @@ -380,11 +375,9 @@ def _create_param_from_subclass_attributes(custom_data, custom_weight): if self.allow_deep_gemm and not is_blackwell_deep_gemm_used(): # Lazy import to avoid CUDA initialization problems. if _is_col_major(layer.w13_weight_scale_inv): - layer.w13_weight_scale_inv = \ - get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous() + layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous() if _is_col_major(layer.w2_weight_scale_inv): - layer.w2_weight_scale_inv = \ - get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous() + layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous() if is_blackwell_deep_gemm_used(): assert layer.weight_block_size is not None @@ -402,18 +395,21 @@ def _create_param_from_subclass_attributes(custom_data, custom_weight): ) if _is_col_major(layer.w13_weight_scale_inv): - layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor( - layer.w13_weight_scale_inv).contiguous() + layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous() if _is_col_major(layer.w2_weight_scale_inv): - layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor( - layer.w2_weight_scale_inv).contiguous() + layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous() def apply_vllm_fp8_patches(block_quant=True): if block_quant: print("Applying vllm fp8 patches for blockwise quantization") func1_path = "vllm.model_executor.layers.quantization.fp8.Fp8LinearMethod.process_weights_after_loading" - patcher1 = patch(func1_path, process_weights_after_loading_for_vllm11 if vllm.__version__ >= "0.11.0" else process_weights_after_loading_for_vllm10) + patcher1 = patch( + func1_path, + process_weights_after_loading_for_vllm11 + if vllm.__version__ >= "0.11.0" + else process_weights_after_loading_for_vllm10, + ) patcher1.start() func2_path = "vllm.model_executor.layers.quantization.fp8.Fp8MoEMethod.process_weights_after_loading" patcher2 = patch(func2_path, process_weights_after_loading_moe) diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index ceb95cdd7e2..11fb207855d 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -249,7 +249,9 @@ async def launch_server(self, master_address: str = None, master_port: int = Non "seed": self.config.get("seed", 0), "override_generation_config": json.dumps(override_generation_config), "quantization": "fp8" if quantization else None, - "hf_overrides": {"quantization_config": fp8_block_quant_kwargs} if quantization and use_block_quant else None, + "hf_overrides": {"quantization_config": fp8_block_quant_kwargs} + if quantization and use_block_quant + else None, **engine_kwargs, } From aa8a4f38b1c745c1730102734764d36103e9c1b1 Mon Sep 17 00:00:00 2001 From: Xue Huang Date: Thu, 20 Nov 2025 14:27:15 +0800 Subject: [PATCH 13/18] modify flag for fp8 rollout --- .../e2e_ppo_trainer_megatron_vllm.yml | 3 +- docs/advance/fp8.md | 7 ++--- ...run_dapo_qwen3_moe_30b_vllm_fp8_rollout.sh | 3 +- tests/special_e2e/run_ppo_trainer_megatron.sh | 5 ++-- .../_generated_ppo_megatron_trainer.yaml | 3 +- verl/trainer/config/ppo_megatron_trainer.yaml | 4 +-- verl/utils/vllm/vllm_fp8_utils.py | 29 +++++++++---------- verl/workers/config/rollout.py | 4 +-- .../rollout/vllm_rollout/vllm_async_server.py | 19 ++++++------ .../rollout/vllm_rollout/vllm_rollout_spmd.py | 11 ++++--- 10 files changed, 39 insertions(+), 49 deletions(-) diff --git a/.github/workflows/e2e_ppo_trainer_megatron_vllm.yml b/.github/workflows/e2e_ppo_trainer_megatron_vllm.yml index d5d184b7776..c4ec6e664c6 100644 --- a/.github/workflows/e2e_ppo_trainer_megatron_vllm.yml +++ b/.github/workflows/e2e_ppo_trainer_megatron_vllm.yml @@ -196,7 +196,8 @@ jobs: - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with FP8 rollout run: | ray stop --force - FP8_ROLLOUT=1 TOTAL_TRAIN_STEPS=2 MODEL_ID=Qwen/Qwen3-0.6B bash tests/special_e2e/run_ppo_trainer_megatron.sh + export VLLM_USE_V1=1 + ROLLOUT_QUANTIZATION=fp8 ROLLOUT_MODE=async TOTAL_TRAIN_STEPS=2 MODEL_ID=Qwen/Qwen3-0.6B bash tests/special_e2e/run_ppo_trainer_megatron.sh - name: clean up run: | rm -rf checkpoints diff --git a/docs/advance/fp8.md b/docs/advance/fp8.md index ab023c46289..62183a04a84 100644 --- a/docs/advance/fp8.md +++ b/docs/advance/fp8.md @@ -97,14 +97,11 @@ FP8 can be enabled in the config file `verl/trainer/config/ppo_megatron_trainer. ``` rollout: - quantization: True - - use_block_quant_rollout: True + quantization: "fp8" ``` Or it can be enabled by command line: -- `actor_rollout_ref.rollout.quantization=True` -- `actor_rollout_ref.rollout.use_block_quant_rollout=True` +- `actor_rollout_ref.rollout.quantization=fp8` Please refer to `recipe/dapo/run_dapo_qwen3_moe_30b_vllm_fp8_rollout.sh` diff --git a/recipe/dapo/run_dapo_qwen3_moe_30b_vllm_fp8_rollout.sh b/recipe/dapo/run_dapo_qwen3_moe_30b_vllm_fp8_rollout.sh index d3273980099..c8860c6aea2 100644 --- a/recipe/dapo/run_dapo_qwen3_moe_30b_vllm_fp8_rollout.sh +++ b/recipe/dapo/run_dapo_qwen3_moe_30b_vllm_fp8_rollout.sh @@ -140,8 +140,7 @@ RAY_ADDRESS='http://127.0.0.1:8265' ray job submit --runtime-env=${RUNTIME_ENV} actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ actor_rollout_ref.rollout.val_kwargs.do_sample=True \ actor_rollout_ref.rollout.val_kwargs.n=1 \ - +actor_rollout_ref.rollout.quantization=True \ - +actor_rollout_ref.rollout.use_block_quant_rollout=True \ + +actor_rollout_ref.rollout.quantization=fp8 \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.mode=async \ actor_rollout_ref.rollout.calculate_log_probs=True \ diff --git a/tests/special_e2e/run_ppo_trainer_megatron.sh b/tests/special_e2e/run_ppo_trainer_megatron.sh index 2332bfa3da1..e977d365f05 100644 --- a/tests/special_e2e/run_ppo_trainer_megatron.sh +++ b/tests/special_e2e/run_ppo_trainer_megatron.sh @@ -128,7 +128,7 @@ ENGINE=${ENGINE:-"vllm"} exp_name="$(basename "${MODEL_ID,,}")-megatron-gsm8k-minimal" ROLLOUT_MODE=${ROLLOUT_MODE:-sync} -FP8_ROLLOUT=${FP8_ROLLOUT:-False} +ROLLOUT_QUANTIZATION=${ROLLOUT_QUANTIZATION:-null} RETURN_RAW_CHAT="False" SKIP_TOKENIZER_INIT=${SKIP_TOKENIZER_INIT:-False} @@ -191,8 +191,7 @@ python3 -m verl.trainer.main_ppo --config-path=config \ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ actor_rollout_ref.rollout.update_weights_bucket_megabytes=128 \ - ++actor_rollout_ref.rollout.quantization=${FP8_ROLLOUT} \ - ++actor_rollout_ref.rollout.use_block_quant_rollout=${FP8_ROLLOUT} \ + ++actor_rollout_ref.rollout.quantization=${ROLLOUT_QUANTIZATION} \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ actor_rollout_ref.ref.megatron.use_mbridge=${USE_MBRIDGE} \ diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml index 841f8b4089f..f6a0706d389 100644 --- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -265,8 +265,7 @@ actor_rollout_ref: port: 9090 file: /tmp/ray/session_latest/metrics/prometheus/prometheus.yml served_model_name: ${oc.select:actor_rollout_ref.model.path,null} - quantization: false - use_block_quant_rollout: true + quantization: null layer_name_map: qkv_layer_name: qkv gate_proj_layer_name: gate_up diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index 0c1b77afcd5..454395cd2e2 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -45,9 +45,7 @@ actor_rollout_ref: use_remove_padding: false rollout: - quantization: False - - use_block_quant_rollout: True + quantization: null layer_name_map: qkv_layer_name: qkv diff --git a/verl/utils/vllm/vllm_fp8_utils.py b/verl/utils/vllm/vllm_fp8_utils.py index 6c6b5db7b6c..eab4e4d8d77 100644 --- a/verl/utils/vllm/vllm_fp8_utils.py +++ b/verl/utils/vllm/vllm_fp8_utils.py @@ -400,19 +400,16 @@ def _create_param_from_subclass_attributes(custom_data, custom_weight): layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous() -def apply_vllm_fp8_patches(block_quant=True): - if block_quant: - print("Applying vllm fp8 patches for blockwise quantization") - func1_path = "vllm.model_executor.layers.quantization.fp8.Fp8LinearMethod.process_weights_after_loading" - patcher1 = patch( - func1_path, - process_weights_after_loading_for_vllm11 - if vllm.__version__ >= "0.11.0" - else process_weights_after_loading_for_vllm10, - ) - patcher1.start() - func2_path = "vllm.model_executor.layers.quantization.fp8.Fp8MoEMethod.process_weights_after_loading" - patcher2 = patch(func2_path, process_weights_after_loading_moe) - patcher2.start() - else: - raise ValueError("Only blockwise quantization is supported for FP8 rollout") +def apply_vllm_fp8_patches(): + logger.info("Applying vllm fp8 patches for blockwise quantization") + func1_path = "vllm.model_executor.layers.quantization.fp8.Fp8LinearMethod.process_weights_after_loading" + patcher1 = patch( + func1_path, + process_weights_after_loading_for_vllm11 + if vllm.__version__ >= "0.11.0" + else process_weights_after_loading_for_vllm10, + ) + patcher1.start() + func2_path = "vllm.model_executor.layers.quantization.fp8.Fp8MoEMethod.process_weights_after_loading" + patcher2 = patch(func2_path, process_weights_after_loading_moe) + patcher2.start() diff --git a/verl/workers/config/rollout.py b/verl/workers/config/rollout.py index cd92bc24df4..3e8276e38eb 100644 --- a/verl/workers/config/rollout.py +++ b/verl/workers/config/rollout.py @@ -198,9 +198,7 @@ class RolloutConfig(BaseConfig): skip_tokenizer_init: bool = False - quantization: bool = False - - use_block_quant_rollout: bool = False + quantization: Optional[str] = None def __post_init__(self): """Validate the rollout config""" diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index 11fb207855d..db9e234f3fb 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -217,9 +217,8 @@ async def launch_server(self, master_address: str = None, master_port: int = Non ) logger.info(f"override_generation_config: {override_generation_config}") quantization = self.config.quantization - use_block_quant = self.config.use_block_quant_rollout - if quantization: - if use_block_quant: + if quantization is not None: + if quantization == "fp8": FP8_BLOCK_QUANT_KWARGS = { "activation_scheme": "dynamic", "fmt": "e4m3", @@ -227,9 +226,11 @@ async def launch_server(self, master_address: str = None, master_port: int = Non "weight_block_size": [128, 128], } fp8_block_quant_kwargs = dict(FP8_BLOCK_QUANT_KWARGS) - # Apply vllm fp8 patches - # Will remove the patch after vllm support on-the-fly quant for rollout natively. - apply_vllm_fp8_patches(block_quant=use_block_quant) + # Apply vllm fp8 patches + # Will remove the patch after vllm support on-the-fly quant for rollout natively. + apply_vllm_fp8_patches() + else: + raise ValueError(f"Currently only support fp8 quantization, got: {quantization}") args = { "dtype": self.config.dtype, "load_format": self.config.load_format, @@ -248,10 +249,8 @@ async def launch_server(self, master_address: str = None, master_port: int = Non "tensor_parallel_size": self.config.tensor_model_parallel_size, "seed": self.config.get("seed", 0), "override_generation_config": json.dumps(override_generation_config), - "quantization": "fp8" if quantization else None, - "hf_overrides": {"quantization_config": fp8_block_quant_kwargs} - if quantization and use_block_quant - else None, + "quantization": quantization, + "hf_overrides": {"quantization_config": fp8_block_quant_kwargs} if quantization == "fp8" else None, **engine_kwargs, } diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py index 52c1563ab9e..a14314188af 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py @@ -604,10 +604,13 @@ def _init_worker(self, all_kwargs: list[dict[str, Any]]): if self.lora_config: lora_dtype = getattr(torch, self.config.dtype) self.vllm_config.lora_config = LoRAConfig(lora_dtype=lora_dtype, **self.lora_config) - if self.config.quantization: - # Apply vllm fp8 patches - # Will remove the patch after vllm support on-the-fly quant for rollout natively. - apply_vllm_fp8_patches(block_quant=self.config.use_block_quant_rollout) + if self.config.quantization is not None: + if self.config.quantization == "fp8": + # Apply vllm fp8 patches + # Will remove the patch after vllm support on-the-fly quant for rollout natively. + apply_vllm_fp8_patches() + else: + raise ValueError(f"Currently only support fp8 quantization, got: {self.config.quantization}") self.inference_engine = WorkerWrapperBase(vllm_config=self.vllm_config) self.inference_engine.init_worker(all_kwargs) From 311c0821b9d0069ecf24f20c32054843f3f1ab0e Mon Sep 17 00:00:00 2001 From: Xue Huang Date: Wed, 26 Nov 2025 14:47:24 +0000 Subject: [PATCH 14/18] sglang fp8 rollout --- verl/utils/sglang/sglang_fp8_utils.py | 187 ++++++++++++++++++ .../sglang_rollout/async_sglang_server.py | 17 ++ .../rollout/sglang_rollout/sglang_rollout.py | 36 +++- 3 files changed, 232 insertions(+), 8 deletions(-) create mode 100644 verl/utils/sglang/sglang_fp8_utils.py diff --git a/verl/utils/sglang/sglang_fp8_utils.py b/verl/utils/sglang/sglang_fp8_utils.py new file mode 100644 index 00000000000..cc09f450454 --- /dev/null +++ b/verl/utils/sglang/sglang_fp8_utils.py @@ -0,0 +1,187 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import torch + +logger = logging.getLogger(__name__) + + +def should_quantize_param(param_name: str) -> bool: + """根据参数名判断是否应该量化为 FP8 + + 量化规则: + - 必须以 .weight 结尾(排除 bias) + - 排除 embedding 层 + - 排除 normalization 层 + - 排除输出层(lm_head) + """ + # 必须是权重参数 + if not param_name.endswith(".weight"): + return False + + # 排除的层类型 + exclude_patterns = [ + "embed_tokens", # Embedding 层 + "lm_head", # 输出层 + "layernorm", # LayerNorm + "norm", # 各种 Norm 层 + "ln_", # LayerNorm 变体 + "embeddings", # Embeddings + ] + + # 检查是否匹配排除模式 + param_lower = param_name.lower() + for pattern in exclude_patterns: + if pattern in param_lower: + return False + + # 包含的层类型(Linear 层) + include_patterns = [ + "q_proj", # Query projection + "k_proj", # Key projection + "v_proj", # Value projection + "o_proj", # Output projection + "gate_proj", # Gate projection (for MLP) + "up_proj", # Up projection (for MLP) + "down_proj", # Down projection (for MLP) + "fc1", # Fully connected 1 + "fc2", # Fully connected 2 + "gate", # Gate (for MoE) + "mlp", # MLP layers + ] + + # 检查是否匹配包含模式 + for pattern in include_patterns: + if pattern in param_lower: + logger.info(f"Will quantize FP8: {param_name}") + return True + + # 默认不量化 + logger.debug(f"Skip quantization: {param_name}") + return False + + +def quant_weights_by_name(weights, quant_config, dtype=torch.bfloat16, vllm_version="0.11.0"): + """基于参数名的 FP8 量化 + + Args: + weights: Generator of (name, tensor) pairs + quant_config: Quantization configuration + dtype: Data type for intermediate computation + vllm_version: vLLM version string for scale naming + + Returns: + List of (name, tensor) pairs with quantized weights + """ + from verl.utils.sglang.sglang_fp8_utils import scaled_fp8_blockwise + + weights_quantized = [] + quantized_count = 0 + skipped_count = 0 + + if isinstance(quant_config, dict): + weight_block_size = quant_config.get("weight_block_size") + else: + weight_block_size = getattr(quant_config, "weight_block_size", None) + + if weight_block_size is None: + raise ValueError("weight_block_size not found in quant_config") + + for k, v in weights: + # 判断是否需要量化 + if not should_quantize_param(k): + weights_quantized.append((k, v)) + skipped_count += 1 + continue + + # 量化为 FP8 + try: + if weight_block_size is not None: + print(f"Quantizing to FP8 blockwise: {k}") + param_lp, param_scale = scaled_fp8_blockwise( + v.to(dtype), + weight_block_size=weight_block_size, + ) + param_scale = param_scale.squeeze(-1) + weights_quantized.append([k, param_lp]) + weights_quantized.append([k + "_scale_inv", param_scale]) + quantized_count += 1 + else: + raise ValueError( + "Only blockwise quantization is supported. Please set weight_block_size in quant_config" + ) + except Exception as e: + logger.error(f"Failed to quantize {k}: {e}") + # 如果量化失败,使用原始权重 + weights_quantized.append((k, v)) + skipped_count += 1 + + print(f"FP8 quantization complete: {quantized_count} quantized, {skipped_count} skipped") + return weights_quantized + + +def scaled_fp8_blockwise( + data_hp, + weight_block_size, +): + # cast tensor from high precision to FP8 with 128*128 blockwise quantization. + assert len(data_hp.shape) == 2, "Only 2d input tensor is supported" + + block_size1 = weight_block_size[1] + block_size0 = weight_block_size[0] + assert data_hp.shape[1] % block_size1 == 0, ( + f"data_hp.shape[1] {data_hp.shape[1]} must be a multiple of block_size1: {block_size1}." + ) + assert data_hp.shape[0] % block_size0 == 0, ( + f"data_hp.shape[0] {data_hp.shape[0]} must be a multiple of block_size0: {block_size0}." + ) + + # FP8 + max_dtype = torch.finfo(torch.float8_e4m3fn).max + + original_shape = data_hp.shape + blk_m, blk_n = data_hp.shape[0] // block_size0, data_hp.shape[1] // block_size1 + + assert block_size1 == block_size0 + data_hp = data_hp.reshape(blk_m, block_size0, blk_n, block_size1) + + # Permute to (BLK_M, BLK_N, BLOCK_SIZE_M, BLOCK_SIZE_N) + data_hp = data_hp.permute(0, 2, 1, 3) + # Flatten to (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) + data_hp = data_hp.to(torch.float32).contiguous().flatten(start_dim=2) + + # Calculate max absolute value per block + max_abs = torch.amax(torch.abs(data_hp), dim=-1, keepdim=True) + + # Use FP32 scale + scale_fp = max_dtype / max_abs + scale_fp = torch.where(max_abs == 0, 1.0, scale_fp) + # preserve the behavior for 0 amax case + scale_fp = torch.where(max_abs == torch.inf, 1.0, scale_fp) + + descale_fp = torch.reciprocal(scale_fp) + + # Scale and saturate cast the data elements to max of target dtype + data_lp = torch.clamp(data_hp * scale_fp, min=-1 * max_dtype, max=max_dtype) + + fp_data = data_lp.to(torch.float8_e4m3fn) + + # (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) to (M, N) + fp_data = fp_data.reshape(blk_m, blk_n, block_size0, block_size1).permute(0, 2, 1, 3).reshape(original_shape) + + # Convert to target format, but still in original precision container + return fp_data, descale_fp diff --git a/verl/workers/rollout/sglang_rollout/async_sglang_server.py b/verl/workers/rollout/sglang_rollout/async_sglang_server.py index 1d3d657d925..b3355bf4ccb 100644 --- a/verl/workers/rollout/sglang_rollout/async_sglang_server.py +++ b/verl/workers/rollout/sglang_rollout/async_sglang_server.py @@ -14,6 +14,7 @@ # limitations under the License. import asyncio import dataclasses +import json import logging import os from typing import Any, Optional @@ -125,6 +126,18 @@ async def launch_server(self, master_address: str = None, master_port: int = Non engine_kwargs = self.config.get("engine_kwargs", {}).get("sglang", {}) or {} attention_backend = engine_kwargs.pop("attention_backend", None) + quantization = self.config.get("quantization", None) + if quantization is not None: + if quantization == "fp8": + FP8_BLOCK_QUANT_KWARGS = { + "activation_scheme": "dynamic", + "fmt": "e4m3", + "quant_method": "fp8", + "weight_block_size": [128, 128], + } + fp8_block_quant_kwargs = dict(FP8_BLOCK_QUANT_KWARGS) + else: + raise ValueError(f"Currently only support fp8 quantization, got: {quantization}") dist_init_addr = ( f"[{self._master_address}]:{self._master_port}" if is_valid_ipv6_address(self._master_address) @@ -153,6 +166,10 @@ async def launch_server(self, master_address: str = None, master_port: int = Non "attention_backend": attention_backend if attention_backend is not None else "fa3", "skip_tokenizer_init": self.config.skip_tokenizer_init, "skip_server_warmup": True, + "quantization": quantization, + "json_model_override_args": json.dumps({"quantization_config": fp8_block_quant_kwargs}) + if quantization == "fp8" + else json.dumps({}), **engine_kwargs, } diff --git a/verl/workers/rollout/sglang_rollout/sglang_rollout.py b/verl/workers/rollout/sglang_rollout/sglang_rollout.py index fa5ca879eb9..bf91d73b21b 100644 --- a/verl/workers/rollout/sglang_rollout/sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/sglang_rollout.py @@ -363,7 +363,7 @@ def _verify_config(self, model_hf_config): self.config.max_model_len = self.config.prompt_length + self.config.response_length assert ( self.config.max_model_len >= self.config.prompt_length + self.config.response_length - ), f"""max_model_len should be greater than total sequence length (prompt_length + response_length): + ), f"""max_model_len should be greater than total sequence length (prompt_length + response_length): {self.config.max_model_len} >= {self.config.prompt_length} + {self.config.response_length}""" max_position_embeddings = None if hasattr(model_hf_config, "max_position_embeddings"): @@ -1218,8 +1218,8 @@ async def run_with_cancellation(): == req.attention_mask.shape[-1] == req.position_ids.shape[-1] == req.loss_mask.shape[-1] - ), f"""Request {req.request_id} has different length of - {req.input_ids.shape[-1]=}, {req.attention_mask.shape[-1]=}, + ), f"""Request {req.request_id} has different length of + {req.input_ids.shape[-1]=}, {req.attention_mask.shape[-1]=}, {req.position_ids.shape[-1]=}, {req.loss_mask.shape[-1]=}""" error_message_lines = [ f"""Request {req.request_id} has input_ids length {req.input_ids.shape[-1]} @@ -1237,7 +1237,7 @@ async def run_with_cancellation(): response_ids.append(req.response_ids.to(tgt_device).squeeze(0)) if req.response_ids.shape[-1] > self.config.response_length: logger.warning( - f"""{req.request_id=} has response_ids length {req.response_ids.shape[-1]} + f"""{req.request_id=} has response_ids length {req.response_ids.shape[-1]} greater than max_response_len {self.config.response_length},\n{req=}""" ) prompt_attention_mask.append(req.prompt_attention_mask.to(tgt_device).squeeze(0)) @@ -1486,10 +1486,10 @@ def _preprocess_prompt_to_async_rollout_requests(self, prompts: DataProto, n: in tokenization_sanity_check_mode=self.config.multi_turn.tokenization_sanity_check_mode, processing_class=self.processing_class, ) - error_message = f"""Request {req.request_id} has mismatched lengths: - input_ids={req.input_ids.shape[-1]}, - attention_mask={req.attention_mask.shape[-1]}, - position_ids={req.position_ids.shape[-1]}, + error_message = f"""Request {req.request_id} has mismatched lengths: + input_ids={req.input_ids.shape[-1]}, + attention_mask={req.attention_mask.shape[-1]}, + position_ids={req.position_ids.shape[-1]}, loss_mask={req.loss_mask.shape[-1]}""" assert ( req.input_ids.shape[-1] @@ -1555,6 +1555,15 @@ def __init__( model_config: HFModelConfig, device_mesh: DeviceMesh, ): + if config.get("quantization", None) == "fp8": + FP8_BLOCK_QUANT_KWARGS = { + "activation_scheme": "dynamic", + "fmt": "e4m3", + "quant_method": "fp8", + "weight_block_size": [128, 128], + } + fp8_block_quant_kwargs = dict(FP8_BLOCK_QUANT_KWARGS) + model_config.hf_config.quantization_config = fp8_block_quant_kwargs super().__init__(config, model_config, device_mesh) self._engine: AsyncHttpServerAdapter = None @@ -1615,6 +1624,17 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None await self._init_server_adapter() update_weights_bucket_bytes = int(self.config.update_weights_bucket_megabytes) << 20 + if self.config.get("quantization", None) == "fp8": + from verl.utils.sglang.sglang_fp8_utils import quant_weights_by_name + + weights = quant_weights_by_name( + weights, + self.model_config.hf_config.quantization_config, + dtype=self.model_config.hf_config.dtype, + ) + else: + weights = weights + for params_batch in get_named_tensor_buckets(weights, update_weights_bucket_bytes): await sgl_update_weights( engine=self._engine, From d772aef2f3f47001bf1734653bd74dc60dddf414 Mon Sep 17 00:00:00 2001 From: Xue Huang Date: Wed, 26 Nov 2025 18:34:50 -0800 Subject: [PATCH 15/18] update --- verl/utils/sglang/sglang_fp8_utils.py | 49 +++++++++---------- .../rollout/sglang_rollout/sglang_rollout.py | 2 +- 2 files changed, 23 insertions(+), 28 deletions(-) diff --git a/verl/utils/sglang/sglang_fp8_utils.py b/verl/utils/sglang/sglang_fp8_utils.py index cc09f450454..c8c94794a56 100644 --- a/verl/utils/sglang/sglang_fp8_utils.py +++ b/verl/utils/sglang/sglang_fp8_utils.py @@ -21,35 +21,35 @@ def should_quantize_param(param_name: str) -> bool: - """根据参数名判断是否应该量化为 FP8 + """Determine whether to quantize to FP8 based on parameter name - 量化规则: - - 必须以 .weight 结尾(排除 bias) - - 排除 embedding 层 - - 排除 normalization 层 - - 排除输出层(lm_head) + Quantization rules: + - Must end with .weight (exclude bias) + - Exclude embedding layers + - Exclude normalization layers + - Exclude output layer (lm_head) """ - # 必须是权重参数 + # Must be a weight parameter if not param_name.endswith(".weight"): return False - # 排除的层类型 + # Layer types to exclude exclude_patterns = [ - "embed_tokens", # Embedding 层 - "lm_head", # 输出层 + "embed_tokens", # Embedding layer + "lm_head", # Output layer "layernorm", # LayerNorm - "norm", # 各种 Norm 层 - "ln_", # LayerNorm 变体 + "norm", # Various Norm layers + "ln_", # LayerNorm variants "embeddings", # Embeddings ] - # 检查是否匹配排除模式 + # Check if matches exclude patterns param_lower = param_name.lower() for pattern in exclude_patterns: if pattern in param_lower: return False - # 包含的层类型(Linear 层) + # Layer types to include (Linear layers) include_patterns = [ "q_proj", # Query projection "k_proj", # Key projection @@ -64,19 +64,19 @@ def should_quantize_param(param_name: str) -> bool: "mlp", # MLP layers ] - # 检查是否匹配包含模式 + # Check if matches include patterns for pattern in include_patterns: if pattern in param_lower: logger.info(f"Will quantize FP8: {param_name}") return True - # 默认不量化 + # Do not quantize by default logger.debug(f"Skip quantization: {param_name}") return False def quant_weights_by_name(weights, quant_config, dtype=torch.bfloat16, vllm_version="0.11.0"): - """基于参数名的 FP8 量化 + """FP8 quantization based on parameter name Args: weights: Generator of (name, tensor) pairs @@ -90,8 +90,6 @@ def quant_weights_by_name(weights, quant_config, dtype=torch.bfloat16, vllm_vers from verl.utils.sglang.sglang_fp8_utils import scaled_fp8_blockwise weights_quantized = [] - quantized_count = 0 - skipped_count = 0 if isinstance(quant_config, dict): weight_block_size = quant_config.get("weight_block_size") @@ -102,16 +100,15 @@ def quant_weights_by_name(weights, quant_config, dtype=torch.bfloat16, vllm_vers raise ValueError("weight_block_size not found in quant_config") for k, v in weights: - # 判断是否需要量化 + # Check if quantization is needed if not should_quantize_param(k): weights_quantized.append((k, v)) - skipped_count += 1 continue - # 量化为 FP8 + # Quantize to FP8 try: if weight_block_size is not None: - print(f"Quantizing to FP8 blockwise: {k}") + logger.debug(f"Quantizing to FP8 blockwise: {k}") param_lp, param_scale = scaled_fp8_blockwise( v.to(dtype), weight_block_size=weight_block_size, @@ -119,18 +116,15 @@ def quant_weights_by_name(weights, quant_config, dtype=torch.bfloat16, vllm_vers param_scale = param_scale.squeeze(-1) weights_quantized.append([k, param_lp]) weights_quantized.append([k + "_scale_inv", param_scale]) - quantized_count += 1 else: raise ValueError( "Only blockwise quantization is supported. Please set weight_block_size in quant_config" ) except Exception as e: logger.error(f"Failed to quantize {k}: {e}") - # 如果量化失败,使用原始权重 + # If quantization fails, use original weights weights_quantized.append((k, v)) - skipped_count += 1 - print(f"FP8 quantization complete: {quantized_count} quantized, {skipped_count} skipped") return weights_quantized @@ -185,3 +179,4 @@ def scaled_fp8_blockwise( # Convert to target format, but still in original precision container return fp_data, descale_fp + diff --git a/verl/workers/rollout/sglang_rollout/sglang_rollout.py b/verl/workers/rollout/sglang_rollout/sglang_rollout.py index bf91d73b21b..4b123ddb9df 100644 --- a/verl/workers/rollout/sglang_rollout/sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/sglang_rollout.py @@ -1626,7 +1626,7 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None update_weights_bucket_bytes = int(self.config.update_weights_bucket_megabytes) << 20 if self.config.get("quantization", None) == "fp8": from verl.utils.sglang.sglang_fp8_utils import quant_weights_by_name - + logger.info(f"Convert bf16 weights to fp8 format before loading") weights = quant_weights_by_name( weights, self.model_config.hf_config.quantization_config, From 32d0c66e595c7214ce9169a83f635e06934ef34b Mon Sep 17 00:00:00 2001 From: Xue Huang Date: Thu, 4 Dec 2025 01:19:26 -0800 Subject: [PATCH 16/18] small update --- verl/utils/sglang/sglang_fp8_utils.py | 110 +++++++++--------- .../rollout/sglang_rollout/sglang_rollout.py | 3 +- 2 files changed, 57 insertions(+), 56 deletions(-) diff --git a/verl/utils/sglang/sglang_fp8_utils.py b/verl/utils/sglang/sglang_fp8_utils.py index c8c94794a56..1833c02abb8 100644 --- a/verl/utils/sglang/sglang_fp8_utils.py +++ b/verl/utils/sglang/sglang_fp8_utils.py @@ -14,10 +14,12 @@ # limitations under the License. import logging +import os import torch -logger = logging.getLogger(__name__) +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO")) def should_quantize_param(param_name: str) -> bool: @@ -67,7 +69,7 @@ def should_quantize_param(param_name: str) -> bool: # Check if matches include patterns for pattern in include_patterns: if pattern in param_lower: - logger.info(f"Will quantize FP8: {param_name}") + logger.debug(f"Will quantize FP8: {param_name}") return True # Do not quantize by default @@ -75,59 +77,6 @@ def should_quantize_param(param_name: str) -> bool: return False -def quant_weights_by_name(weights, quant_config, dtype=torch.bfloat16, vllm_version="0.11.0"): - """FP8 quantization based on parameter name - - Args: - weights: Generator of (name, tensor) pairs - quant_config: Quantization configuration - dtype: Data type for intermediate computation - vllm_version: vLLM version string for scale naming - - Returns: - List of (name, tensor) pairs with quantized weights - """ - from verl.utils.sglang.sglang_fp8_utils import scaled_fp8_blockwise - - weights_quantized = [] - - if isinstance(quant_config, dict): - weight_block_size = quant_config.get("weight_block_size") - else: - weight_block_size = getattr(quant_config, "weight_block_size", None) - - if weight_block_size is None: - raise ValueError("weight_block_size not found in quant_config") - - for k, v in weights: - # Check if quantization is needed - if not should_quantize_param(k): - weights_quantized.append((k, v)) - continue - - # Quantize to FP8 - try: - if weight_block_size is not None: - logger.debug(f"Quantizing to FP8 blockwise: {k}") - param_lp, param_scale = scaled_fp8_blockwise( - v.to(dtype), - weight_block_size=weight_block_size, - ) - param_scale = param_scale.squeeze(-1) - weights_quantized.append([k, param_lp]) - weights_quantized.append([k + "_scale_inv", param_scale]) - else: - raise ValueError( - "Only blockwise quantization is supported. Please set weight_block_size in quant_config" - ) - except Exception as e: - logger.error(f"Failed to quantize {k}: {e}") - # If quantization fails, use original weights - weights_quantized.append((k, v)) - - return weights_quantized - - def scaled_fp8_blockwise( data_hp, weight_block_size, @@ -180,3 +129,54 @@ def scaled_fp8_blockwise( # Convert to target format, but still in original precision container return fp_data, descale_fp + +def quant_weights_by_name(weights, quant_config, dtype=torch.bfloat16): + """FP8 quantization based on parameter name + + Args: + weights: Generator of (name, tensor) pairs + quant_config: Quantization configuration + dtype: Data type for intermediate computation + + Returns: + List of (name, tensor) pairs with quantized weights + """ + + weights_quantized = [] + + if isinstance(quant_config, dict): + weight_block_size = quant_config.get("weight_block_size") + else: + weight_block_size = getattr(quant_config, "weight_block_size", None) + + if weight_block_size is None: + raise ValueError("weight_block_size not found in quant_config") + + for k, v in weights: + # Check if quantization is needed + if not should_quantize_param(k): + weights_quantized.append((k, v)) + continue + + # Quantize to FP8 + try: + if weight_block_size is not None: + if torch.distributed.get_rank() == 0: + logger.debug(f" Quantizing to FP8 blockwise: {k}") + param_lp, param_scale = scaled_fp8_blockwise( + v.to(dtype), + weight_block_size=weight_block_size, + ) + param_scale = param_scale.squeeze(-1) + weights_quantized.append([k, param_lp]) + weights_quantized.append([k + "_scale_inv", param_scale]) + else: + raise ValueError( + "Only blockwise quantization is supported. Please set weight_block_size in quant_config" + ) + except Exception as e: + logger.error(f"Failed to quantize {k}: {e}") + # If quantization fails, use original weights + weights_quantized.append((k, v)) + + return weights_quantized diff --git a/verl/workers/rollout/sglang_rollout/sglang_rollout.py b/verl/workers/rollout/sglang_rollout/sglang_rollout.py index 4b123ddb9df..4d5e8a0720f 100644 --- a/verl/workers/rollout/sglang_rollout/sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/sglang_rollout.py @@ -1626,7 +1626,8 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None update_weights_bucket_bytes = int(self.config.update_weights_bucket_megabytes) << 20 if self.config.get("quantization", None) == "fp8": from verl.utils.sglang.sglang_fp8_utils import quant_weights_by_name - logger.info(f"Convert bf16 weights to fp8 format before loading") + + logger.info("Convert bf16 weights to fp8 format before loading") weights = quant_weights_by_name( weights, self.model_config.hf_config.quantization_config, From bf2ed72d8ae1ca6975512489a66940c6f8bb6fa6 Mon Sep 17 00:00:00 2001 From: Xue Huang Date: Thu, 4 Dec 2025 01:52:57 -0800 Subject: [PATCH 17/18] update doc --- docs/advance/fp8.md | 20 ++++++++----------- .../sglang_rollout/async_sglang_server.py | 2 ++ .../rollout/sglang_rollout/sglang_rollout.py | 3 +++ 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/docs/advance/fp8.md b/docs/advance/fp8.md index 62183a04a84..0006392d7cd 100644 --- a/docs/advance/fp8.md +++ b/docs/advance/fp8.md @@ -1,14 +1,15 @@ # FP8 rollout for verl -Last updated: 11/19/2025 +Last updated: 12/4/2025 -This document introduces FP8 rollout with vllm inference backend in verl. +This document introduces FP8 rollout in verl. -We monkey patch several vLLM functions to enable FP8 rollout for reinforcement learning. -1. **load_weights**: A custom `load_weights` function to quantize the on-the-fly model weights from a higher-precision format to FP8. -2. **process weights after loading**: Replace `vllm.model_executor.layers.quantization.fp8.Fp8LinearMethod.process_weights_after_loading` -function to handle model weights loading after quantization. +We monkey patch several vLLM functions to enable FP8 rollout for reinforcement learning: + +1. **Quantize weights**: Quantize model weights on-the-fly from higher-precision formats to FP8. +2. **Process weights after loading**: For vLLM, we replace the `vllm.model_executor.layers.quantization.fp8.Fp8LinearMethod.process_weights_after_loading` function to handle weight processing after quantization. For SGLang, this patch is not needed as it natively supports loading quantized weights. + ## Support Matrix - FP8 blockwise quantization for rollout @@ -16,7 +17,7 @@ function to handle model weights loading after quantization. which is 1x128 quantization for activations and 128x128 quantization for model weights - Dense models and MoE models - Async rollout interfaces -- vLLM 0.10.x & vLLM 0.11 +- vLLM 0.10.x & vLLM 0.11 & SGlang 0.5.5 - FSDP and Megatron training backends ## Experiments and Outcomes @@ -104,8 +105,3 @@ Or it can be enabled by command line: - `actor_rollout_ref.rollout.quantization=fp8` Please refer to `recipe/dapo/run_dapo_qwen3_moe_30b_vllm_fp8_rollout.sh` - -## Plans - -- will open another PR to support FP8 rollout in SGLang -- further to enable FP8 training in megatron diff --git a/verl/workers/rollout/sglang_rollout/async_sglang_server.py b/verl/workers/rollout/sglang_rollout/async_sglang_server.py index b3355bf4ccb..6af810e7763 100644 --- a/verl/workers/rollout/sglang_rollout/async_sglang_server.py +++ b/verl/workers/rollout/sglang_rollout/async_sglang_server.py @@ -20,6 +20,7 @@ from typing import Any, Optional import ray +import sglang import sglang.srt.entrypoints.engine import torch from ray.actor import ActorHandle @@ -129,6 +130,7 @@ async def launch_server(self, master_address: str = None, master_port: int = Non quantization = self.config.get("quantization", None) if quantization is not None: if quantization == "fp8": + assert sglang.__version__ >= "0.5.5", "sglang>=0.5.5 is required for FP8 quantization" FP8_BLOCK_QUANT_KWARGS = { "activation_scheme": "dynamic", "fmt": "e4m3", diff --git a/verl/workers/rollout/sglang_rollout/sglang_rollout.py b/verl/workers/rollout/sglang_rollout/sglang_rollout.py index 4d5e8a0720f..ed020b26160 100644 --- a/verl/workers/rollout/sglang_rollout/sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/sglang_rollout.py @@ -1556,6 +1556,9 @@ def __init__( device_mesh: DeviceMesh, ): if config.get("quantization", None) == "fp8": + import sglang + + assert sglang.__version__ >= "0.5.5", "sglang>=0.5.5 is required for FP8 quantization" FP8_BLOCK_QUANT_KWARGS = { "activation_scheme": "dynamic", "fmt": "e4m3", From 65d4f87360eabd5c6725eab24508d543fd11b474 Mon Sep 17 00:00:00 2001 From: Xue Huang Date: Thu, 4 Dec 2025 02:28:29 -0800 Subject: [PATCH 18/18] add ci --- .../e2e_ppo_trainer_megatron_sglang_2.yml | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/.github/workflows/e2e_ppo_trainer_megatron_sglang_2.yml b/.github/workflows/e2e_ppo_trainer_megatron_sglang_2.yml index f4e9cc061b3..9d87600dc9b 100644 --- a/.github/workflows/e2e_ppo_trainer_megatron_sglang_2.yml +++ b/.github/workflows/e2e_ppo_trainer_megatron_sglang_2.yml @@ -254,6 +254,33 @@ jobs: ACTOR_FSDP_OPTIMIZER_OFFLOAD=True REF_FSDP_PARAM_OFFLOAD=True \ bash tests/special_e2e/ppo_trainer/run_function_reward.sh + e2e_ppo_trainer_megatron-sglang-fp8: + needs: setup + runs-on: ["${{ needs.setup.outputs.runner-label || 'L20x8' }}"] + timeout-minutes: 60 # Increase this timeout value as needed + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" + HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install --no-deps -e .[test] + - name: Prepare GSM8K dataset + run: | + python3 examples/data_preprocess/gsm8k.py --local_dataset_path ${HOME}/models/hf_data/gsm8k + - name: Running GSM8K E2E training tests on 8 L20 GPUs with SGLang (FP8) + run: | + ray stop --force + ENGINE=sglang ROLLOUT_QUANTIZATION=fp8 ROLLOUT_MODE=async TOTAL_TRAIN_STEPS=2 MODEL_ID=Qwen/Qwen3-0.6B bash tests/special_e2e/run_ppo_trainer_megatron.sh + - name: clean up + run: | + rm -rf checkpoints cleanup: runs-on: ubuntu-latest