From 8b4f6ea2667ab380e80faa15f0dc55d378c56356 Mon Sep 17 00:00:00 2001 From: GeisYaO Date: Sun, 12 Apr 2026 03:23:04 -0900 Subject: [PATCH 1/5] Update quant.py --- aiter/ops/quant.py | 68 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 66 insertions(+), 2 deletions(-) diff --git a/aiter/ops/quant.py b/aiter/ops/quant.py index b1668c4c11..d1ebc931e5 100644 --- a/aiter/ops/quant.py +++ b/aiter/ops/quant.py @@ -72,7 +72,30 @@ def pertoken_quant( return y, y_scale -def per_1x32_f4_quant(x, scale=None, quant_dtype=dtypes.fp4x2, shuffle=False): +def per_1x32_f4_quant(x, scale=None, quant_dtype=dtypes.fp4x2, shuffle=False, + pack_dim=-1): + """Quantize a tensor to MXFP4 (e2m1) format with per-1x32 block scaling. + + By default, packing is along the last dimension (dim=-1), which produces + output suitable for ``tl.dot_scaled`` **LHS** operand: + A(M, K) -> fp4=(M, K//2), scale=(M, K//32) + + For ``tl.dot_scaled`` **RHS** operand, set ``pack_dim=0`` so the packing + is along the first dimension (the K / contraction dimension): + B(K, N) -> fp4=(K//2, N), scale=(K//32, N) + + Args: + x: Input tensor of shape (..., N) or (M, N). + scale: Pre-computed scale tensor (optional, usually None). + quant_dtype: Target quantized dtype, must be ``dtypes.fp4x2``. + shuffle: Whether to apply e8m0 scale shuffling for hardware. + pack_dim: Dimension along which to pack two FP4 values into one byte. + -1 (default): pack along the last dimension (for dot_scaled LHS). + 0: pack along the first dimension (for dot_scaled RHS). + + Returns: + Tuple of (quantized_tensor, scale_tensor). + """ assert quant_dtype == dtypes.fp4x2 block_size = 32 F8E8M0_EXP_BIAS = 127 @@ -81,6 +104,13 @@ def per_1x32_f4_quant(x, scale=None, quant_dtype=dtypes.fp4x2, shuffle=False): # dtypeMax = F4E2M1_MAX dtypeMax = 2.0**MAX_POW2 + # For pack_dim=0, transpose so packing always happens along last dim internally + transposed = False + if pack_dim == 0: + assert x.dim() == 2, "pack_dim=0 requires a 2D input tensor (K, N)" + x = x.T.contiguous() + transposed = True + shape_original = x.shape x = x.view(-1, shape_original[-1]) @@ -102,7 +132,41 @@ def per_1x32_f4_quant(x, scale=None, quant_dtype=dtypes.fp4x2, shuffle=False): scale = scale_e8m0_biased.view(m, -1).view(torch.uint8) if shuffle: scale = fp4_utils.e8m0_shuffle(scale) - return y, scale.view(dtypes.fp8_e8m0) + scale = scale.view(dtypes.fp8_e8m0) + + # For pack_dim=0, transpose results back: (N, K//2) -> (K//2, N) + if transposed: + y = y.T.contiguous() + scale = scale.view(torch.uint8).T.contiguous().view(dtypes.fp8_e8m0) + + return y, scale + + +def per_1x32_f4_quant_for_dot_scaled(lhs, rhs, quant_dtype=dtypes.fp4x2, + shuffle=False): + """Convenience function: quantize both LHS and RHS for ``tl.dot_scaled``. + + Handles the packing dimension automatically: + - LHS A(M, K): packed along K (dim=-1) -> fp4=(M, K//2), scale=(M, K//32) + - RHS B(K, N): packed along K (dim=0) -> fp4=(K//2, N), scale=(K//32, N) + + Note: Triton 3.6+ expects rhs_scale in transposed form (N, K//32). Users + should transpose the returned rhs_scale accordingly if using Triton >= 3.6. + + Args: + lhs: LHS tensor of shape (M, K). + rhs: RHS tensor of shape (K, N). + + Returns: + Tuple of (lhs_fp4, lhs_scale, rhs_fp4, rhs_scale). + """ + lhs_fp4, lhs_scale = per_1x32_f4_quant( + lhs, quant_dtype=quant_dtype, shuffle=shuffle, pack_dim=-1 + ) + rhs_fp4, rhs_scale = per_1x32_f4_quant( + rhs, quant_dtype=quant_dtype, shuffle=shuffle, pack_dim=0 + ) + return lhs_fp4, lhs_scale, rhs_fp4, rhs_scale def per_1x32_f8_scale_f8_quant( From a8d42d4f32ae20f8cbd76aac7d3c203793eb8582 Mon Sep 17 00:00:00 2001 From: GeisYaO Date: Mon, 13 Apr 2026 04:14:18 -0900 Subject: [PATCH 2/5] Refactor per_1x32_f4_quant function signature --- aiter/ops/quant.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/aiter/ops/quant.py b/aiter/ops/quant.py index d1ebc931e5..eddd4fce14 100644 --- a/aiter/ops/quant.py +++ b/aiter/ops/quant.py @@ -72,8 +72,8 @@ def pertoken_quant( return y, y_scale -def per_1x32_f4_quant(x, scale=None, quant_dtype=dtypes.fp4x2, shuffle=False, - pack_dim=-1): +def per_1x32_f4_quant( + x, scale=None, quant_dtype=dtypes.fp4x2, shuffle=False, pack_dim=-1): """Quantize a tensor to MXFP4 (e2m1) format with per-1x32 block scaling. By default, packing is along the last dimension (dim=-1), which produces @@ -98,7 +98,7 @@ def per_1x32_f4_quant(x, scale=None, quant_dtype=dtypes.fp4x2, shuffle=False, """ assert quant_dtype == dtypes.fp4x2 block_size = 32 - F8E8M0_EXP_BIAS = 127 + F8E8M0_EXP_BIAS = 127 # noqa:F841 F4E2M1_MAX = 6.0 MAX_POW2 = int(torch.log2(torch.tensor(F4E2M1_MAX, dtype=torch.float32)).item()) # dtypeMax = F4E2M1_MAX From dae3fb36afbd3ac227c3dbb5b8616a8cac9658a0 Mon Sep 17 00:00:00 2001 From: GeisYaO Date: Mon, 13 Apr 2026 04:16:38 -0900 Subject: [PATCH 3/5] Fix function definition formatting in quant.py --- aiter/ops/quant.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/aiter/ops/quant.py b/aiter/ops/quant.py index eddd4fce14..e2f360ebd8 100644 --- a/aiter/ops/quant.py +++ b/aiter/ops/quant.py @@ -73,7 +73,8 @@ def pertoken_quant( def per_1x32_f4_quant( - x, scale=None, quant_dtype=dtypes.fp4x2, shuffle=False, pack_dim=-1): + x, scale=None, quant_dtype=dtypes.fp4x2, shuffle=False, pack_dim=-1 +): """Quantize a tensor to MXFP4 (e2m1) format with per-1x32 block scaling. By default, packing is along the last dimension (dim=-1), which produces From 4f28987e0dbbc3a8b029231796221b17a5995df0 Mon Sep 17 00:00:00 2001 From: GeisYaO Date: Mon, 13 Apr 2026 19:02:26 -0900 Subject: [PATCH 4/5] Refactor per_1x32_f4_quant_for_dot_scaled definition --- aiter/ops/quant.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/aiter/ops/quant.py b/aiter/ops/quant.py index e2f360ebd8..a17d52fb03 100644 --- a/aiter/ops/quant.py +++ b/aiter/ops/quant.py @@ -143,8 +143,7 @@ def per_1x32_f4_quant( return y, scale -def per_1x32_f4_quant_for_dot_scaled(lhs, rhs, quant_dtype=dtypes.fp4x2, - shuffle=False): +def per_1x32_f4_quant_for_dot_scaled(lhs, rhs, quant_dtype=dtypes.fp4x2,shuffle=False): """Convenience function: quantize both LHS and RHS for ``tl.dot_scaled``. Handles the packing dimension automatically: From 3a6120ea6b1dbdc3cf4135e49daf7203ded66ebe Mon Sep 17 00:00:00 2001 From: GeisYaO Date: Tue, 14 Apr 2026 06:05:04 -0900 Subject: [PATCH 5/5] Restore semantic.py to match main branch --- aiter/ops/quant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiter/ops/quant.py b/aiter/ops/quant.py index a17d52fb03..8af023f537 100644 --- a/aiter/ops/quant.py +++ b/aiter/ops/quant.py @@ -143,7 +143,7 @@ def per_1x32_f4_quant( return y, scale -def per_1x32_f4_quant_for_dot_scaled(lhs, rhs, quant_dtype=dtypes.fp4x2,shuffle=False): +def per_1x32_f4_quant_for_dot_scaled(lhs, rhs, quant_dtype=dtypes.fp4x2, shuffle=False): """Convenience function: quantize both LHS and RHS for ``tl.dot_scaled``. Handles the packing dimension automatically: