Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions .github/workflows/e2e_ppo_trainer_megatron_sglang_2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,33 @@ jobs:
ACTOR_FSDP_OPTIMIZER_OFFLOAD=True REF_FSDP_PARAM_OFFLOAD=True \
bash tests/special_e2e/ppo_trainer/run_function_reward.sh

e2e_ppo_trainer_megatron-sglang-fp8:
needs: setup
runs-on: ["${{ needs.setup.outputs.runner-label || 'L20x8' }}"]
timeout-minutes: 60 # Increase this timeout value as needed
env:
HTTP_PROXY: ${{ secrets.PROXY_HTTP }}
HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }}
NO_PROXY: "localhost,127.0.0.1,hf-mirror.com"
HF_ENDPOINT: "https://hf-mirror.com"
HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
fetch-depth: 0
- name: Install the current repository
run: |
pip3 install --no-deps -e .[test]
- name: Prepare GSM8K dataset
run: |
python3 examples/data_preprocess/gsm8k.py --local_dataset_path ${HOME}/models/hf_data/gsm8k
- name: Running GSM8K E2E training tests on 8 L20 GPUs with SGLang (FP8)
run: |
ray stop --force
ENGINE=sglang ROLLOUT_QUANTIZATION=fp8 ROLLOUT_MODE=async TOTAL_TRAIN_STEPS=2 MODEL_ID=Qwen/Qwen3-0.6B bash tests/special_e2e/run_ppo_trainer_megatron.sh
- name: clean up
run: |
rm -rf checkpoints

cleanup:
runs-on: ubuntu-latest
Expand Down
20 changes: 8 additions & 12 deletions docs/advance/fp8.md
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
# FP8 rollout for verl

Last updated: 11/19/2025
Last updated: 12/4/2025

This document introduces FP8 rollout with vllm inference backend in verl.
This document introduces FP8 rollout in verl.


We monkey patch several vLLM functions to enable FP8 rollout for reinforcement learning.
1. **load_weights**: A custom `load_weights` function to quantize the on-the-fly model weights from a higher-precision format to FP8.
2. **process weights after loading**: Replace `vllm.model_executor.layers.quantization.fp8.Fp8LinearMethod.process_weights_after_loading`
function to handle model weights loading after quantization.
We monkey patch several vLLM functions to enable FP8 rollout for reinforcement learning:

1. **Quantize weights**: Quantize model weights on-the-fly from higher-precision formats to FP8.
2. **Process weights after loading**: For vLLM, we replace the `vllm.model_executor.layers.quantization.fp8.Fp8LinearMethod.process_weights_after_loading` function to handle weight processing after quantization. For SGLang, this patch is not needed as it natively supports loading quantized weights.


## Support Matrix
- FP8 blockwise quantization for rollout
- Used in Deepseek,
which is 1x128 quantization for activations and 128x128 quantization for model weights
- Dense models and MoE models
- Async rollout interfaces
- vLLM 0.10.x & vLLM 0.11
- vLLM 0.10.x & vLLM 0.11 & SGlang 0.5.5
- FSDP and Megatron training backends

## Experiments and Outcomes
Expand Down Expand Up @@ -104,8 +105,3 @@ Or it can be enabled by command line:
- `actor_rollout_ref.rollout.quantization=fp8`

Please refer to `recipe/dapo/run_dapo_qwen3_moe_30b_vllm_fp8_rollout.sh`

## Plans

- will open another PR to support FP8 rollout in SGLang
- further to enable FP8 training in megatron
182 changes: 182 additions & 0 deletions verl/utils/sglang/sglang_fp8_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os

import torch

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO"))


def should_quantize_param(param_name: str) -> bool:
"""Determine whether to quantize to FP8 based on parameter name

Quantization rules:
- Must end with .weight (exclude bias)
- Exclude embedding layers
- Exclude normalization layers
- Exclude output layer (lm_head)
"""
# Must be a weight parameter
if not param_name.endswith(".weight"):
return False

# Layer types to exclude
exclude_patterns = [
"embed_tokens", # Embedding layer
"lm_head", # Output layer
"layernorm", # LayerNorm
"norm", # Various Norm layers
"ln_", # LayerNorm variants
"embeddings", # Embeddings
]

# Check if matches exclude patterns
param_lower = param_name.lower()
for pattern in exclude_patterns:
if pattern in param_lower:
return False

# Layer types to include (Linear layers)
include_patterns = [
"q_proj", # Query projection
"k_proj", # Key projection
"v_proj", # Value projection
"o_proj", # Output projection
"gate_proj", # Gate projection (for MLP)
"up_proj", # Up projection (for MLP)
"down_proj", # Down projection (for MLP)
"fc1", # Fully connected 1
"fc2", # Fully connected 2
"gate", # Gate (for MoE)
"mlp", # MLP layers
]

