Skip to content

Commit 541b80b

Browse files
committed
decrease test time
1 parent 2471a8a commit 541b80b

File tree

1 file changed

+28
-21
lines changed

1 file changed

+28
-21
lines changed

tests/moe/test_dpsk_fused_moe_fp8.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import pytest
22
import torch
3-
import numpy as np
43
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe, WeightLayout
54
from flashinfer.autotuner import autotune
65

@@ -243,28 +242,36 @@ def _fp8_block_quant_2d(w_bf16: torch.Tensor, block: int = 128):
243242
max_fp8 = finfo.max
244243

245244
w_f32 = w_bf16.to(torch.float32).contiguous()
246-
w_fp8 = torch.empty_like(w_f32, dtype=torch.float8_e4m3fn)
247-
scales = torch.empty(
248-
(*prefix, nb_r, nb_c), dtype=torch.float32, device=w_bf16.device
245+
prefix_ndim = len(prefix)
246+
247+
# Reshape weights into 128x128 blocks and move block dims to the tail:
248+
# [..., nb_r, block, nb_c, block] -> [..., nb_r, nb_c, block, block]
249+
reshaped = w_f32.reshape(*prefix, nb_r, block, nb_c, block)
250+
permute_dims = tuple(range(prefix_ndim)) + (
251+
prefix_ndim,
252+
prefix_ndim + 2,
253+
prefix_ndim + 1,
254+
prefix_ndim + 3,
249255
)
256+
blocks = reshaped.permute(permute_dims).contiguous()
257+
258+
# Compute per-block scales
259+
amax = torch.amax(torch.abs(blocks), dim=(-1, -2))
260+
scales = torch.where(
261+
amax > 0,
262+
amax / max_fp8,
263+
torch.ones_like(amax, dtype=torch.float32),
264+
)
265+
266+
# Quantize blocks in parallel
267+
q_blocks = (blocks / scales.unsqueeze(-1).unsqueeze(-1)).to(torch.float8_e4m3fn)
268+
269+
# Restore original layout
270+
inv_permute = [0] * (prefix_ndim + 4)
271+
for i, d in enumerate(permute_dims):
272+
inv_permute[d] = i
273+
w_fp8 = q_blocks.permute(*inv_permute).reshape(*prefix, R, C)
250274

251-
it = np.ndindex(*prefix) if prefix else [()]
252-
for idx in it:
253-
sel = idx if isinstance(idx, tuple) else (idx,)
254-
for i in range(nb_r):
255-
rs = slice(i * block, (i + 1) * block)
256-
for j in range(nb_c):
257-
cs = slice(j * block, (j + 1) * block)
258-
blk = w_f32[(*sel, rs, cs)] # [128, 128]
259-
amax = torch.amax(torch.abs(blk))
260-
s = (
261-
(amax / max_fp8)
262-
if amax > 0
263-
else torch.tensor(1.0, device=w_bf16.device)
264-
)
265-
q = (blk / s).to(torch.float8_e4m3fn)
266-
w_fp8[(*sel, rs, cs)] = q
267-
scales[(*sel, i, j)] = s
268275
return w_fp8, scales
269276

270277

0 commit comments

Comments
 (0)