diff --git a/.github/workflows/e2e_ppo_trainer_megatron_sglang_2.yml b/.github/workflows/e2e_ppo_trainer_megatron_sglang_2.yml index f832ac42834..d5e5efad222 100644 --- a/.github/workflows/e2e_ppo_trainer_megatron_sglang_2.yml +++ b/.github/workflows/e2e_ppo_trainer_megatron_sglang_2.yml @@ -220,6 +220,33 @@ jobs: ACTOR_FSDP_OPTIMIZER_OFFLOAD=True REF_FSDP_PARAM_OFFLOAD=True \ bash tests/special_e2e/ppo_trainer/run_function_reward.sh + e2e_ppo_trainer_megatron-sglang-fp8: + needs: setup + runs-on: ["${{ needs.setup.outputs.runner-label || 'L20x8' }}"] + timeout-minutes: 60 # Increase this timeout value as needed + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" + HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install --no-deps -e .[test] + - name: Prepare GSM8K dataset + run: | + python3 examples/data_preprocess/gsm8k.py --local_dataset_path ${HOME}/models/hf_data/gsm8k + - name: Running GSM8K E2E training tests on 8 L20 GPUs with SGLang (FP8) + run: | + ray stop --force + ENGINE=sglang ROLLOUT_QUANTIZATION=fp8 ROLLOUT_MODE=async TOTAL_TRAIN_STEPS=2 MODEL_ID=Qwen/Qwen3-0.6B bash tests/special_e2e/run_ppo_trainer_megatron.sh + - name: clean up + run: | + rm -rf checkpoints cleanup: runs-on: ubuntu-latest diff --git a/docs/advance/fp8.md b/docs/advance/fp8.md index 62183a04a84..0006392d7cd 100644 --- a/docs/advance/fp8.md +++ b/docs/advance/fp8.md @@ -1,14 +1,15 @@ # FP8 rollout for verl -Last updated: 11/19/2025 +Last updated: 12/4/2025 -This document introduces FP8 rollout with vllm inference backend in verl. +This document introduces FP8 rollout in verl. -We monkey patch several vLLM functions to enable FP8 rollout for reinforcement learning. -1. **load_weights**: A custom `load_weights` function to quantize the on-the-fly model weights from a higher-precision format to FP8. -2. **process weights after loading**: Replace `vllm.model_executor.layers.quantization.fp8.Fp8LinearMethod.process_weights_after_loading` -function to handle model weights loading after quantization. +We monkey patch several vLLM functions to enable FP8 rollout for reinforcement learning: + +1. **Quantize weights**: Quantize model weights on-the-fly from higher-precision formats to FP8. +2. **Process weights after loading**: For vLLM, we replace the `vllm.model_executor.layers.quantization.fp8.Fp8LinearMethod.process_weights_after_loading` function to handle weight processing after quantization. For SGLang, this patch is not needed as it natively supports loading quantized weights. + ## Support Matrix - FP8 blockwise quantization for rollout @@ -16,7 +17,7 @@ function to handle model weights loading after quantization. which is 1x128 quantization for activations and 128x128 quantization for model weights - Dense models and MoE models - Async rollout interfaces -- vLLM 0.10.x & vLLM 0.11 +- vLLM 0.10.x & vLLM 0.11 & SGlang 0.5.5 - FSDP and Megatron training backends ## Experiments and Outcomes @@ -104,8 +105,3 @@ Or it can be enabled by command line: - `actor_rollout_ref.rollout.quantization=fp8` Please refer to `recipe/dapo/run_dapo_qwen3_moe_30b_vllm_fp8_rollout.sh` - -## Plans - -- will open another PR to support FP8 rollout in SGLang -- further to enable FP8 training in megatron diff --git a/verl/utils/sglang/sglang_fp8_utils.py b/verl/utils/sglang/sglang_fp8_utils.py new file mode 100644 index 00000000000..1833c02abb8 --- /dev/null +++ b/verl/utils/sglang/sglang_fp8_utils.py @@ -0,0 +1,182 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os + +import torch + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO")) + + +def should_quantize_param(param_name: str) -> bool: + """Determine whether to quantize to FP8 based on parameter name + + Quantization rules: + - Must end with .weight (exclude bias) + - Exclude embedding layers + - Exclude normalization layers + - Exclude output layer (lm_head) + """ + # Must be a weight parameter + if not param_name.endswith(".weight"): + return False + + # Layer types to exclude + exclude_patterns = [ + "embed_tokens", # Embedding layer + "lm_head", # Output layer + "layernorm", # LayerNorm + "norm", # Various Norm layers + "ln_", # LayerNorm variants + "embeddings", # Embeddings + ] + + # Check if matches exclude patterns + param_lower = param_name.lower() + for pattern in exclude_patterns: + if pattern in param_lower: + return False + + # Layer types to include (Linear layers) + include_patterns = [ + "q_proj", # Query projection + "k_proj", # Key projection + "v_proj", # Value projection + "o_proj", # Output projection + "gate_proj", # Gate projection (for MLP) + "up_proj", # Up projection (for MLP) + "down_proj", # Down projection (for MLP) + "fc1", # Fully connected 1 + "fc2", # Fully connected 2 + "gate", # Gate (for MoE) + "mlp", # MLP layers + ] + + # Check if matches include patterns + for pattern in include_patterns: + if pattern in param_lower: + logger.debug(f"Will quantize FP8: {param_name}") + return True + + # Do not quantize by default + logger.debug(f"Skip quantization: {param_name}") + return False + + +def scaled_fp8_blockwise( + data_hp, + weight_block_size, +): + # cast tensor from high precision to FP8 with 128*128 blockwise quantization. + assert len(data_hp.shape) == 2, "Only 2d input tensor is supported" + + block_size1 = weight_block_size[1] + block_size0 = weight_block_size[0] + assert data_hp.shape[1] % block_size1 == 0, ( + f"data_hp.shape[1] {data_hp.shape[1]} must be a multiple of block_size1: {block_size1}." + ) + assert data_hp.shape[0] % block_size0 == 0, ( + f"data_hp.shape[0] {data_hp.shape[0]} must be a multiple of block_size0: {block_size0}." + ) + + # FP8 + max_dtype = torch.finfo(torch.float8_e4m3fn).max + + original_shape = data_hp.shape + blk_m, blk_n = data_hp.shape[0] // block_size0, data_hp.shape[1] // block_size1 + + assert block_size1 == block_size0 + data_hp = data_hp.reshape(blk_m, block_size0, blk_n, block_size1) + + # Permute to (BLK_M, BLK_N, BLOCK_SIZE_M, BLOCK_SIZE_N) + data_hp = data_hp.permute(0, 2, 1, 3) + # Flatten to (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) + data_hp = data_hp.to(torch.float32).contiguous().flatten(start_dim=2) + + # Calculate max absolute value per block + max_abs = torch.amax(torch.abs(data_hp), dim=-1, keepdim=True) + + # Use FP32 scale + scale_fp = max_dtype / max_abs + scale_fp = torch.where(max_abs == 0, 1.0, scale_fp) + # preserve the behavior for 0 amax case + scale_fp = torch.where(max_abs == torch.inf, 1.0, scale_fp) + + descale_fp = torch.reciprocal(scale_fp) + + # Scale and saturate cast the data elements to max of target dtype + data_lp = torch.clamp(data_hp * scale_fp, min=-1 * max_dtype, max=max_dtype) + + fp_data = data_lp.to(torch.float8_e4m3fn) + + # (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) to (M, N) + fp_data = fp_data.reshape(blk_m, blk_n, block_size0, block_size1).permute(0, 2, 1, 3).reshape(original_shape) + + # Convert to target format, but still in original precision container + return fp_data, descale_fp + + +def quant_weights_by_name(weights, quant_config, dtype=torch.bfloat16): + """FP8 quantization based on parameter name + + Args: + weights: Generator of (name, tensor) pairs + quant_config: Quantization configuration + dtype: Data type for intermediate computation + + Returns: + List of (name, tensor) pairs with quantized weights + """ + + weights_quantized = [] + + if isinstance(quant_config, dict): + weight_block_size = quant_config.get("weight_block_size") + else: + weight_block_size = getattr(quant_config, "weight_block_size", None) + + if weight_block_size is None: + raise ValueError("weight_block_size not found in quant_config") + + for k, v in weights: + # Check if quantization is needed + if not should_quantize_param(k): + weights_quantized.append((k, v)) + continue + + # Quantize to FP8 + try: + if weight_block_size is not None: + if torch.distributed.get_rank() == 0: + logger.debug(f" Quantizing to FP8 blockwise: {k}") + param_lp, param_scale = scaled_fp8_blockwise( + v.to(dtype), + weight_block_size=weight_block_size, + ) + param_scale = param_scale.squeeze(-1) + weights_quantized.append([k, param_lp]) + weights_quantized.append([k + "_scale_inv", param_scale]) + else: + raise ValueError( + "Only blockwise quantization is supported. Please set weight_block_size in quant_config" + ) + except Exception as e: + logger.error(f"Failed to quantize {k}: {e}") + # If quantization fails, use original weights + weights_quantized.append((k, v)) + + return weights_quantized diff --git a/verl/workers/config/rollout.py b/verl/workers/config/rollout.py index a615e8df018..bea1bd4520d 100644 --- a/verl/workers/config/rollout.py +++ b/verl/workers/config/rollout.py @@ -204,6 +204,7 @@ class RolloutConfig(BaseConfig): skip_tokenizer_init: bool = False quantization: Optional[str] = None + enable_rollout_routing_replay: bool = False def __post_init__(self): diff --git a/verl/workers/rollout/sglang_rollout/async_sglang_server.py b/verl/workers/rollout/sglang_rollout/async_sglang_server.py index f9b9cea314d..e78700d9f7a 100644 --- a/verl/workers/rollout/sglang_rollout/async_sglang_server.py +++ b/verl/workers/rollout/sglang_rollout/async_sglang_server.py @@ -14,11 +14,13 @@ # limitations under the License. import asyncio import dataclasses +import json import logging import os from typing import Any, Optional import ray +import sglang import sglang.srt.entrypoints.engine import torch from ray.actor import ActorHandle @@ -125,6 +127,19 @@ async def launch_server(self, master_address: str = None, master_port: int = Non engine_kwargs = self.config.get("engine_kwargs", {}).get("sglang", {}) or {} attention_backend = engine_kwargs.pop("attention_backend", None) + quantization = self.config.get("quantization", None) + if quantization is not None: + if quantization == "fp8": + assert sglang.__version__ >= "0.5.5", "sglang>=0.5.5 is required for FP8 quantization" + FP8_BLOCK_QUANT_KWARGS = { + "activation_scheme": "dynamic", + "fmt": "e4m3", + "quant_method": "fp8", + "weight_block_size": [128, 128], + } + fp8_block_quant_kwargs = dict(FP8_BLOCK_QUANT_KWARGS) + else: + raise ValueError(f"Currently only support fp8 quantization, got: {quantization}") dist_init_addr = ( f"[{self._master_address}]:{self._master_port}" if is_valid_ipv6_address(self._master_address) @@ -153,6 +168,10 @@ async def launch_server(self, master_address: str = None, master_port: int = Non "attention_backend": attention_backend if attention_backend is not None else "fa3", "skip_tokenizer_init": self.config.skip_tokenizer_init, "skip_server_warmup": True, + "quantization": quantization, + "json_model_override_args": json.dumps({"quantization_config": fp8_block_quant_kwargs}) + if quantization == "fp8" + else json.dumps({}), **engine_kwargs, } diff --git a/verl/workers/rollout/sglang_rollout/sglang_rollout.py b/verl/workers/rollout/sglang_rollout/sglang_rollout.py index 299230dc6c9..63d3b0c36af 100644 --- a/verl/workers/rollout/sglang_rollout/sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/sglang_rollout.py @@ -97,6 +97,18 @@ def __init__( model_config: HFModelConfig, device_mesh: DeviceMesh, ): + if config.get("quantization", None) == "fp8": + import sglang + + assert sglang.__version__ >= "0.5.5", "sglang>=0.5.5 is required for FP8 quantization" + FP8_BLOCK_QUANT_KWARGS = { + "activation_scheme": "dynamic", + "fmt": "e4m3", + "quant_method": "fp8", + "weight_block_size": [128, 128], + } + fp8_block_quant_kwargs = dict(FP8_BLOCK_QUANT_KWARGS) + model_config.hf_config.quantization_config = fp8_block_quant_kwargs super().__init__(config, model_config, device_mesh) self._engine: AsyncHttpServerAdapter = None @@ -157,6 +169,18 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None await self._init_server_adapter() update_weights_bucket_bytes = int(self.config.update_weights_bucket_megabytes) << 20 + if self.config.get("quantization", None) == "fp8": + from verl.utils.sglang.sglang_fp8_utils import quant_weights_by_name + + logger.info("Convert bf16 weights to fp8 format before loading") + weights = quant_weights_by_name( + weights, + self.model_config.hf_config.quantization_config, + dtype=self.model_config.hf_config.dtype, + ) + else: + weights = weights + for params_batch in get_named_tensor_buckets(weights, update_weights_bucket_bytes): await sgl_update_weights( engine=self._engine,