# Check if matches include patterns
for pattern in include_patterns:
if pattern in param_lower:
logger.debug(f"Will quantize FP8: {param_name}")
return True

# Do not quantize by default
logger.debug(f"Skip quantization: {param_name}")
return False


def scaled_fp8_blockwise(
data_hp,
weight_block_size,
):
# cast tensor from high precision to FP8 with 128*128 blockwise quantization.
assert len(data_hp.shape) == 2, "Only 2d input tensor is supported"

block_size1 = weight_block_size[1]
block_size0 = weight_block_size[0]
assert data_hp.shape[1] % block_size1 == 0, (
f"data_hp.shape[1] {data_hp.shape[1]} must be a multiple of block_size1: {block_size1}."
)
assert data_hp.shape[0] % block_size0 == 0, (
f"data_hp.shape[0] {data_hp.shape[0]} must be a multiple of block_size0: {block_size0}."
)

# FP8
max_dtype = torch.finfo(torch.float8_e4m3fn).max

original_shape = data_hp.shape
blk_m, blk_n = data_hp.shape[0] // block_size0, data_hp.shape[1] // block_size1

assert block_size1 == block_size0
data_hp = data_hp.reshape(blk_m, block_size0, blk_n, block_size1)

# Permute to (BLK_M, BLK_N, BLOCK_SIZE_M, BLOCK_SIZE_N)
data_hp = data_hp.permute(0, 2, 1, 3)
# Flatten to (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N)
data_hp = data_hp.to(torch.float32).contiguous().flatten(start_dim=2)

# Calculate max absolute value per block
max_abs = torch.amax(torch.abs(data_hp), dim=-1, keepdim=True)

# Use FP32 scale
scale_fp = max_dtype / max_abs
scale_fp = torch.where(max_abs == 0, 1.0, scale_fp)
# preserve the behavior for 0 amax case
scale_fp = torch.where(max_abs == torch.inf, 1.0, scale_fp)

descale_fp = torch.reciprocal(scale_fp)

# Scale and saturate cast the data elements to max of target dtype
data_lp = torch.clamp(data_hp * scale_fp, min=-1 * max_dtype, max=max_dtype)

fp_data = data_lp.to(torch.float8_e4m3fn)

# (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) to (M, N)
fp_data = fp_data.reshape(blk_m, blk_n, block_size0, block_size1).permute(0, 2, 1, 3).reshape(original_shape)

# Convert to target format, but still in original precision container
return fp_data, descale_fp


def quant_weights_by_name(weights, quant_config, dtype=torch.bfloat16):
"""FP8 quantization based on parameter name

Args:
weights: Generator of (name, tensor) pairs
quant_config: Quantization configuration
dtype: Data type for intermediate computation

Returns:
List of (name, tensor) pairs with quantized weights
"""

weights_quantized = []

if isinstance(quant_config, dict):
weight_block_size = quant_config.get("weight_block_size")
else:
weight_block_size = getattr(quant_config, "weight_block_size", None)

if weight_block_size is None:
raise ValueError("weight_block_size not found in quant_config")

for k, v in weights:
# Check if quantization is needed
if not should_quantize_param(k):
weights_quantized.append((k, v))
continue

# Quantize to FP8
try:
if weight_block_size is not None:
if torch.distributed.get_rank() == 0:
logger.debug(f" Quantizing to FP8 blockwise: {k}")
param_lp, param_scale = scaled_fp8_blockwise(
v.to(dtype),
weight_block_size=weight_block_size,
)
param_scale = param_scale.squeeze(-1)
weights_quantized.append([k, param_lp])
weights_quantized.append([k + "_scale_inv", param_scale])
else:
raise ValueError(
"Only blockwise quantization is supported. Please set weight_block_size in quant_config"
)
Comment on lines +163 to +176
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This else block is unreachable. weight_block_size is checked for None on line 152 before the loop begins, and an exception is raised if it is None. Consequently, the condition weight_block_size is not None on line 163 will always evaluate to true inside the loop, rendering the else branch dead code. Removing the conditional wrapper and the unreachable else block will improve code clarity and maintainability.

            if torch.distributed.get_rank() == 0:
                logger.debug(f"  Quantizing to FP8 blockwise: {k}")
            param_lp, param_scale = scaled_fp8_blockwise(
                v.to(dtype),
                weight_block_size=weight_block_size,
            )
            param_scale = param_scale.squeeze(-1)
            weights_quantized.append([k, param_lp])
            weights_quantized.append([k + "_scale_inv", param_scale])

except Exception as e:
logger.error(f"Failed to quantize {k}: {e}")
# If quantization fails, use original weights
weights_quantized.append((k, v))

