-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[rollout, sglang] feat: support blockwise fp8 rollout #4415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1d2dca2
e02c310
9c843bf
972f287
360b0f5
00b6596
bece9e5
f563493
4d4f7b6
96a2832
c4757ef
00bbce4
aa8a4f3
311c082
d772aef
32d0c66
bf2ed72
65d4f87
a9b9d90
037986e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,182 @@ | ||
| # Copyright 2025 Bytedance Ltd. and/or its affiliates | ||
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import logging | ||
| import os | ||
|
|
||
| import torch | ||
|
|
||
| logger = logging.getLogger(__file__) | ||
| logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO")) | ||
|
|
||
|
|
||
| def should_quantize_param(param_name: str) -> bool: | ||
| """Determine whether to quantize to FP8 based on parameter name | ||
|
|
||
| Quantization rules: | ||
| - Must end with .weight (exclude bias) | ||
| - Exclude embedding layers | ||
| - Exclude normalization layers | ||
| - Exclude output layer (lm_head) | ||
| """ | ||
| # Must be a weight parameter | ||
| if not param_name.endswith(".weight"): | ||
| return False | ||
|
|
||
| # Layer types to exclude | ||
| exclude_patterns = [ | ||
| "embed_tokens", # Embedding layer | ||
| "lm_head", # Output layer | ||
| "layernorm", # LayerNorm | ||
| "norm", # Various Norm layers | ||
| "ln_", # LayerNorm variants | ||
| "embeddings", # Embeddings | ||
| ] | ||
|
|
||
| # Check if matches exclude patterns | ||
| param_lower = param_name.lower() | ||
| for pattern in exclude_patterns: | ||
| if pattern in param_lower: | ||
| return False | ||
|
|
||
| # Layer types to include (Linear layers) | ||
| include_patterns = [ | ||
| "q_proj", # Query projection | ||
| "k_proj", # Key projection | ||
| "v_proj", # Value projection | ||
| "o_proj", # Output projection | ||
| "gate_proj", # Gate projection (for MLP) | ||
| "up_proj", # Up projection (for MLP) | ||
| "down_proj", # Down projection (for MLP) | ||
| "fc1", # Fully connected 1 | ||
| "fc2", # Fully connected 2 | ||
| "gate", # Gate (for MoE) | ||
| "mlp", # MLP layers | ||
| ] | ||
|
|
||
| # Check if matches include patterns | ||
| for pattern in include_patterns: | ||
| if pattern in param_lower: | ||
| logger.debug(f"Will quantize FP8: {param_name}") | ||
| return True | ||
|
|
||
| # Do not quantize by default | ||
| logger.debug(f"Skip quantization: {param_name}") | ||
| return False | ||
|
|
||
|
|
||
| def scaled_fp8_blockwise( | ||
| data_hp, | ||
| weight_block_size, | ||
| ): | ||
| # cast tensor from high precision to FP8 with 128*128 blockwise quantization. | ||
| assert len(data_hp.shape) == 2, "Only 2d input tensor is supported" | ||
|
|
||
| block_size1 = weight_block_size[1] | ||
| block_size0 = weight_block_size[0] | ||
| assert data_hp.shape[1] % block_size1 == 0, ( | ||
| f"data_hp.shape[1] {data_hp.shape[1]} must be a multiple of block_size1: {block_size1}." | ||
| ) | ||
| assert data_hp.shape[0] % block_size0 == 0, ( | ||
| f"data_hp.shape[0] {data_hp.shape[0]} must be a multiple of block_size0: {block_size0}." | ||
| ) | ||
|
|
||
| # FP8 | ||
| max_dtype = torch.finfo(torch.float8_e4m3fn).max | ||
|
|
||
| original_shape = data_hp.shape | ||
| blk_m, blk_n = data_hp.shape[0] // block_size0, data_hp.shape[1] // block_size1 | ||
|
|
||
| assert block_size1 == block_size0 | ||
| data_hp = data_hp.reshape(blk_m, block_size0, blk_n, block_size1) | ||
|
|
||
| # Permute to (BLK_M, BLK_N, BLOCK_SIZE_M, BLOCK_SIZE_N) | ||
| data_hp = data_hp.permute(0, 2, 1, 3) | ||
| # Flatten to (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) | ||
| data_hp = data_hp.to(torch.float32).contiguous().flatten(start_dim=2) | ||
|
|
||
| # Calculate max absolute value per block | ||
| max_abs = torch.amax(torch.abs(data_hp), dim=-1, keepdim=True) | ||
|
|
||
| # Use FP32 scale | ||
| scale_fp = max_dtype / max_abs | ||
| scale_fp = torch.where(max_abs == 0, 1.0, scale_fp) | ||
| # preserve the behavior for 0 amax case | ||
| scale_fp = torch.where(max_abs == torch.inf, 1.0, scale_fp) | ||
|
|
||
| descale_fp = torch.reciprocal(scale_fp) | ||
|
|
||
| # Scale and saturate cast the data elements to max of target dtype | ||
| data_lp = torch.clamp(data_hp * scale_fp, min=-1 * max_dtype, max=max_dtype) | ||
|
|
||
| fp_data = data_lp.to(torch.float8_e4m3fn) | ||
|
|
||
| # (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) to (M, N) | ||
| fp_data = fp_data.reshape(blk_m, blk_n, block_size0, block_size1).permute(0, 2, 1, 3).reshape(original_shape) | ||
|
|
||
| # Convert to target format, but still in original precision container | ||
| return fp_data, descale_fp | ||
|
|
||
|
|
||
| def quant_weights_by_name(weights, quant_config, dtype=torch.bfloat16): | ||
| """FP8 quantization based on parameter name | ||
|
|
||
| Args: | ||
| weights: Generator of (name, tensor) pairs | ||
| quant_config: Quantization configuration | ||
| dtype: Data type for intermediate computation | ||
|
|
||
| Returns: | ||
| List of (name, tensor) pairs with quantized weights | ||
| """ | ||
|
|
||
| weights_quantized = [] | ||
|
|
||
| if isinstance(quant_config, dict): | ||
| weight_block_size = quant_config.get("weight_block_size") | ||
| else: | ||
| weight_block_size = getattr(quant_config, "weight_block_size", None) | ||
|
|
||
| if weight_block_size is None: | ||
| raise ValueError("weight_block_size not found in quant_config") | ||
|
|
||
| for k, v in weights: | ||
| # Check if quantization is needed | ||
| if not should_quantize_param(k): | ||
| weights_quantized.append((k, v)) | ||
| continue | ||
|
|
||
| # Quantize to FP8 | ||
| try: | ||
| if weight_block_size is not None: | ||
| if torch.distributed.get_rank() == 0: | ||
| logger.debug(f" Quantizing to FP8 blockwise: {k}") | ||
| param_lp, param_scale = scaled_fp8_blockwise( | ||
| v.to(dtype), | ||
| weight_block_size=weight_block_size, | ||
| ) | ||
| param_scale = param_scale.squeeze(-1) | ||
| weights_quantized.append([k, param_lp]) | ||
| weights_quantized.append([k + "_scale_inv", param_scale]) | ||
| else: | ||
| raise ValueError( | ||
| "Only blockwise quantization is supported. Please set weight_block_size in quant_config" | ||
| ) | ||
| except Exception as e: | ||
| logger.error(f"Failed to quantize {k}: {e}") | ||
| # If quantization fails, use original weights | ||
| weights_quantized.append((k, v)) | ||
|
|
||
| return weights_quantized | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,11 +14,13 @@ | |
| # limitations under the License. | ||
| import asyncio | ||
| import dataclasses | ||
| import json | ||
| import logging | ||
| import os | ||
| from typing import Any, Optional | ||
|
|
||
| import ray | ||
| import sglang | ||
| import sglang.srt.entrypoints.engine | ||
| import torch | ||
| from ray.actor import ActorHandle | ||
|
|
@@ -125,6 +127,19 @@ async def launch_server(self, master_address: str = None, master_port: int = Non | |
|
|
||
| engine_kwargs = self.config.get("engine_kwargs", {}).get("sglang", {}) or {} | ||
| attention_backend = engine_kwargs.pop("attention_backend", None) | ||
| quantization = self.config.get("quantization", None) | ||
| if quantization is not None: | ||
| if quantization == "fp8": | ||
| assert sglang.__version__ >= "0.5.5", "sglang>=0.5.5 is required for FP8 quantization" | ||
| FP8_BLOCK_QUANT_KWARGS = { | ||
| "activation_scheme": "dynamic", | ||
| "fmt": "e4m3", | ||
| "quant_method": "fp8", | ||
| "weight_block_size": [128, 128], | ||
| } | ||
| fp8_block_quant_kwargs = dict(FP8_BLOCK_QUANT_KWARGS) | ||
|
Comment on lines
+133
to
+140
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The FP8 quantization configuration logic, including the version check and |
||
| else: | ||
| raise ValueError(f"Currently only support fp8 quantization, got: {quantization}") | ||
| dist_init_addr = ( | ||
| f"[{self._master_address}]:{self._master_port}" | ||
| if is_valid_ipv6_address(self._master_address) | ||
|
|
@@ -153,6 +168,10 @@ async def launch_server(self, master_address: str = None, master_port: int = Non | |
| "attention_backend": attention_backend if attention_backend is not None else "fa3", | ||
| "skip_tokenizer_init": self.config.skip_tokenizer_init, | ||
| "skip_server_warmup": True, | ||
| "quantization": quantization, | ||
| "json_model_override_args": json.dumps({"quantization_config": fp8_block_quant_kwargs}) | ||
| if quantization == "fp8" | ||
| else json.dumps({}), | ||
| **engine_kwargs, | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -97,6 +97,18 @@ def __init__( | |
| model_config: HFModelConfig, | ||
| device_mesh: DeviceMesh, | ||
| ): | ||
| if config.get("quantization", None) == "fp8": | ||
| import sglang | ||
|
|
||
| assert sglang.__version__ >= "0.5.5", "sglang>=0.5.5 is required for FP8 quantization" | ||
| FP8_BLOCK_QUANT_KWARGS = { | ||
| "activation_scheme": "dynamic", | ||
| "fmt": "e4m3", | ||
| "quant_method": "fp8", | ||
| "weight_block_size": [128, 128], | ||
| } | ||
| fp8_block_quant_kwargs = dict(FP8_BLOCK_QUANT_KWARGS) | ||
|
Comment on lines
+103
to
+110
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The FP8 quantization configuration logic, including the version check and |
||
| model_config.hf_config.quantization_config = fp8_block_quant_kwargs | ||
| super().__init__(config, model_config, device_mesh) | ||
| self._engine: AsyncHttpServerAdapter = None | ||
|
|
||
|
|
@@ -157,6 +169,18 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None | |
| await self._init_server_adapter() | ||
|
|
||
| update_weights_bucket_bytes = int(self.config.update_weights_bucket_megabytes) << 20 | ||
| if self.config.get("quantization", None) == "fp8": | ||
| from verl.utils.sglang.sglang_fp8_utils import quant_weights_by_name | ||
|
|
||
| logger.info("Convert bf16 weights to fp8 format before loading") | ||
| weights = quant_weights_by_name( | ||
| weights, | ||
| self.model_config.hf_config.quantization_config, | ||
| dtype=self.model_config.hf_config.dtype, | ||
| ) | ||
| else: | ||
| weights = weights | ||
|
|
||
| for params_batch in get_named_tensor_buckets(weights, update_weights_bucket_bytes): | ||
| await sgl_update_weights( | ||
| engine=self._engine, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This
elseblock is unreachable.weight_block_sizeis checked forNoneon line 152 before the loop begins, and an exception is raised if it isNone. Consequently, the conditionweight_block_size is not Noneon line 163 will always evaluate to true inside the loop, rendering theelsebranch dead code. Removing the conditional wrapper and the unreachableelseblock will improve code clarity and maintainability.