Skip to content

[Quantization][RL] Support Online Blockwise FP8 Quantization#15440

Open
AniZpZ wants to merge 28 commits intosgl-project:mainfrom
AniZpZ:dev/blockwise-fp8-rollout
Open

[Quantization][RL] Support Online Blockwise FP8 Quantization#15440
AniZpZ wants to merge 28 commits intosgl-project:mainfrom
AniZpZ:dev/blockwise-fp8-rollout

Conversation

@AniZpZ
Copy link
Collaborator

@AniZpZ AniZpZ commented Dec 19, 2025

Motivation

Following #9650, support blockwise fp8 rollout with flashrl

Modifications

Accuracy Tests

Benchmarking and Profiling

Checklist

@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@AniZpZ AniZpZ changed the title [Quantization][RL] Support Online Blockwise FP8 Quantization [WIP][Quantization][RL] Support Online Blockwise FP8 Quantization Dec 19, 2025
@AniZpZ AniZpZ changed the title [WIP][Quantization][RL] Support Online Blockwise FP8 Quantization [Quantization][RL] Support Online Blockwise FP8 Quantization Dec 22, 2025
@Wilboludriver
Copy link
Contributor

Wilboludriver commented Dec 22, 2025

Experimental Details

Model: Qwen/Qwen3-8B-Base
Training Recipe: DAPO
Configuration:

  • Training Dataset: DAPO-Math-17k
  • Quantization Scheme: dynamic blockwise fp8
  • Validation: AIME-2024
  • Prompt batch size 32, n=16.
  • Rollout batch size: 32316
  • Train_batch_size & ppo_mini_batch_size 32
  • Token-level TIS, C=2
  • Max response length 20K
  • 8*H20. veRL; CUDA12.9

Results (2026.01.12 Updated)

Observations and Outlook

Accuracy of Quantization: The current blockwise FP8 rollout implementation, which converts weights by FP32 -> BF16 -> FP8, shows only minor training-inference discrepancies and maintains training metrics consistent with the BF16 baseline. In contrast, per-channel FP8 quantization leads to notable precision loss during text generation. Further experiments indicate that direct FP32-to-FP8 quantization results in a larger performance gap and an elevated final validation score, which is attributed to longer generated responses.

Gen. Throughput: FP8 rollout initially delivers slightly higher throughput than BF16 but is later surpassed as training progresses. In additional runs with a maximum response length of 10K (compared to the current 20K setting), FP8 rollout achieves significantly higher throughput than BF16 when long-tail generation is constrained.

@AniZpZ

@zhaochenyang20
Copy link
Collaborator

/rerun-failed-ci

@Hecate0821
Copy link
Contributor

TODOs:

  • Investigate why FP8 achieves higher accuracy and determine if it is purely due to noise.
  • Analyze why FP8 throughput becomes lower than BF16 in subsequent steps.

@zhaochenyang20
Copy link
Collaborator

TODOs:

  • Investigate why FP8 achieves higher accuracy and determine if it is purely due to noise.
  • Analyze why FP8 throughput becomes lower than BF16 in subsequent steps.

Shall we finish these todos, then merge this PR?

@AniZpZ
Copy link
Collaborator Author

AniZpZ commented Jan 12, 2026

TODOs:

  • Investigate why FP8 achieves higher accuracy and determine if it is purely due to noise.
  • Analyze why FP8 throughput becomes lower than BF16 in subsequent steps.

Shall we finish these todos, then merge this PR?

@zhaochenyang20 @Hecate0821 @FlamingoPg TODOs have been solved and CI passed

Comment on lines +882 to +883
# Note: only [128, 128] block size is available for now
default_block_size = [128, 128]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default_block_size = [128, 128] is set for twice. We shall only have it once and set [128, 128] as default value.

if quant_method is not None:
quant_method.process_weights_after_loading(module)
logger.info(
f"[QuantizedRL] Fllback to per-channel quantization for module: {name}; "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fallback to



# Adapt from https://github.com/volcengine/verl/pull/4415/files#diff-79538cec3426fe5c75d07b39a15e90971f19e98404755792f9b28859b8902ae1
def scaled_fp8_blockwise(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we adds dedicated comments and return type hint to this function?

logger.debug(
f"[QuantizedRL] Set quant_method weight_block_size={default_block_size} for module: {name}"
)
except Exception as e:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please do not catch errors like this. This may catch unexpected errors.

could we only catch RuntimeError/ValueError?

# Permute to (BLK_M, BLK_N, BLOCK_SIZE_M, BLOCK_SIZE_N)
data_hp = data_hp.permute(0, 2, 1, 3)
# Flatten to (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N)
data_hp = data_hp.to(torch.float32).contiguous().flatten(start_dim=2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a must-have to make it fp32?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Converting to fp32 ensures the precision for scale calculations, as the scales are also in fp32.

)
logger.info(
"FP8 approach: Model loads with native SGLang FP8 quantization. "
"FP8 approach: Model loads and gets blockwise fp8 quantization on . "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this log seems strange


def _get_tp_sharded_scale(full_scale_tensor):
"""Get tp sharded scale from full scale tensor"""
def _get_tp_sharded_scale(full_scale_tensor, is_blockwise=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This _get_tp_sharded_scale function is too long and seems to convert multiple things together. Could we turn this into serveral functions?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants