|
1 | 1 | import pytest |
2 | 2 | import torch |
3 | | -import numpy as np |
4 | 3 | from flashinfer.fused_moe import trtllm_fp8_block_scale_moe, WeightLayout |
5 | 4 | from flashinfer.autotuner import autotune |
6 | 5 |
|
@@ -243,28 +242,36 @@ def _fp8_block_quant_2d(w_bf16: torch.Tensor, block: int = 128): |
243 | 242 | max_fp8 = finfo.max |
244 | 243 |
|
245 | 244 | w_f32 = w_bf16.to(torch.float32).contiguous() |
246 | | - w_fp8 = torch.empty_like(w_f32, dtype=torch.float8_e4m3fn) |
247 | | - scales = torch.empty( |
248 | | - (*prefix, nb_r, nb_c), dtype=torch.float32, device=w_bf16.device |
| 245 | + prefix_ndim = len(prefix) |
| 246 | + |
| 247 | + # Reshape weights into 128x128 blocks and move block dims to the tail: |
| 248 | + # [..., nb_r, block, nb_c, block] -> [..., nb_r, nb_c, block, block] |
| 249 | + reshaped = w_f32.reshape(*prefix, nb_r, block, nb_c, block) |
| 250 | + permute_dims = tuple(range(prefix_ndim)) + ( |
| 251 | + prefix_ndim, |
| 252 | + prefix_ndim + 2, |
| 253 | + prefix_ndim + 1, |
| 254 | + prefix_ndim + 3, |
249 | 255 | ) |
| 256 | + blocks = reshaped.permute(permute_dims).contiguous() |
| 257 | + |
| 258 | + # Compute per-block scales |
| 259 | + amax = torch.amax(torch.abs(blocks), dim=(-1, -2)) |
| 260 | + scales = torch.where( |
| 261 | + amax > 0, |
| 262 | + amax / max_fp8, |
| 263 | + torch.ones_like(amax, dtype=torch.float32), |
| 264 | + ) |
| 265 | + |
| 266 | + # Quantize blocks in parallel |
| 267 | + q_blocks = (blocks / scales.unsqueeze(-1).unsqueeze(-1)).to(torch.float8_e4m3fn) |
| 268 | + |
| 269 | + # Restore original layout |
| 270 | + inv_permute = [0] * (prefix_ndim + 4) |
| 271 | + for i, d in enumerate(permute_dims): |
| 272 | + inv_permute[d] = i |
| 273 | + w_fp8 = q_blocks.permute(*inv_permute).reshape(*prefix, R, C) |
250 | 274 |
|
251 | | - it = np.ndindex(*prefix) if prefix else [()] |
252 | | - for idx in it: |
253 | | - sel = idx if isinstance(idx, tuple) else (idx,) |
254 | | - for i in range(nb_r): |
255 | | - rs = slice(i * block, (i + 1) * block) |
256 | | - for j in range(nb_c): |
257 | | - cs = slice(j * block, (j + 1) * block) |
258 | | - blk = w_f32[(*sel, rs, cs)] # [128, 128] |
259 | | - amax = torch.amax(torch.abs(blk)) |
260 | | - s = ( |
261 | | - (amax / max_fp8) |
262 | | - if amax > 0 |
263 | | - else torch.tensor(1.0, device=w_bf16.device) |
264 | | - ) |
265 | | - q = (blk / s).to(torch.float8_e4m3fn) |
266 | | - w_fp8[(*sel, rs, cs)] = q |
267 | | - scales[(*sel, i, j)] = s |
268 | 275 | return w_fp8, scales |
269 | 276 |
|
270 | 277 |
|
|
0 commit comments