[Quantization][RL] Support Online Blockwise FP8 Quantization#15440
[Quantization][RL] Support Online Blockwise FP8 Quantization#15440AniZpZ wants to merge 28 commits intosgl-project:mainfrom
Conversation
reset the author info
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
Experimental DetailsModel: Qwen/Qwen3-8B-Base
Results (2026.01.12 Updated)Observations and OutlookAccuracy of Quantization: The current blockwise FP8 rollout implementation, which converts weights by FP32 -> BF16 -> FP8, shows only minor training-inference discrepancies and maintains training metrics consistent with the BF16 baseline. In contrast, per-channel FP8 quantization leads to notable precision loss during text generation. Further experiments indicate that direct FP32-to-FP8 quantization results in a larger performance gap and an elevated final validation score, which is attributed to longer generated responses. Gen. Throughput: FP8 rollout initially delivers slightly higher throughput than BF16 but is later surpassed as training progresses. In additional runs with a maximum response length of 10K (compared to the current 20K setting), FP8 rollout achieves significantly higher throughput than BF16 when long-tail generation is constrained. |
|
/rerun-failed-ci |
|
TODOs:
|
Shall we finish these todos, then merge this PR? |
@zhaochenyang20 @Hecate0821 @FlamingoPg TODOs have been solved and CI passed |
| # Note: only [128, 128] block size is available for now | ||
| default_block_size = [128, 128] |
There was a problem hiding this comment.
default_block_size = [128, 128] is set for twice. We shall only have it once and set [128, 128] as default value.
| if quant_method is not None: | ||
| quant_method.process_weights_after_loading(module) | ||
| logger.info( | ||
| f"[QuantizedRL] Fllback to per-channel quantization for module: {name}; " |
|
|
||
|
|
||
| # Adapt from https://github.com/volcengine/verl/pull/4415/files#diff-79538cec3426fe5c75d07b39a15e90971f19e98404755792f9b28859b8902ae1 | ||
| def scaled_fp8_blockwise( |
There was a problem hiding this comment.
could we adds dedicated comments and return type hint to this function?
| logger.debug( | ||
| f"[QuantizedRL] Set quant_method weight_block_size={default_block_size} for module: {name}" | ||
| ) | ||
| except Exception as e: |
There was a problem hiding this comment.
Please do not catch errors like this. This may catch unexpected errors.
could we only catch RuntimeError/ValueError?
| # Permute to (BLK_M, BLK_N, BLOCK_SIZE_M, BLOCK_SIZE_N) | ||
| data_hp = data_hp.permute(0, 2, 1, 3) | ||
| # Flatten to (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) | ||
| data_hp = data_hp.to(torch.float32).contiguous().flatten(start_dim=2) |
There was a problem hiding this comment.
is this a must-have to make it fp32?
There was a problem hiding this comment.
Converting to fp32 ensures the precision for scale calculations, as the scales are also in fp32.
| ) | ||
| logger.info( | ||
| "FP8 approach: Model loads with native SGLang FP8 quantization. " | ||
| "FP8 approach: Model loads and gets blockwise fp8 quantization on . " |
There was a problem hiding this comment.
this log seems strange
|
|
||
| def _get_tp_sharded_scale(full_scale_tensor): | ||
| """Get tp sharded scale from full scale tensor""" | ||
| def _get_tp_sharded_scale(full_scale_tensor, is_blockwise=False): |
There was a problem hiding this comment.
This _get_tp_sharded_scale function is too long and seems to convert multiple things together. Could we turn this into serveral functions?






Motivation
Following #9650, support blockwise fp8 rollout with flashrl
Modifications
Accuracy Tests
Benchmarking and Profiling
Checklist