-
Notifications
You must be signed in to change notification settings - Fork 6.9k
[ModelOpt MXFP8] #18449
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ModelOpt MXFP8] #18449
Changes from all commits
640ade1
cce9f2f
403ab94
92ea923
10eb3a5
48f64f7
39b6b1d
608da34
34f54cd
8308802
cb270eb
d036876
6c41d92
ec9ff1c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -654,6 +654,31 @@ def _pack_mxfp8_scales(scale_u8: torch.Tensor) -> torch.Tensor: | |
| return packed.view(1, scale_m, scale_k, 2, 256) | ||
|
|
||
|
|
||
| def dequantize_mxfp8( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You could mirror implementation here: |
||
| data: torch.Tensor, scale_u8: torch.Tensor, group_size: int = 32 | ||
| ) -> torch.Tensor: | ||
| """Dequantize MXFP8 tensor with UE8M0 scales back to bf16. | ||
|
|
||
| Applies per-group scaling: fp8_val * 2^(scale - 127). | ||
|
|
||
| Args: | ||
| data: FP8 tensor of shape (M, K). | ||
| scale_u8: uint8 UE8M0 scales of shape (M, K // group_size). | ||
| group_size: Number of elements per scale group (default 32). | ||
|
|
||
| Returns: | ||
| bf16 tensor of shape (M, K). | ||
| """ | ||
| m, k = data.shape | ||
| n_groups = k // group_size | ||
| scales_f32 = torch.pow( | ||
| 2.0, scale_u8.to(dtype=torch.float32, device=data.device) - 127.0 | ||
| ) | ||
| data_f32 = data.to(torch.float32).view(m, n_groups, group_size) | ||
| scales_f32 = scales_f32.view(m, n_groups, 1) | ||
| return (data_f32 * scales_f32).view(m, k).to(torch.bfloat16) | ||
|
|
||
|
|
||
| def triton_mxfp8_blockscaled_linear( | ||
| input: torch.Tensor, | ||
| weight: torch.Tensor, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can borrow from triton instead of reimplementing this:
sglang/python/sglang/test/test_block_fp8.py
Lines 30 to 37 in bf08d3f