Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 13 additions & 105 deletions python/triton_kernels/bench/bench_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@
import triton_kernels.roofline as roofline
from triton_kernels.swiglu import swiglu_fn
from triton_kernels.matmul import matmul, PrecisionConfig, FlexCtx, FnSpecs, FusedActivation
from triton_kernels.matmul_details.opt_flags import make_opt_flags, scoped_opt_flags_constraints
from triton_kernels.matmul_details.opt_flags import scoped_opt_flags_constraints
from triton_kernels.target_info import get_cdna_version
from triton_kernels.tensor_details import layout
from triton_kernels.tensor_details.layout import BlackwellMX4ValueShuffledLayout
from triton_kernels.reduce import reduce
from triton_kernels.topk import topk
from triton_kernels.tensor import make_ragged_tensor_metadata, remap_ragged_tensor_metadata # ragged tensor
from triton_kernels.tensor import is_tma_compliant, Tensor, torch_dtype_to_dtype
from triton_kernels.distributed import convert_dp_to_ep, convert_ep_to_dp, make_expt_dict_uniform, make_expt_assignment, SymmetricMemoryPool
from triton_kernels.distributed_details.mesh import Mesh
# quantization
Expand All @@ -23,60 +21,6 @@
from triton_kernels.numerics_details.mxfp import MXFP_BLOCK_SIZE, downcast_to_mxfp


def _shuffle_mx4_weights(tensor, block_k, block_n):
"""
Convert MX4 weights from BlackwellMXValueLayout to BlackwellMX4ValueShuffledLayout.

Works directly with the column-major storage data, bypassing the canonical format
round-trip which uses a different byte-level packing convention.
"""
from triton_kernels.tensor import Storage, Tensor as TKTensor
storage = tensor.storage
# The stored data is column-major [E, K_packed_padded, N] with stride(-2)==1
data = storage.data
E = tensor.shape[0]
K_logical = tensor.shape[-2]
N = tensor.shape[-1]
K_packed = K_logical // 2 # 2 FP4 values per byte
# Trim any padding from the BlackwellMXValueLayout
data = data[:, :K_packed, :N].contiguous()
# Now apply the shuffled layout's tiling
shuffled_layout = BlackwellMX4ValueShuffledLayout(block_k=block_k, block_n=block_n)
transformation = shuffled_layout.make_transformation([E, K_logical, N], True)
shuffled_data = transformation.swizzle_data(data)
return TKTensor(Storage(shuffled_data, shuffled_layout), shape=list(tensor.shape), dtype=tensor.dtype)


def _infer_opt_flags(x, w, ragged_metadata, pc):
"""
Infer opt_flags by calling make_opt_flags with the same parameters matmul would use.
This ensures the block shapes match what the kernel will actually select.
"""
if not isinstance(w, Tensor):
raise TypeError("w must be a Tensor for block shape inference")
K = w.shape[-2]
N = w.shape[-1]
M = x.shape[-2]
batch_size = 1
if not isinstance(x, Tensor):
x = wrap_torch_tensor(x)
# Convert out_dtype from torch.dtype to triton dtype (make_opt_flags expects .bitwidth)
out_dtype = pc.out_dtype or x.dtype
out_dtype = torch_dtype_to_dtype(out_dtype)
x_transpose = x.stride(-1) != 1
b_scale = pc.b_mx_scale
can_use_tma = (x.numel() > 0 and is_tma_compliant(x) and w.numel() > 0 and is_tma_compliant(w)
and (b_scale is None or is_tma_compliant(b_scale)))
# Respects any constraints set by the caller via scoped_opt_flags_constraints
opt_flags = make_opt_flags(out_dtype, x.dtype, w.dtype, pc, batch_size, M, N, K, ragged_metadata, can_use_tma,
False, # can_use_split_k=False for MoE
None, # epilogue_effective_itemsize
x_transpose, False, # has_y_acc_in
None, # block_k
)
return opt_flags


