Skip to content

Support fp8 block gemm with fp8_e8m0 scale#398

Open
Yejing-Lai wants to merge 1 commit into
vllm-project:mainfrom
Yejing-Lai:lyj/update_fp8_gemm
Open

Support fp8 block gemm with fp8_e8m0 scale#398
Yejing-Lai wants to merge 1 commit into
vllm-project:mainfrom
Yejing-Lai:lyj/update_fp8_gemm

Conversation

@Yejing-Lai

Copy link
Copy Markdown
Contributor

Support FP8 block GEMM with fp8_e8m0 scale and add related UT.
Clean the code that if-else path is too long.

Copilot AI review requested due to automatic review settings June 5, 2026 09:22

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

This PR adds optional UE8M0 (float8 e8m0) scale support to the FP8 quantization reference utilities and expands the oneDNN FP8 GEMM test matrix to cover that mode.

Changes:

  • Parameterize FP8 GEMM per-block tests to run with/without UE8M0 scales.
  • Extend fp8_block_quant_2d to optionally round scales to powers-of-two and return UE8M0 scale tensors; make dequant paths robust by casting scales to FP32.
  • Update oneDNN FP8 matmul scale handling and primitive caching key to account for UE8M0 usage.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.

File Description
tests/test_fp8_gemm_onednn.py Adds use_ue8m0 parametrization and threads it into quantization helpers.
tests/ops/fp8_quant_op.py Adds UE8M0 scale option for block quantization and ensures dequant casts scales to FP32.
csrc/xpu/onednn/fp8_gemm_w8a8.h Adjusts scale-mode detection, scale attribute setup, and matmul primitive cache key.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 166 to +171
int m1_sc_group_size = m1_sc.numel();
int m2_sc_group_size = m2_sc.numel();
int sc_group_size = (m1_sc_group_size << 8) | m2_sc_group_size;
if (m1_sc_dtype == at::ScalarType::Float8_e8m0fnu) {
sc_group_size |= (1 << 30);
}
Comment thread csrc/xpu/onednn/fp8_gemm_w8a8.h Outdated
Comment on lines +28 to +31
// MXFP8 weight scale is [k/32, n]
bool is_block_quant = (m1_sc.dim() == 2) && (m2_sc.dim() == 2) &&
(m1_sc.size(1) != 1) && (m2_sc.size(0) != 1) &&
(m2_sc.size(1) != n);
Comment thread tests/ops/fp8_quant_op.py
block_m: block rows
block_n: block cols
fp8_dtype: torch.float8_e4m3fn
use_ue8m0: return scales as torch.float8_e8m0fnu

@jikunshang jikunshang left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. thanks for fixing.
cc @wuxun-zhang

Signed-off-by: Lai, Yejing <yejing.lai@intel.com>
@Yejing-Lai Yejing-Lai force-pushed the lyj/update_fp8_gemm branch from ce6d92e to 9c69b50 Compare June 8, 2026 08:19
@Yejing-Lai Yejing-Lai marked this pull request as ready for review June 8, 2026 08:19
@Yejing-Lai

Copy link
Copy Markdown
Contributor Author

@zufangzhu Please review~

@zufangzhu zufangzhu left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. BTW, do we need to update w8a16 as well?

@Yejing-Lai

Copy link
Copy Markdown
Contributor Author

LGTM. BTW, do we need to update w8a16 as well?

w8a16 not include MXFP8 path. It can be kept as it is.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants