Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 50 additions & 17 deletions aiter/ops/triton/_triton_kernels/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,16 @@ def _mxfp4_quant_op(
x: [BLOCK_SIZE_M, BLOCK_SIZE_N], fp32

"""
EXP_BIAS_FP32: tl.constexpr = 127
EXP_BIAS_FP4: tl.constexpr = 1
EBITS_F32: tl.constexpr = 8
EBITS_FP4: tl.constexpr = 2
MBITS_F32: tl.constexpr = 23
MBITS_FP4: tl.constexpr = 1

max_normal: tl.constexpr = 6
min_normal: tl.constexpr = 1

NUM_QUANT_BLOCKS: tl.constexpr = BLOCK_SIZE_N // MXFP4_QUANT_BLOCK_SIZE
x = x.reshape(BLOCK_SIZE_M, NUM_QUANT_BLOCKS, MXFP4_QUANT_BLOCK_SIZE)
# Calculate scale
Expand Down Expand Up @@ -125,26 +135,49 @@ def _mxfp4_quant_op(
# S111 -> +/- 6.0
qx = qx.to(tl.uint32, bitcast=True)

# Extract sign, exponents and mantissa fields from FP32
# Extract sign
s = qx & 0x80000000
e = (qx >> 23) & 0xFF
m = qx & 0x7FFFFF
E8_BIAS: tl.constexpr = 127
E2_BIAS: tl.constexpr = 1
# Set everything to positive, will add sign back at the end
qx = qx ^ s

qx_fp32 = qx.to(tl.float32, bitcast=True)
saturate_mask = qx_fp32 >= max_normal
denormal_mask = (not saturate_mask) & (qx_fp32 < min_normal)
normal_mask = not (saturate_mask | denormal_mask)

# Denormal numbers
# If exponent is less than 127, then it's a denormal number
# See above, for denormal number mantissa is always 1 and we set bit 1 of mantissa
adjusted_exponents = tl.core.sub(E8_BIAS, e + 1, sanitize_overflow=False)
m = tl.where(e < E8_BIAS, (0x400000 | (m >> 1)) >> adjusted_exponents, m)
# For normal numbers, bias is changed from 127 to 1, and for subnormals, we keep exponent as 0.
# Note: E8_BIAS - E2_BIAS = 126, so for normals we subtract that.
e = tl.maximum(e, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS)

# Combine sign, exponent, and mantissa, while saturating
# rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right
e2m1_tmp = tl.minimum((((e << 2) | (m >> 21)) + 1) >> 1, 0x7)
e2m1_value = ((s >> 28) | e2m1_tmp).to(tl.uint8)
denorm_exp: tl.constexpr = (
(EXP_BIAS_FP32 - EXP_BIAS_FP4) + (MBITS_F32 - MBITS_FP4) + 1
)
denorm_mask_int: tl.constexpr = denorm_exp << MBITS_F32
denorm_mask_float: tl.constexpr = tl.cast(denorm_mask_int, tl.float32, bitcast=True)

denormal_x = qx_fp32 + denorm_mask_float
denormal_x = denormal_x.to(tl.uint32, bitcast=True)
denormal_x -= denorm_mask_int
denormal_x = denormal_x.to(tl.uint8)

# Normal numbers
normal_x = qx
# resulting mantissa is odd
mant_odd = (normal_x >> (MBITS_F32 - MBITS_FP4)) & 1
# update exponent, rounding bias part 1
val_to_add = ((EXP_BIAS_FP4 - EXP_BIAS_FP32) << MBITS_F32) + (1 << 21) - 1
normal_x += val_to_add
# rounding bias part 2
normal_x += mant_odd
# take the bits!
normal_x = normal_x >> (MBITS_F32 - MBITS_FP4)
normal_x = normal_x.to(tl.uint8)

# Merge results
e2m1_value = tl.full(qx.type.get_block_shapes(), 0x7, dtype=tl.uint8)
e2m1_value = tl.where(normal_mask, normal_x, e2m1_value)
e2m1_value = tl.where(denormal_mask, denormal_x, e2m1_value)
# add sign back
sign_lp = s >> (MBITS_F32 + EBITS_F32 - MBITS_FP4 - EBITS_FP4)
sign_lp = sign_lp.to(tl.uint8)
e2m1_value = e2m1_value | sign_lp
e2m1_value = tl.reshape(
e2m1_value, [BLOCK_SIZE_M, NUM_QUANT_BLOCKS, MXFP4_QUANT_BLOCK_SIZE // 2, 2]
)
Expand Down
76 changes: 57 additions & 19 deletions op_tests/triton_tests/test_quant_mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ def torch_dynamic_mxfp4_quant(
"""
# Create padded x. Needed because mxfp4 works with block of 32 elements
MXFP4_QUANT_BLOCK_SIZE = 32
EXP_BIAS_FP32 = 127
EXP_BIAS_FP4 = 1
EBITS_F32 = 8
EBITS_FP4 = 2
MBITS_F32 = 23
MBITS_FP4 = 1
max_normal = 6
min_normal = 1
sign_mask = 1 << (EBITS_FP4 + MBITS_FP4)

x_shape = x.shape
if x.shape[-1] % MXFP4_QUANT_BLOCK_SIZE != 0:
shape = list(x_shape)
Expand Down Expand Up @@ -78,29 +88,57 @@ def torch_dynamic_mxfp4_quant(
# Convert quantized fp32 tensor to int32 before converting to mxfp4 format
qx = qx.view(torch.int32)

# Extract sign, exponents and mantissa fields from int32
# Extract sign
s = qx & 0x80000000
e = (qx >> 23) & 0xFF
m = qx & 0x7FFFFF
# Set everything to positive, will add sign back at the end
qx = qx ^ s

E8_BIAS = 127
E2_BIAS = 1
qx_fp32 = qx.view(torch.float32)
saturate_mask = qx_fp32 >= max_normal
denormal_mask = torch.logical_and(
torch.logical_not(saturate_mask), qx_fp32 < min_normal
)
normal_mask = torch.logical_not(torch.logical_or(saturate_mask, denormal_mask))

# Denormal numbers
# If exponent is less than 127, then it's a denormal number
# See above, for denormal number mantissa is always 1 and we set bit 1 of mantissa
adjusted_exponents = E8_BIAS - e - 1
m = torch.where(e < E8_BIAS, (0x400000 | (m >> 1)) >> adjusted_exponents, m)

# For normal numbers, bias is changed from 127 to 1, and for subnormals, we keep exponent as 0.
# Note: E8_BIAS - E2_BIAS = 126, so for normals we subtract that.
e = torch.where(e > E8_BIAS - E2_BIAS, e, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS)

# Combine sign, exponent, and mantissa, while saturating
# rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right
combined_val = (((e << 2) | (m >> 21)) + 1) >> 1
e2m1_tmp = torch.where(combined_val < 0x7, combined_val, 0x7)
e2m1_value = (((s >> 28) & 0xF) | e2m1_tmp).to(torch.uint8)
denorm_exp = (EXP_BIAS_FP32 - EXP_BIAS_FP4) + (MBITS_F32 - MBITS_FP4) + 1
denorm_mask_int = denorm_exp << MBITS_F32
denorm_mask_float = torch.tensor(denorm_mask_int, dtype=torch.int32).view(
torch.float32
)

denormal_x = qx_fp32 + denorm_mask_float
denormal_x = denormal_x.view(torch.int32)
denormal_x -= denorm_mask_int
denormal_x = denormal_x.to(torch.uint8)

# Normal numbers
normal_x = qx
# resulting mantissa is odd
mant_odd = (normal_x >> (MBITS_F32 - MBITS_FP4)) & 1
# update exponent, rounding bias part 1
val_to_add = ((EXP_BIAS_FP4 - EXP_BIAS_FP32) << MBITS_F32) + (1 << 21) - 1
normal_x += val_to_add
# rounding bias part 2
normal_x += mant_odd
# take the bits!
normal_x = normal_x >> (MBITS_F32 - MBITS_FP4)
normal_x = normal_x.to(torch.uint8)

# Merge results
e2m1_value = torch.full_like(qx, 0x7, dtype=torch.uint8)
e2m1_value = torch.where(normal_mask, normal_x, e2m1_value)
e2m1_value = torch.where(denormal_mask, denormal_x, e2m1_value)

# add sign back
sign_lp = s >> (MBITS_F32 + EBITS_F32 - MBITS_FP4 - EBITS_FP4)
sign_lp = sign_lp.to(torch.uint8)
# Right shift of a negative signed integer can fill the least significant
# bits with either 1s or 0s, depending on the implementation. Since PyTorch
# doesn't have an uint32 dtype, we mask out these bits to get just the
# f4 sign bit
sign_lp = sign_lp & sign_mask
e2m1_value = e2m1_value | sign_lp

# Pack 2 4-bit values into 8-bit
x_mxfp4 = e2m1_value[..., ::2] | (e2m1_value[..., 1::2] << 4)
Expand Down
Loading