def was_launched_with_torchrun():
required = ["RANK", "WORLD_SIZE", "LOCAL_RANK", "MASTER_ADDR", "MASTER_PORT"]
return all(k in os.environ for k in required)
Expand All @@ -102,8 +46,14 @@ def quantize_weight(w, dtype, **opt):
assert dtype == FP4, f"{dtype=}"
w, w_scale = downcast_to_mxfp(w.to(torch.bfloat16), torch.uint8, axis=1)
if opt:
w = convert_layout(wrap_torch_tensor(w, dtype=FP4), opt["value_layout"])
w_scale = convert_layout(wrap_torch_tensor(w_scale), opt["scale_layout"])
w = wrap_torch_tensor(w, dtype=FP4)
value_layout = opt.get("value_layout")
if value_layout is not None:
w = convert_layout(w, value_layout)
w_scale = wrap_torch_tensor(w_scale)
scale_layout = opt.get("scale_layout")
if scale_layout is not None:
w_scale = convert_layout(w_scale, scale_layout)
return w, InFlexData(), w_scale


Expand Down Expand Up @@ -186,7 +136,10 @@ def bench_mlp(batch_per_expt, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_d
opt2 = dict()
if w_dtype == FP4:
num_warps = 4 if batch <= 512 else 8
value_layout = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
value_layout = layout.make_default_matmul_mxfp4_w_layout(
mx_axis=1,
allow_blackwell_value_shuffle=shuffle_mx4,
)
scale_layout = layout.make_default_matmul_mxfp4_w_scale_layout(mx_axis=1, num_warps=num_warps)
opt1 = {
"value_layout": value_layout,
Expand Down Expand Up @@ -223,58 +176,13 @@ def bench_mlp(batch_per_expt, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_d
expt_dict = make_expt_dict_uniform(EP, n_expts_tot)
expt_assignment = make_expt_assignment(EP, n_expts_tot, expt_dict, torch.device(dev))

# For MX4 shuffling: run one dry-run iteration to collect routing data, then infer block shapes
if shuffle_mx4 and w_dtype == FP4:
# Disable block swap: with shuffled weights, tile loads are already contiguous,
# so the swap's cacheline optimization is unnecessary. More importantly, disabling
# the swap gives block_k=128 (vs 256), halving per-stage smem footprint, which
# enables fitting 5 pipeline stages instead of 4 — a bigger win than the swap.
dry_run_constraints = {"disable_mx4_block_swap": True}
if epilogue_subtile_fc1 is not None:
dry_run_constraints["epilogue_subtile"] = epilogue_subtile_fc1
with scoped_opt_flags_constraints(dry_run_constraints):
# Dry-run routing to get ragged metadata and dispatched activations
l_dry = matmul(x_dp_local_bf16, wg_global, bg_global, precision_config=pcg)
l_active_dry = topk(l_dry, n_expts_act, apply_softmax=True, all_gather=True, symm_mem_pool=symm_mem_pool)
active_indx_dry = l_active_dry.indx
expt_sizes_dry = l_active_dry.mask_metadata.col_sum
dispatch_indx_dry = l_active_dry.mask_metadata.row_sorted_indx
x_global_meta_dry = make_ragged_tensor_metadata(expt_sizes_dry, dispatch_indx_dry.shape[0])
y_dry = convert_dp_to_ep(x_dp_local_fp8, expt_assignment, active_indx_dry, dispatch_indx_dry, symm_mem_pool)
y_meta_dry = remap_ragged_tensor_metadata(x_global_meta_dry, expt_assignment.expt_map[rank, :])

if y_dry.nelement() > 0:
# Infer block shapes for W1 (includes the block swap)
opt_flags_w1 = _infer_opt_flags(y_dry, w1_ep_local, y_meta_dry, pc1)
w1_block_k, w1_block_n = opt_flags_w1.block_k, opt_flags_w1.block_n
w1_ep_local = _shuffle_mx4_weights(w1_ep_local, w1_block_k, w1_block_n)

# Run W1 once to get intermediate for W2 block shape inference
y_fc1_dry = matmul(y_dry, w1_ep_local, b1_ep_local, a_ragged_metadata=y_meta_dry, precision_config=pc1,
fused_activation=act1)

# Infer block shapes for W2 (includes the block swap)
opt_flags_w2 = _infer_opt_flags(y_fc1_dry, w2_ep_local, y_meta_dry, pc2)
w2_block_k, w2_block_n = opt_flags_w2.block_k, opt_flags_w2.block_n
w2_ep_local = _shuffle_mx4_weights(w2_ep_local, w2_block_k, w2_block_n)

print(f"Shuffled layout: FC1 block_k={w1_block_k}, block_n={w1_block_n}, "
f"stages={opt_flags_w1.num_stages}, subtile={opt_flags_w1.epilogue_subtile}; "
f"FC2 block_k={w2_block_k}, block_n={w2_block_n}, "
f"stages={opt_flags_w2.num_stages}, subtile={opt_flags_w2.epilogue_subtile}")
torch.cuda.synchronize()

# Build per-kernel constraints
fc1_constraints = {}
if shuffle_mx4:
fc1_constraints["disable_mx4_block_swap"] = True
if num_stages_fc1 is not None:
fc1_constraints["num_stages"] = num_stages_fc1
if epilogue_subtile_fc1 is not None:
fc1_constraints["epilogue_subtile"] = epilogue_subtile_fc1
fc2_constraints = {}
if shuffle_mx4:
fc2_constraints["disable_mx4_block_swap"] = True
if num_stages_fc2 is not None:
fc2_constraints["num_stages"] = num_stages_fc2

Expand Down
35 changes: 31 additions & 4 deletions python/triton_kernels/tests/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class Case:
split_k: int = 1
a_hbm_swizzling: bool = False
b_hbm_swizzling: bool = False
shuffle_mxfp4_w_layout: bool = False
epilogue_subtile: Union[int, None] = None
a_transpose: bool = False
b_transpose: bool = False
Expand Down Expand Up @@ -148,12 +149,16 @@ def _build_test_op_cases():
# float8 x mxfloat
test_cases.extend([
Case(16, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", b_hbm_swizzling=True),
Case(16, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", b_hbm_swizzling=True, shuffle_mxfp4_w_layout=True),
Case(1024, 1024, 1024, "batched", "float8_e5m2", "mxfloat4_e2m1", b_hbm_swizzling=True),
Case(1024, 1024, 1024, "batched", "float8_e5m2", "mxfloat4_e2m1", b_hbm_swizzling=True, shuffle_mxfp4_w_layout=True),
Case(1024, 1024, 1024, "batched", "float8_e5m2", "mxfloat4_e2m1"),
Case(1024, 1024, 1024, "ragged", "float8_e5m2", "mxfloat4_e2m1", split_k=9),
Case(1024, 1024, 1024, "ragged", "float8_e5m2", "mxfloat4_e2m1", split_k=9, b_hbm_swizzling=True),
Case(1024, 1024, 1024, "ragged", "float8_e5m2", "mxfloat4_e2m1", split_k=9, b_hbm_swizzling=True, shuffle_mxfp4_w_layout=True),
Case(300, 400, 416, "ragged", "float8_e5m2", "mxfloat8_e4m3fn"),
Case(300, 400, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1"),
Case(300, 400, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", b_hbm_swizzling=True, shuffle_mxfp4_w_layout=True),
Case(300, 400, 416, "batched", "float8_e5m2", "mxfloat8_e4m3fn"),
])
# mxfloat x mxfloat
Expand Down Expand Up @@ -236,15 +241,15 @@ def _build_test_op_cases():
@pytest.mark.parametrize("is_persistent", [False,True])
@pytest.mark.parametrize("num_warps", [4, 8] if is_hopper() else [None])
def test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma, is_persistent, num_warps, n_slices,
mode, act_dtype_str, weight_dtype_str, output_dtype_str, block_m, b_hbm_swizzling, a_hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile,
mode, act_dtype_str, weight_dtype_str, output_dtype_str, block_m, b_hbm_swizzling, shuffle_mxfp4_w_layout, a_hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile,
a_transpose, b_transpose, c_transpose,
swiglu_opts, device, opt_flags_scope):
# We catch and re-invoke pytest.skip(), because otherwise pytest may hold a reference to
# the frame that called pytest.skip, including all the tensors, leading to OOM.
skip_message = None
try:
_test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma, is_persistent, num_warps, n_slices,
mode, act_dtype_str, weight_dtype_str, output_dtype_str, block_m, b_hbm_swizzling, a_hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile,
mode, act_dtype_str, weight_dtype_str, output_dtype_str, block_m, b_hbm_swizzling, shuffle_mxfp4_w_layout, a_hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile,
a_transpose, b_transpose, c_transpose,
swiglu_opts, device, opt_flags_scope)
except pytest.skip.Exception as e:
Expand All @@ -254,7 +259,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma, i
pytest.skip(skip_message)

def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma, is_persistent, num_warps, n_slices,
mode, act_dtype_str, weight_dtype_str, output_dtype_str, block_m, b_hbm_swizzling, a_hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile,
mode, act_dtype_str, weight_dtype_str, output_dtype_str, block_m, b_hbm_swizzling, shuffle_mxfp4_w_layout, a_hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile,
a_transpose, b_transpose, c_transpose,
swiglu_opts, device, opt_flags_scope):
act_uses_mx = act_dtype_str.startswith("mx") or act_dtype_str == "nvfp4_e2m1"
Expand Down Expand Up @@ -349,6 +354,20 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma,

# set opt flags constraints
constraints = make_constraints(block_m, split_k, is_persistent, epilogue_subtile, b_hbm_swizzling, weight_dtype_str, num_warps)
use_blackwell_shuffled_w_layout = shuffle_mxfp4_w_layout and b_hbm_swizzling
if shuffle_mxfp4_w_layout:
if not b_hbm_swizzling:
pytest.skip("Shuffled MXFP4 weight layout only applies with b_hbm_swizzling")
if is_hip() or torch.cuda.get_device_capability()[0] < 10:
pytest.skip("Shuffled MXFP4 weight layout requires Blackwell or newer")
if weight_dtype_str != "mxfloat4_e2m1":
pytest.skip("Shuffled MXFP4 weight layout only supports mxfloat4_e2m1 weights")
if not act_dtype_str.startswith("float8"):
pytest.skip("Shuffled MXFP4 weight layout is only tested with FP8 activations")
if not colmajor_mxfp_weight:
pytest.skip("Shuffled MXFP4 weight layout requires column-major MXFP weights")
if not is_persistent:
pytest.skip("Shuffled MXFP4 weight layout requires the persistent TMA kernel")
opt_flags.update_opt_flags_constraints(constraints)

a_dtype = DType(act_dtype_str)
Expand All @@ -359,6 +378,12 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma,
do_bias = inner_expt_opt is None
do_gather = do_gather and mode != "batched"
do_scatter = do_scatter and mode != "batched"
b_value_hbm_swizzling = None
if b_hbm_swizzling and colmajor_mxfp_weight and b_dtype.is_mxfloat4:
b_value_hbm_swizzling = layout.make_default_matmul_mxfp4_w_layout(
mx_axis=-2,
allow_blackwell_value_shuffle=use_blackwell_shuffled_w_layout,
)

# --- create inputs ---
a, a_scales, a_ragged_metadata = make_random_tensor(
Expand All @@ -384,9 +409,11 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma,
ragged_padding = inner_expt_opt is not None and "pad_b" in inner_expt_opt,
squeeze_batch_dim = mode == "plain",
is_mx_rowmajor = not colmajor_mxfp_weight,
value_hbm_swizzling = layout.make_default_matmul_mxfp4_w_layout(mx_axis=-2) if b_hbm_swizzling and colmajor_mxfp_weight and b_dtype.is_mxfloat4 else None,
value_hbm_swizzling = b_value_hbm_swizzling,
scale_hbm_swizzling = layout.make_default_matmul_mxfp4_w_scale_layout(mx_axis=-2, num_warps=num_warps) if b_hbm_swizzling and colmajor_mxfp_weight and b_dtype.is_mxfloat4 else None,
)
if use_blackwell_shuffled_w_layout:
assert isinstance(b.storage.layout, layout.BlackwellMX4ValueShuffledLayout)
gather_indx = None if not do_gather else torch.randint(0, max(m, 1), (m, ), dtype=torch.int32, device=device)
scatter_indx = None if not do_scatter else torch.randperm(m, dtype=torch.int32, device=device)
bias = None if not do_bias else torch.randn(b.shape[:-2] + b.shape[-1:], dtype=torch.float32, device=device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,9 @@ def _make_batched_blackwell_mxfp4_weight(device, batch_size, k, n):
return weight_val, weight_scale


def _shuffle_blackwell_mxfp4_weight(weight, block_k, block_n):
shuffled_layout = BlackwellMX4ValueShuffledLayout(block_k=block_k, block_n=block_n)
transformation = shuffled_layout.make_transformation(weight.shape, is_fp4=True)
shuffled_data = transformation.swizzle_data(weight.storage.data)
return Tensor(Storage(shuffled_data, shuffled_layout), dtype=weight.dtype, shape=weight.shape)
def _shuffle_blackwell_mxfp4_weight(weight):
shuffled_layout = BlackwellMX4ValueShuffledLayout()
return convert_layout(weight, shuffled_layout)


@pytest.mark.parametrize("n, expected", [(64, 128), (200, 256)])
Expand Down Expand Up @@ -70,7 +68,7 @@ def test_matmul_blackwell_scale_small_n(device):
out_dtype=a.dtype,
)
tri_y = matmul(a, b, None, precision_config=precision_config)
ref_y = matmul_torch(a, b, None, precision_config=precision_config)
ref_y = matmul_torch(a.to(torch.bfloat16), b, None, precision_config=precision_config)
assert_close(ref_y, tri_y, maxtol=3e-2, rmstol=None)


Expand All @@ -82,30 +80,25 @@ def test_matmul_blackwell_shuffled_mxfp4_weight(device):

torch.manual_seed(0)
batch_size, m, n, k = 2, 128, 128, 128
block_k, block_n = 128, 128
a = torch.randn((batch_size, m, k), device=device, dtype=torch.bfloat16)
a = torch.randn((batch_size, m, k), device=device, dtype=torch.bfloat16).to(torch.float8_e5m2)
b, b_scale = _make_batched_blackwell_mxfp4_weight(device, batch_size, k, n)
b_shuffled = _shuffle_blackwell_mxfp4_weight(b, block_k, block_n)
b_shuffled = _shuffle_blackwell_mxfp4_weight(b)

# Sanity-check the host-side packing; this is the layout consumed by the
# W_SHUFFLED TMA load path in _p_matmul.
transformation = b_shuffled.storage.layout.make_transformation(b.shape, is_fp4=True)
assert torch.equal(b.storage.data, transformation.unswizzle_data(b_shuffled.storage.data))
assert torch.equal(b.storage.data, convert_layout(b_shuffled, b.storage.layout).storage.data)

precision_config = PrecisionConfig(
b_mx_scale=b_scale,
b_microblock_size=MXFP_BLOCK_SIZE.value,
out_dtype=a.dtype,
out_dtype=torch.bfloat16,
)
constraints = {
"is_persistent": True,
"block_m": 128,
"block_n": block_n,
"block_k": block_k,
"disable_mx4_block_swap": True,
}
with scoped_opt_flags_constraints(constraints):
tri_y = matmul(a, b_shuffled, None, precision_config=precision_config)

ref_y = matmul_torch(a, b, None, precision_config=precision_config)
ref_y = matmul_torch(a.to(torch.bfloat16), b, None, precision_config=precision_config)
assert_close(ref_y, tri_y, maxtol=3e-2, rmstol=None)
6 changes: 4 additions & 2 deletions python/triton_kernels/triton_kernels/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,8 @@ def matmul(a, b, bias,
block_k = block_k,
mx_block_size = mx_block_size,
x_uses_tma_when_persistent = a_uses_tma_when_persistent,
rhs_layout=b.storage.layout,
epilogue_reduction_n=fused_activation.specs.reduction_n,
)
if b_is_shuffled:
if b.dtype.bitwidth != 4:
Expand Down Expand Up @@ -701,7 +703,7 @@ def apply(x, scale):

if precision_config.a_mx_scale is not None:
a_scale = precision_config.a_mx_scale
mx_axis = x_tri.storage.data.ndim -1
mx_axis = x_tri.ndim - 1
canonical_layout = layout.StridedLayout(major_dim=mx_axis)
x_tri = convert_layout(x_tri, canonical_layout)
x_tri_scale = convert_layout(a_scale, canonical_layout)
Expand All @@ -711,7 +713,7 @@ def apply(x, scale):

if precision_config.b_mx_scale is not None:
b_scale = precision_config.b_mx_scale
mx_axis = w_tri.storage.data.ndim - 2
mx_axis = w_tri.ndim - 2
canonical_layout = layout.StridedLayout(major_dim=mx_axis)
w_tri = convert_layout(w_tri, canonical_layout)
w_tri_scale = convert_layout(b_scale, canonical_layout)
Expand Down
Loading
Loading