return weights_quantized
1 change: 1 addition & 0 deletions verl/workers/config/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ class RolloutConfig(BaseConfig):
skip_tokenizer_init: bool = False

quantization: Optional[str] = None

enable_rollout_routing_replay: bool = False

def __post_init__(self):
Expand Down
19 changes: 19 additions & 0 deletions verl/workers/rollout/sglang_rollout/async_sglang_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
# limitations under the License.
import asyncio
import dataclasses
import json
import logging
import os
from typing import Any, Optional

import ray
import sglang
import sglang.srt.entrypoints.engine
import torch
from ray.actor import ActorHandle
Expand Down Expand Up @@ -125,6 +127,19 @@ async def launch_server(self, master_address: str = None, master_port: int = Non

engine_kwargs = self.config.get("engine_kwargs", {}).get("sglang", {}) or {}
attention_backend = engine_kwargs.pop("attention_backend", None)
quantization = self.config.get("quantization", None)
if quantization is not None:
if quantization == "fp8":
assert sglang.__version__ >= "0.5.5", "sglang>=0.5.5 is required for FP8 quantization"
FP8_BLOCK_QUANT_KWARGS = {
"activation_scheme": "dynamic",
"fmt": "e4m3",
"quant_method": "fp8",
"weight_block_size": [128, 128],
}
fp8_block_quant_kwargs = dict(FP8_BLOCK_QUANT_KWARGS)
Comment on lines +133 to +140
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The FP8 quantization configuration logic, including the version check and FP8_BLOCK_QUANT_KWARGS dictionary, is duplicated in verl/workers/rollout/sglang_rollout/sglang_rollout.py. To improve maintainability and prevent future inconsistencies, this logic should be centralized. Consider moving FP8_BLOCK_QUANT_KWARGS to verl/utils/sglang/sglang_fp8_utils.py as a constant and creating a helper function there to encapsulate the version check and config creation.

else:
raise ValueError(f"Currently only support fp8 quantization, got: {quantization}")
dist_init_addr = (
f"[{self._master_address}]:{self._master_port}"
if is_valid_ipv6_address(self._master_address)
Expand Down Expand Up @@ -153,6 +168,10 @@ async def launch_server(self, master_address: str = None, master_port: int = Non
"attention_backend": attention_backend if attention_backend is not None else "fa3",
"skip_tokenizer_init": self.config.skip_tokenizer_init,
"skip_server_warmup": True,
"quantization": quantization,
"json_model_override_args": json.dumps({"quantization_config": fp8_block_quant_kwargs})
if quantization == "fp8"
else json.dumps({}),
**engine_kwargs,
}

Expand Down
24 changes: 24 additions & 0 deletions verl/workers/rollout/sglang_rollout/sglang_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,18 @@ def __init__(
model_config: HFModelConfig,
device_mesh: DeviceMesh,
):
if config.get("quantization", None) == "fp8":
import sglang

assert sglang.__version__ >= "0.5.5", "sglang>=0.5.5 is required for FP8 quantization"
FP8_BLOCK_QUANT_KWARGS = {
"activation_scheme": "dynamic",
"fmt": "e4m3",
"quant_method": "fp8",
"weight_block_size": [128, 128],
}
fp8_block_quant_kwargs = dict(FP8_BLOCK_QUANT_KWARGS)
Comment on lines +103 to +110
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The FP8 quantization configuration logic, including the version check and FP8_BLOCK_QUANT_KWARGS dictionary, is duplicated in verl/workers/rollout/sglang_rollout/async_sglang_server.py. To improve maintainability and prevent future inconsistencies, this logic should be centralized. Consider moving FP8_BLOCK_QUANT_KWARGS to verl/utils/sglang/sglang_fp8_utils.py as a constant and creating a helper function there to encapsulate the version check and config creation.

model_config.hf_config.quantization_config = fp8_block_quant_kwargs
super().__init__(config, model_config, device_mesh)
self._engine: AsyncHttpServerAdapter = None

Expand Down Expand Up @@ -157,6 +169,18 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None
await self._init_server_adapter()

update_weights_bucket_bytes = int(self.config.update_weights_bucket_megabytes) << 20
if self.config.get("quantization", None) == "fp8":
from verl.utils.sglang.sglang_fp8_utils import quant_weights_by_name

logger.info("Convert bf16 weights to fp8 format before loading")
weights = quant_weights_by_name(
weights,
self.model_config.hf_config.quantization_config,
dtype=self.model_config.hf_config.dtype,
)
else:
weights = weights

for params_batch in get_named_tensor_buckets(weights, update_weights_bucket_bytes):
await sgl_update_weights(
engine=self._engine,
Expand Down
Loading