diff --git a/aiter/ops/triton/_triton_kernels/quant.py b/aiter/ops/triton/_triton_kernels/quant.py index ed03a861df..3773fb077b 100644 --- a/aiter/ops/triton/_triton_kernels/quant.py +++ b/aiter/ops/triton/_triton_kernels/quant.py @@ -94,6 +94,16 @@ def _mxfp4_quant_op( x: [BLOCK_SIZE_M, BLOCK_SIZE_N], fp32 """ + EXP_BIAS_FP32: tl.constexpr = 127 + EXP_BIAS_FP4: tl.constexpr = 1 + EBITS_F32: tl.constexpr = 8 + EBITS_FP4: tl.constexpr = 2 + MBITS_F32: tl.constexpr = 23 + MBITS_FP4: tl.constexpr = 1 + + max_normal: tl.constexpr = 6 + min_normal: tl.constexpr = 1 + NUM_QUANT_BLOCKS: tl.constexpr = BLOCK_SIZE_N // MXFP4_QUANT_BLOCK_SIZE x = x.reshape(BLOCK_SIZE_M, NUM_QUANT_BLOCKS, MXFP4_QUANT_BLOCK_SIZE) # Calculate scale @@ -125,26 +135,49 @@ def _mxfp4_quant_op( # S111 -> +/- 6.0 qx = qx.to(tl.uint32, bitcast=True) - # Extract sign, exponents and mantissa fields from FP32 + # Extract sign s = qx & 0x80000000 - e = (qx >> 23) & 0xFF - m = qx & 0x7FFFFF - E8_BIAS: tl.constexpr = 127 - E2_BIAS: tl.constexpr = 1 + # Set everything to positive, will add sign back at the end + qx = qx ^ s + + qx_fp32 = qx.to(tl.float32, bitcast=True) + saturate_mask = qx_fp32 >= max_normal + denormal_mask = (not saturate_mask) & (qx_fp32 < min_normal) + normal_mask = not (saturate_mask | denormal_mask) # Denormal numbers - # If exponent is less than 127, then it's a denormal number - # See above, for denormal number mantissa is always 1 and we set bit 1 of mantissa - adjusted_exponents = tl.core.sub(E8_BIAS, e + 1, sanitize_overflow=False) - m = tl.where(e < E8_BIAS, (0x400000 | (m >> 1)) >> adjusted_exponents, m) - # For normal numbers, bias is changed from 127 to 1, and for subnormals, we keep exponent as 0. - # Note: E8_BIAS - E2_BIAS = 126, so for normals we subtract that. - e = tl.maximum(e, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS) - - # Combine sign, exponent, and mantissa, while saturating - # rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right - e2m1_tmp = tl.minimum((((e << 2) | (m >> 21)) + 1) >> 1, 0x7) - e2m1_value = ((s >> 28) | e2m1_tmp).to(tl.uint8) + denorm_exp: tl.constexpr = ( + (EXP_BIAS_FP32 - EXP_BIAS_FP4) + (MBITS_F32 - MBITS_FP4) + 1 + ) + denorm_mask_int: tl.constexpr = denorm_exp << MBITS_F32 + denorm_mask_float: tl.constexpr = tl.cast(denorm_mask_int, tl.float32, bitcast=True) + + denormal_x = qx_fp32 + denorm_mask_float + denormal_x = denormal_x.to(tl.uint32, bitcast=True) + denormal_x -= denorm_mask_int + denormal_x = denormal_x.to(tl.uint8) + + # Normal numbers + normal_x = qx + # resulting mantissa is odd + mant_odd = (normal_x >> (MBITS_F32 - MBITS_FP4)) & 1 + # update exponent, rounding bias part 1 + val_to_add = ((EXP_BIAS_FP4 - EXP_BIAS_FP32) << MBITS_F32) + (1 << 21) - 1 + normal_x += val_to_add + # rounding bias part 2 + normal_x += mant_odd + # take the bits! + normal_x = normal_x >> (MBITS_F32 - MBITS_FP4) + normal_x = normal_x.to(tl.uint8) + + # Merge results + e2m1_value = tl.full(qx.type.get_block_shapes(), 0x7, dtype=tl.uint8) + e2m1_value = tl.where(normal_mask, normal_x, e2m1_value) + e2m1_value = tl.where(denormal_mask, denormal_x, e2m1_value) + # add sign back + sign_lp = s >> (MBITS_F32 + EBITS_F32 - MBITS_FP4 - EBITS_FP4) + sign_lp = sign_lp.to(tl.uint8) + e2m1_value = e2m1_value | sign_lp e2m1_value = tl.reshape( e2m1_value, [BLOCK_SIZE_M, NUM_QUANT_BLOCKS, MXFP4_QUANT_BLOCK_SIZE // 2, 2] ) diff --git a/op_tests/triton_tests/test_quant_mxfp4.py b/op_tests/triton_tests/test_quant_mxfp4.py index d217d2544c..b77ebfc8a1 100644 --- a/op_tests/triton_tests/test_quant_mxfp4.py +++ b/op_tests/triton_tests/test_quant_mxfp4.py @@ -29,6 +29,16 @@ def torch_dynamic_mxfp4_quant( """ # Create padded x. Needed because mxfp4 works with block of 32 elements MXFP4_QUANT_BLOCK_SIZE = 32 + EXP_BIAS_FP32 = 127 + EXP_BIAS_FP4 = 1 + EBITS_F32 = 8 + EBITS_FP4 = 2 + MBITS_F32 = 23 + MBITS_FP4 = 1 + max_normal = 6 + min_normal = 1 + sign_mask = 1 << (EBITS_FP4 + MBITS_FP4) + x_shape = x.shape if x.shape[-1] % MXFP4_QUANT_BLOCK_SIZE != 0: shape = list(x_shape) @@ -78,29 +88,57 @@ def torch_dynamic_mxfp4_quant( # Convert quantized fp32 tensor to int32 before converting to mxfp4 format qx = qx.view(torch.int32) - # Extract sign, exponents and mantissa fields from int32 + # Extract sign s = qx & 0x80000000 - e = (qx >> 23) & 0xFF - m = qx & 0x7FFFFF + # Set everything to positive, will add sign back at the end + qx = qx ^ s - E8_BIAS = 127 - E2_BIAS = 1 + qx_fp32 = qx.view(torch.float32) + saturate_mask = qx_fp32 >= max_normal + denormal_mask = torch.logical_and( + torch.logical_not(saturate_mask), qx_fp32 < min_normal + ) + normal_mask = torch.logical_not(torch.logical_or(saturate_mask, denormal_mask)) # Denormal numbers - # If exponent is less than 127, then it's a denormal number - # See above, for denormal number mantissa is always 1 and we set bit 1 of mantissa - adjusted_exponents = E8_BIAS - e - 1 - m = torch.where(e < E8_BIAS, (0x400000 | (m >> 1)) >> adjusted_exponents, m) - - # For normal numbers, bias is changed from 127 to 1, and for subnormals, we keep exponent as 0. - # Note: E8_BIAS - E2_BIAS = 126, so for normals we subtract that. - e = torch.where(e > E8_BIAS - E2_BIAS, e, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS) - - # Combine sign, exponent, and mantissa, while saturating - # rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right - combined_val = (((e << 2) | (m >> 21)) + 1) >> 1 - e2m1_tmp = torch.where(combined_val < 0x7, combined_val, 0x7) - e2m1_value = (((s >> 28) & 0xF) | e2m1_tmp).to(torch.uint8) + denorm_exp = (EXP_BIAS_FP32 - EXP_BIAS_FP4) + (MBITS_F32 - MBITS_FP4) + 1 + denorm_mask_int = denorm_exp << MBITS_F32 + denorm_mask_float = torch.tensor(denorm_mask_int, dtype=torch.int32).view( + torch.float32 + ) + + denormal_x = qx_fp32 + denorm_mask_float + denormal_x = denormal_x.view(torch.int32) + denormal_x -= denorm_mask_int + denormal_x = denormal_x.to(torch.uint8) + + # Normal numbers + normal_x = qx + # resulting mantissa is odd + mant_odd = (normal_x >> (MBITS_F32 - MBITS_FP4)) & 1 + # update exponent, rounding bias part 1 + val_to_add = ((EXP_BIAS_FP4 - EXP_BIAS_FP32) << MBITS_F32) + (1 << 21) - 1 + normal_x += val_to_add + # rounding bias part 2 + normal_x += mant_odd + # take the bits! + normal_x = normal_x >> (MBITS_F32 - MBITS_FP4) + normal_x = normal_x.to(torch.uint8) + + # Merge results + e2m1_value = torch.full_like(qx, 0x7, dtype=torch.uint8) + e2m1_value = torch.where(normal_mask, normal_x, e2m1_value) + e2m1_value = torch.where(denormal_mask, denormal_x, e2m1_value) + + # add sign back + sign_lp = s >> (MBITS_F32 + EBITS_F32 - MBITS_FP4 - EBITS_FP4) + sign_lp = sign_lp.to(torch.uint8) + # Right shift of a negative signed integer can fill the least significant + # bits with either 1s or 0s, depending on the implementation. Since PyTorch + # doesn't have an uint32 dtype, we mask out these bits to get just the + # f4 sign bit + sign_lp = sign_lp & sign_mask + e2m1_value = e2m1_value | sign_lp # Pack 2 4-bit values into 8-bit x_mxfp4 = e2m1_value[..., ::2] | (e2m1_value[..., 1::2] << 4)