Support fp8 block gemm with fp8_e8m0 scale#398
Conversation
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
This PR adds optional UE8M0 (float8 e8m0) scale support to the FP8 quantization reference utilities and expands the oneDNN FP8 GEMM test matrix to cover that mode.
Changes:
- Parameterize FP8 GEMM per-block tests to run with/without UE8M0 scales.
- Extend
fp8_block_quant_2dto optionally round scales to powers-of-two and return UE8M0 scale tensors; make dequant paths robust by casting scales to FP32. - Update oneDNN FP8 matmul scale handling and primitive caching key to account for UE8M0 usage.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| tests/test_fp8_gemm_onednn.py | Adds use_ue8m0 parametrization and threads it into quantization helpers. |
| tests/ops/fp8_quant_op.py | Adds UE8M0 scale option for block quantization and ensures dequant casts scales to FP32. |
| csrc/xpu/onednn/fp8_gemm_w8a8.h | Adjusts scale-mode detection, scale attribute setup, and matmul primitive cache key. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| int m1_sc_group_size = m1_sc.numel(); | ||
| int m2_sc_group_size = m2_sc.numel(); | ||
| int sc_group_size = (m1_sc_group_size << 8) | m2_sc_group_size; | ||
| if (m1_sc_dtype == at::ScalarType::Float8_e8m0fnu) { | ||
| sc_group_size |= (1 << 30); | ||
| } |
| // MXFP8 weight scale is [k/32, n] | ||
| bool is_block_quant = (m1_sc.dim() == 2) && (m2_sc.dim() == 2) && | ||
| (m1_sc.size(1) != 1) && (m2_sc.size(0) != 1) && | ||
| (m2_sc.size(1) != n); |
| block_m: block rows | ||
| block_n: block cols | ||
| fp8_dtype: torch.float8_e4m3fn | ||
| use_ue8m0: return scales as torch.float8_e8m0fnu |
jikunshang
left a comment
There was a problem hiding this comment.
LGTM. thanks for fixing.
cc @wuxun-zhang
Signed-off-by: Lai, Yejing <yejing.lai@intel.com>
ce6d92e to
9c69b50
Compare
|
@zufangzhu Please review~ |
zufangzhu
left a comment
There was a problem hiding this comment.
LGTM. BTW, do we need to update w8a16 as well?
w8a16 not include MXFP8 path. It can be kept as it is. |
Support FP8 block GEMM with fp8_e8m0 scale and add related UT.
Clean the code that if-else path is too long.