diff --git a/python/triton_kernels/bench/bench_mlp.py b/python/triton_kernels/bench/bench_mlp.py index 9e36efd17dfc..76930bcc85ca 100644 --- a/python/triton_kernels/bench/bench_mlp.py +++ b/python/triton_kernels/bench/bench_mlp.py @@ -7,14 +7,12 @@ import triton_kernels.roofline as roofline from triton_kernels.swiglu import swiglu_fn from triton_kernels.matmul import matmul, PrecisionConfig, FlexCtx, FnSpecs, FusedActivation -from triton_kernels.matmul_details.opt_flags import make_opt_flags, scoped_opt_flags_constraints +from triton_kernels.matmul_details.opt_flags import scoped_opt_flags_constraints from triton_kernels.target_info import get_cdna_version from triton_kernels.tensor_details import layout -from triton_kernels.tensor_details.layout import BlackwellMX4ValueShuffledLayout from triton_kernels.reduce import reduce from triton_kernels.topk import topk from triton_kernels.tensor import make_ragged_tensor_metadata, remap_ragged_tensor_metadata # ragged tensor -from triton_kernels.tensor import is_tma_compliant, Tensor, torch_dtype_to_dtype from triton_kernels.distributed import convert_dp_to_ep, convert_ep_to_dp, make_expt_dict_uniform, make_expt_assignment, SymmetricMemoryPool from triton_kernels.distributed_details.mesh import Mesh # quantization @@ -23,60 +21,6 @@ from triton_kernels.numerics_details.mxfp import MXFP_BLOCK_SIZE, downcast_to_mxfp -def _shuffle_mx4_weights(tensor, block_k, block_n): - """ - Convert MX4 weights from BlackwellMXValueLayout to BlackwellMX4ValueShuffledLayout. - - Works directly with the column-major storage data, bypassing the canonical format - round-trip which uses a different byte-level packing convention. - """ - from triton_kernels.tensor import Storage, Tensor as TKTensor - storage = tensor.storage - # The stored data is column-major [E, K_packed_padded, N] with stride(-2)==1 - data = storage.data - E = tensor.shape[0] - K_logical = tensor.shape[-2] - N = tensor.shape[-1] - K_packed = K_logical // 2 # 2 FP4 values per byte - # Trim any padding from the BlackwellMXValueLayout - data = data[:, :K_packed, :N].contiguous() - # Now apply the shuffled layout's tiling - shuffled_layout = BlackwellMX4ValueShuffledLayout(block_k=block_k, block_n=block_n) - transformation = shuffled_layout.make_transformation([E, K_logical, N], True) - shuffled_data = transformation.swizzle_data(data) - return TKTensor(Storage(shuffled_data, shuffled_layout), shape=list(tensor.shape), dtype=tensor.dtype) - - -def _infer_opt_flags(x, w, ragged_metadata, pc): - """ - Infer opt_flags by calling make_opt_flags with the same parameters matmul would use. - This ensures the block shapes match what the kernel will actually select. - """ - if not isinstance(w, Tensor): - raise TypeError("w must be a Tensor for block shape inference") - K = w.shape[-2] - N = w.shape[-1] - M = x.shape[-2] - batch_size = 1 - if not isinstance(x, Tensor): - x = wrap_torch_tensor(x) - # Convert out_dtype from torch.dtype to triton dtype (make_opt_flags expects .bitwidth) - out_dtype = pc.out_dtype or x.dtype - out_dtype = torch_dtype_to_dtype(out_dtype) - x_transpose = x.stride(-1) != 1 - b_scale = pc.b_mx_scale - can_use_tma = (x.numel() > 0 and is_tma_compliant(x) and w.numel() > 0 and is_tma_compliant(w) - and (b_scale is None or is_tma_compliant(b_scale))) - # Respects any constraints set by the caller via scoped_opt_flags_constraints - opt_flags = make_opt_flags(out_dtype, x.dtype, w.dtype, pc, batch_size, M, N, K, ragged_metadata, can_use_tma, - False, # can_use_split_k=False for MoE - None, # epilogue_effective_itemsize - x_transpose, False, # has_y_acc_in - None, # block_k - ) - return opt_flags - - def was_launched_with_torchrun(): required = ["RANK", "WORLD_SIZE", "LOCAL_RANK", "MASTER_ADDR", "MASTER_PORT"] return all(k in os.environ for k in required) @@ -102,8 +46,14 @@ def quantize_weight(w, dtype, **opt): assert dtype == FP4, f"{dtype=}" w, w_scale = downcast_to_mxfp(w.to(torch.bfloat16), torch.uint8, axis=1) if opt: - w = convert_layout(wrap_torch_tensor(w, dtype=FP4), opt["value_layout"]) - w_scale = convert_layout(wrap_torch_tensor(w_scale), opt["scale_layout"]) + w = wrap_torch_tensor(w, dtype=FP4) + value_layout = opt.get("value_layout") + if value_layout is not None: + w = convert_layout(w, value_layout) + w_scale = wrap_torch_tensor(w_scale) + scale_layout = opt.get("scale_layout") + if scale_layout is not None: + w_scale = convert_layout(w_scale, scale_layout) return w, InFlexData(), w_scale @@ -186,7 +136,10 @@ def bench_mlp(batch_per_expt, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_d opt2 = dict() if w_dtype == FP4: num_warps = 4 if batch <= 512 else 8 - value_layout = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1) + value_layout = layout.make_default_matmul_mxfp4_w_layout( + mx_axis=1, + allow_blackwell_value_shuffle=shuffle_mx4, + ) scale_layout = layout.make_default_matmul_mxfp4_w_scale_layout(mx_axis=1, num_warps=num_warps) opt1 = { "value_layout": value_layout, @@ -223,58 +176,13 @@ def bench_mlp(batch_per_expt, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_d expt_dict = make_expt_dict_uniform(EP, n_expts_tot) expt_assignment = make_expt_assignment(EP, n_expts_tot, expt_dict, torch.device(dev)) - # For MX4 shuffling: run one dry-run iteration to collect routing data, then infer block shapes - if shuffle_mx4 and w_dtype == FP4: - # Disable block swap: with shuffled weights, tile loads are already contiguous, - # so the swap's cacheline optimization is unnecessary. More importantly, disabling - # the swap gives block_k=128 (vs 256), halving per-stage smem footprint, which - # enables fitting 5 pipeline stages instead of 4 — a bigger win than the swap. - dry_run_constraints = {"disable_mx4_block_swap": True} - if epilogue_subtile_fc1 is not None: - dry_run_constraints["epilogue_subtile"] = epilogue_subtile_fc1 - with scoped_opt_flags_constraints(dry_run_constraints): - # Dry-run routing to get ragged metadata and dispatched activations - l_dry = matmul(x_dp_local_bf16, wg_global, bg_global, precision_config=pcg) - l_active_dry = topk(l_dry, n_expts_act, apply_softmax=True, all_gather=True, symm_mem_pool=symm_mem_pool) - active_indx_dry = l_active_dry.indx - expt_sizes_dry = l_active_dry.mask_metadata.col_sum - dispatch_indx_dry = l_active_dry.mask_metadata.row_sorted_indx - x_global_meta_dry = make_ragged_tensor_metadata(expt_sizes_dry, dispatch_indx_dry.shape[0]) - y_dry = convert_dp_to_ep(x_dp_local_fp8, expt_assignment, active_indx_dry, dispatch_indx_dry, symm_mem_pool) - y_meta_dry = remap_ragged_tensor_metadata(x_global_meta_dry, expt_assignment.expt_map[rank, :]) - - if y_dry.nelement() > 0: - # Infer block shapes for W1 (includes the block swap) - opt_flags_w1 = _infer_opt_flags(y_dry, w1_ep_local, y_meta_dry, pc1) - w1_block_k, w1_block_n = opt_flags_w1.block_k, opt_flags_w1.block_n - w1_ep_local = _shuffle_mx4_weights(w1_ep_local, w1_block_k, w1_block_n) - - # Run W1 once to get intermediate for W2 block shape inference - y_fc1_dry = matmul(y_dry, w1_ep_local, b1_ep_local, a_ragged_metadata=y_meta_dry, precision_config=pc1, - fused_activation=act1) - - # Infer block shapes for W2 (includes the block swap) - opt_flags_w2 = _infer_opt_flags(y_fc1_dry, w2_ep_local, y_meta_dry, pc2) - w2_block_k, w2_block_n = opt_flags_w2.block_k, opt_flags_w2.block_n - w2_ep_local = _shuffle_mx4_weights(w2_ep_local, w2_block_k, w2_block_n) - - print(f"Shuffled layout: FC1 block_k={w1_block_k}, block_n={w1_block_n}, " - f"stages={opt_flags_w1.num_stages}, subtile={opt_flags_w1.epilogue_subtile}; " - f"FC2 block_k={w2_block_k}, block_n={w2_block_n}, " - f"stages={opt_flags_w2.num_stages}, subtile={opt_flags_w2.epilogue_subtile}") - torch.cuda.synchronize() - # Build per-kernel constraints fc1_constraints = {} - if shuffle_mx4: - fc1_constraints["disable_mx4_block_swap"] = True if num_stages_fc1 is not None: fc1_constraints["num_stages"] = num_stages_fc1 if epilogue_subtile_fc1 is not None: fc1_constraints["epilogue_subtile"] = epilogue_subtile_fc1 fc2_constraints = {} - if shuffle_mx4: - fc2_constraints["disable_mx4_block_swap"] = True if num_stages_fc2 is not None: fc2_constraints["num_stages"] = num_stages_fc2 diff --git a/python/triton_kernels/tests/test_matmul.py b/python/triton_kernels/tests/test_matmul.py index 6cbe32d7291d..f7542501da0e 100644 --- a/python/triton_kernels/tests/test_matmul.py +++ b/python/triton_kernels/tests/test_matmul.py @@ -91,6 +91,7 @@ class Case: split_k: int = 1 a_hbm_swizzling: bool = False b_hbm_swizzling: bool = False + shuffle_mxfp4_w_layout: bool = False epilogue_subtile: Union[int, None] = None a_transpose: bool = False b_transpose: bool = False @@ -148,12 +149,16 @@ def _build_test_op_cases(): # float8 x mxfloat test_cases.extend([ Case(16, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", b_hbm_swizzling=True), + Case(16, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", b_hbm_swizzling=True, shuffle_mxfp4_w_layout=True), Case(1024, 1024, 1024, "batched", "float8_e5m2", "mxfloat4_e2m1", b_hbm_swizzling=True), + Case(1024, 1024, 1024, "batched", "float8_e5m2", "mxfloat4_e2m1", b_hbm_swizzling=True, shuffle_mxfp4_w_layout=True), Case(1024, 1024, 1024, "batched", "float8_e5m2", "mxfloat4_e2m1"), Case(1024, 1024, 1024, "ragged", "float8_e5m2", "mxfloat4_e2m1", split_k=9), Case(1024, 1024, 1024, "ragged", "float8_e5m2", "mxfloat4_e2m1", split_k=9, b_hbm_swizzling=True), + Case(1024, 1024, 1024, "ragged", "float8_e5m2", "mxfloat4_e2m1", split_k=9, b_hbm_swizzling=True, shuffle_mxfp4_w_layout=True), Case(300, 400, 416, "ragged", "float8_e5m2", "mxfloat8_e4m3fn"), Case(300, 400, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1"), + Case(300, 400, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", b_hbm_swizzling=True, shuffle_mxfp4_w_layout=True), Case(300, 400, 416, "batched", "float8_e5m2", "mxfloat8_e4m3fn"), ]) # mxfloat x mxfloat @@ -236,7 +241,7 @@ def _build_test_op_cases(): @pytest.mark.parametrize("is_persistent", [False,True]) @pytest.mark.parametrize("num_warps", [4, 8] if is_hopper() else [None]) def test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma, is_persistent, num_warps, n_slices, - mode, act_dtype_str, weight_dtype_str, output_dtype_str, block_m, b_hbm_swizzling, a_hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile, + mode, act_dtype_str, weight_dtype_str, output_dtype_str, block_m, b_hbm_swizzling, shuffle_mxfp4_w_layout, a_hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile, a_transpose, b_transpose, c_transpose, swiglu_opts, device, opt_flags_scope): # We catch and re-invoke pytest.skip(), because otherwise pytest may hold a reference to @@ -244,7 +249,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma, i skip_message = None try: _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma, is_persistent, num_warps, n_slices, - mode, act_dtype_str, weight_dtype_str, output_dtype_str, block_m, b_hbm_swizzling, a_hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile, + mode, act_dtype_str, weight_dtype_str, output_dtype_str, block_m, b_hbm_swizzling, shuffle_mxfp4_w_layout, a_hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile, a_transpose, b_transpose, c_transpose, swiglu_opts, device, opt_flags_scope) except pytest.skip.Exception as e: @@ -254,7 +259,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma, i pytest.skip(skip_message) def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma, is_persistent, num_warps, n_slices, - mode, act_dtype_str, weight_dtype_str, output_dtype_str, block_m, b_hbm_swizzling, a_hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile, + mode, act_dtype_str, weight_dtype_str, output_dtype_str, block_m, b_hbm_swizzling, shuffle_mxfp4_w_layout, a_hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile, a_transpose, b_transpose, c_transpose, swiglu_opts, device, opt_flags_scope): act_uses_mx = act_dtype_str.startswith("mx") or act_dtype_str == "nvfp4_e2m1" @@ -349,6 +354,20 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma, # set opt flags constraints constraints = make_constraints(block_m, split_k, is_persistent, epilogue_subtile, b_hbm_swizzling, weight_dtype_str, num_warps) + use_blackwell_shuffled_w_layout = shuffle_mxfp4_w_layout and b_hbm_swizzling + if shuffle_mxfp4_w_layout: + if not b_hbm_swizzling: + pytest.skip("Shuffled MXFP4 weight layout only applies with b_hbm_swizzling") + if is_hip() or torch.cuda.get_device_capability()[0] < 10: + pytest.skip("Shuffled MXFP4 weight layout requires Blackwell or newer") + if weight_dtype_str != "mxfloat4_e2m1": + pytest.skip("Shuffled MXFP4 weight layout only supports mxfloat4_e2m1 weights") + if not act_dtype_str.startswith("float8"): + pytest.skip("Shuffled MXFP4 weight layout is only tested with FP8 activations") + if not colmajor_mxfp_weight: + pytest.skip("Shuffled MXFP4 weight layout requires column-major MXFP weights") + if not is_persistent: + pytest.skip("Shuffled MXFP4 weight layout requires the persistent TMA kernel") opt_flags.update_opt_flags_constraints(constraints) a_dtype = DType(act_dtype_str) @@ -359,6 +378,12 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma, do_bias = inner_expt_opt is None do_gather = do_gather and mode != "batched" do_scatter = do_scatter and mode != "batched" + b_value_hbm_swizzling = None + if b_hbm_swizzling and colmajor_mxfp_weight and b_dtype.is_mxfloat4: + b_value_hbm_swizzling = layout.make_default_matmul_mxfp4_w_layout( + mx_axis=-2, + allow_blackwell_value_shuffle=use_blackwell_shuffled_w_layout, + ) # --- create inputs --- a, a_scales, a_ragged_metadata = make_random_tensor( @@ -384,9 +409,11 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma, ragged_padding = inner_expt_opt is not None and "pad_b" in inner_expt_opt, squeeze_batch_dim = mode == "plain", is_mx_rowmajor = not colmajor_mxfp_weight, - value_hbm_swizzling = layout.make_default_matmul_mxfp4_w_layout(mx_axis=-2) if b_hbm_swizzling and colmajor_mxfp_weight and b_dtype.is_mxfloat4 else None, + value_hbm_swizzling = b_value_hbm_swizzling, scale_hbm_swizzling = layout.make_default_matmul_mxfp4_w_scale_layout(mx_axis=-2, num_warps=num_warps) if b_hbm_swizzling and colmajor_mxfp_weight and b_dtype.is_mxfloat4 else None, ) + if use_blackwell_shuffled_w_layout: + assert isinstance(b.storage.layout, layout.BlackwellMX4ValueShuffledLayout) gather_indx = None if not do_gather else torch.randint(0, max(m, 1), (m, ), dtype=torch.int32, device=device) scatter_indx = None if not do_scatter else torch.randperm(m, dtype=torch.int32, device=device) bias = None if not do_bias else torch.randn(b.shape[:-2] + b.shape[-1:], dtype=torch.float32, device=device) diff --git a/python/triton_kernels/tests/test_matmul_details/test_opt_flags_nvidia.py b/python/triton_kernels/tests/test_matmul_details/test_opt_flags_nvidia.py index 8894684535d6..909ca7172302 100644 --- a/python/triton_kernels/tests/test_matmul_details/test_opt_flags_nvidia.py +++ b/python/triton_kernels/tests/test_matmul_details/test_opt_flags_nvidia.py @@ -37,11 +37,9 @@ def _make_batched_blackwell_mxfp4_weight(device, batch_size, k, n): return weight_val, weight_scale -def _shuffle_blackwell_mxfp4_weight(weight, block_k, block_n): - shuffled_layout = BlackwellMX4ValueShuffledLayout(block_k=block_k, block_n=block_n) - transformation = shuffled_layout.make_transformation(weight.shape, is_fp4=True) - shuffled_data = transformation.swizzle_data(weight.storage.data) - return Tensor(Storage(shuffled_data, shuffled_layout), dtype=weight.dtype, shape=weight.shape) +def _shuffle_blackwell_mxfp4_weight(weight): + shuffled_layout = BlackwellMX4ValueShuffledLayout() + return convert_layout(weight, shuffled_layout) @pytest.mark.parametrize("n, expected", [(64, 128), (200, 256)]) @@ -70,7 +68,7 @@ def test_matmul_blackwell_scale_small_n(device): out_dtype=a.dtype, ) tri_y = matmul(a, b, None, precision_config=precision_config) - ref_y = matmul_torch(a, b, None, precision_config=precision_config) + ref_y = matmul_torch(a.to(torch.bfloat16), b, None, precision_config=precision_config) assert_close(ref_y, tri_y, maxtol=3e-2, rmstol=None) @@ -82,30 +80,25 @@ def test_matmul_blackwell_shuffled_mxfp4_weight(device): torch.manual_seed(0) batch_size, m, n, k = 2, 128, 128, 128 - block_k, block_n = 128, 128 - a = torch.randn((batch_size, m, k), device=device, dtype=torch.bfloat16) + a = torch.randn((batch_size, m, k), device=device, dtype=torch.bfloat16).to(torch.float8_e5m2) b, b_scale = _make_batched_blackwell_mxfp4_weight(device, batch_size, k, n) - b_shuffled = _shuffle_blackwell_mxfp4_weight(b, block_k, block_n) + b_shuffled = _shuffle_blackwell_mxfp4_weight(b) # Sanity-check the host-side packing; this is the layout consumed by the # W_SHUFFLED TMA load path in _p_matmul. - transformation = b_shuffled.storage.layout.make_transformation(b.shape, is_fp4=True) - assert torch.equal(b.storage.data, transformation.unswizzle_data(b_shuffled.storage.data)) + assert torch.equal(b.storage.data, convert_layout(b_shuffled, b.storage.layout).storage.data) precision_config = PrecisionConfig( b_mx_scale=b_scale, b_microblock_size=MXFP_BLOCK_SIZE.value, - out_dtype=a.dtype, + out_dtype=torch.bfloat16, ) constraints = { "is_persistent": True, "block_m": 128, - "block_n": block_n, - "block_k": block_k, - "disable_mx4_block_swap": True, } with scoped_opt_flags_constraints(constraints): tri_y = matmul(a, b_shuffled, None, precision_config=precision_config) - ref_y = matmul_torch(a, b, None, precision_config=precision_config) + ref_y = matmul_torch(a.to(torch.bfloat16), b, None, precision_config=precision_config) assert_close(ref_y, tri_y, maxtol=3e-2, rmstol=None) diff --git a/python/triton_kernels/triton_kernels/matmul.py b/python/triton_kernels/triton_kernels/matmul.py index 9de9a0e0bac2..25b0fdff90d4 100644 --- a/python/triton_kernels/triton_kernels/matmul.py +++ b/python/triton_kernels/triton_kernels/matmul.py @@ -401,6 +401,8 @@ def matmul(a, b, bias, block_k = block_k, mx_block_size = mx_block_size, x_uses_tma_when_persistent = a_uses_tma_when_persistent, + rhs_layout=b.storage.layout, + epilogue_reduction_n=fused_activation.specs.reduction_n, ) if b_is_shuffled: if b.dtype.bitwidth != 4: @@ -701,7 +703,7 @@ def apply(x, scale): if precision_config.a_mx_scale is not None: a_scale = precision_config.a_mx_scale - mx_axis = x_tri.storage.data.ndim -1 + mx_axis = x_tri.ndim - 1 canonical_layout = layout.StridedLayout(major_dim=mx_axis) x_tri = convert_layout(x_tri, canonical_layout) x_tri_scale = convert_layout(a_scale, canonical_layout) @@ -711,7 +713,7 @@ def apply(x, scale): if precision_config.b_mx_scale is not None: b_scale = precision_config.b_mx_scale - mx_axis = w_tri.storage.data.ndim - 2 + mx_axis = w_tri.ndim - 2 canonical_layout = layout.StridedLayout(major_dim=mx_axis) w_tri = convert_layout(w_tri, canonical_layout) w_tri_scale = convert_layout(b_scale, canonical_layout) diff --git a/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py b/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py index f309a1246592..42ba6b3a3d40 100644 --- a/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py +++ b/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py @@ -12,6 +12,7 @@ from triton_kernels.tensor_details.layout_details.hopper_scale import HopperMXScaleLayout from triton_kernels.tensor_details.layout_details.strided import StridedLayout from triton_kernels.tensor_details.layout_details.base import Layout +from triton_kernels.tensor_details.layout_details.blackwell_value_shuffled import BlackwellMX4ValueShuffledLayout from .opt_flags_details import opt_flags_amd, opt_flags_nvidia @dataclass @@ -194,6 +195,7 @@ def make_default_opt_flags_nvidia( constraints, x_uses_tma_when_persistent=True, mx_block_size=None, + epilogue_reduction_n=1, ): constraints_supported = {"block_m", "block_n", "block_k", "split_k", "is_persistent", "epilogue_subtile", "num_stages", "idle_sms", "max_allowable_mn", "num_warps", "disable_mx4_block_swap"} unsupported = set(constraints.keys()) - constraints_supported @@ -299,6 +301,10 @@ def _is_layout_strided(layout: Layout | None) -> bool: elif can_use_split_k and not enforce_bitwise_invariance: estimated_actual_grid_size = opt_flags_nvidia.compute_grid_size(None, batch_size, m, n, block_m, block_n) split_k = opt_flags_nvidia.compute_split_k(block_k, k, estimated_actual_grid_size) + if split_k > 1: + # Split-K writes full-N fp32 scratch and applies fused reductions in the + # reduce kernel, not in the matmul epilogue. + epilogue_reduction_n = 1 compute_num_stages_args = ( precision_config, is_persistent, @@ -312,6 +318,7 @@ def _is_layout_strided(layout: Layout | None) -> bool: epilogue_effective_itemsize, has_y_acc_in, mx_block_size, + epilogue_reduction_n, ) num_warps = opt_flags_nvidia.compute_num_warps(block_m, block_n, is_persistent, precision_config, constraints) @@ -436,6 +443,8 @@ def make_opt_flags( block_k, mx_block_size=None, x_uses_tma_when_persistent=True, + rhs_layout=None, + epilogue_reduction_n=1, ): opt_flags_constraints = _get_opt_flags_constraints() if opt_flags_constraints.get("is_persistent", False) and not can_use_persistent_tma: @@ -451,6 +460,11 @@ def make_opt_flags( assert not opt_flags_constraints assert block_k is None return opt_flags + if isinstance(rhs_layout, BlackwellMX4ValueShuffledLayout): + opt_flags_constraints = opt_flags_constraints.copy() + opt_flags_constraints.setdefault("block_k", rhs_layout.block_k) + opt_flags_constraints.setdefault("block_n", rhs_layout.block_n) + opt_flags_constraints.setdefault("disable_mx4_block_swap", True) if block_k is not None: opt_flags_constraints = opt_flags_constraints.copy() opt_flags_constraints.update(block_k=block_k, split_k=1) @@ -466,5 +480,6 @@ def make_opt_flags( *args, x_uses_tma_when_persistent=x_uses_tma_when_persistent, mx_block_size=mx_block_size, + epilogue_reduction_n=epilogue_reduction_n, ) assert False diff --git a/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_nvidia.py b/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_nvidia.py index c200afb0bde3..0e574326ba91 100644 --- a/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_nvidia.py +++ b/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_nvidia.py @@ -113,6 +113,7 @@ def compute_num_stages( epilogue_effective_itemsize, has_y_acc_in, mx_block_size=None, + epilogue_reduction_n=1, *, epilogue_subtile, occupancy_target, @@ -140,6 +141,10 @@ def compute_num_stages( # https://docs.nvidia.com/cuda/parallel-thread-execution/#packing-format-used-for-matrix-a-and-b-by-kind-mxf8f6f4-in-shared-memory stage_size += block_k * block_n * weight_size + if precision_config.a_mx_scale is not None: + scale_block_size = mx_block_size or int(MXFP_BLOCK_SIZE) + stage_size += block_m * (block_k // scale_block_size) + if precision_config.b_mx_scale is not None: # mx scales scale_block_size = mx_block_size or int(MXFP_BLOCK_SIZE) @@ -154,22 +159,19 @@ def compute_num_stages( else: acc_size = out_itemsize if target_info.cuda_capability_geq(10, 0) and epilogue_subtile is not None: - acc_block_n = block_n // epilogue_subtile + acc_block_n = block_n // epilogue_subtile // epilogue_reduction_n else: - acc_block_n = block_n + acc_block_n = block_n // epilogue_reduction_n # pipelined TMA store local to global, or # pipelined layout conversion before store of the accumulator # note: layout conversion has some padding epilogue_smem = int((block_m + 4) * acc_block_n * acc_size) if compute_swap_xw(precision_config, block_m, is_persistent): - # SWAP_XW Blackwell kernels stage the full transposed TMEM - # accumulator tile through fp32 smem before converting/storing it. - # If the output is narrower, the final TMA-store tile is a separate - # smem allocation. - acc_smem = block_m * block_n * (FP32.bitwidth // 8) - if out_itemsize < (FP32.bitwidth // 8): - acc_smem += int(block_m * acc_block_n * out_itemsize) - epilogue_smem = max(epilogue_smem, acc_smem) + # The fp32 accumulator stays in TMEM for the Blackwell SWAP_XW + # persistent path. Fused reductions such as swiglu still need smem + # for the unreduced output tile before the narrower TMA-store tile. + if epilogue_reduction_n > 1: + epilogue_smem += int(block_m * block_n * out_itemsize) smem_capacity -= epilogue_smem if x_transpose: smem_capacity -= block_m * block_k * (max(8, lhs_dtype.bitwidth) // 8) diff --git a/python/triton_kernels/triton_kernels/tensor_details/layout.py b/python/triton_kernels/triton_kernels/tensor_details/layout.py index e63f51831b2b..8dd8e0984536 100644 --- a/python/triton_kernels/triton_kernels/tensor_details/layout.py +++ b/python/triton_kernels/triton_kernels/tensor_details/layout.py @@ -24,8 +24,15 @@ ] -def make_default_matmul_mxfp4_w_layout(mx_axis: int): +def make_default_matmul_mxfp4_w_layout( + mx_axis: int, + allow_blackwell_value_shuffle: bool = False, + block_k: int = 128, + block_n: int = 256, +): if cuda_capability_geq(10): + if allow_blackwell_value_shuffle: + return BlackwellMX4ValueShuffledLayout(block_k=block_k, block_n=block_n) return BlackwellMXValueLayout() elif cuda_capability_geq(9): return HopperMXValueLayout(mx_axis=mx_axis, mma_version=3) diff --git a/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_value_shuffled.py b/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_value_shuffled.py index d88db10620c8..0bdff7b3ed8e 100644 --- a/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_value_shuffled.py +++ b/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_value_shuffled.py @@ -2,6 +2,7 @@ from dataclasses import dataclass import torch from .base import Layout, LayoutTransformation +from .torch_utils import repack # ------------------- Blackwell MX4 Value Shuffled Layout ------------------- @@ -25,8 +26,8 @@ class BlackwellMX4ValueShuffledLayout(Layout): The inner dimensions [tile_n, tile_k_packed] match the baseline TMA block shape after swapping, so no transpose is needed after TMA load. """ - block_k: int = 256 - block_n: int = 128 + block_k: int = 128 + block_n: int = 256 @property def name(self): @@ -57,8 +58,8 @@ def swizzle_block_shape(self, block_shape): class BlackwellMX4ValueShuffledTransformation(LayoutTransformation): """Transformation for the shuffled MX4 weight layout.""" - block_k: int = 256 - block_n: int = 128 + block_k: int = 128 + block_n: int = 256 def _compute_params(self, E, K_packed, N): """Compute tiling parameters from the physical shape.""" @@ -75,13 +76,40 @@ def _compute_params(self, E, K_packed, N): return tile_k_packed, tile_n, padded_K_packed, padded_N, num_tiles_k, num_tiles_n + def _canonical_to_physical(self, data: torch.Tensor) -> torch.Tensor: + """Repack canonical [..., K, N_packed] storage to physical [E, K_packed, N].""" + if not self.is_fp4: + raise ValueError("BlackwellMX4ValueShuffledLayout only supports fp4 values") + assert data.stride(-1) == 1 + out_shape = list(data.shape) + out_shape[-1] *= 2 + out_shape[-2] //= 2 + out = torch.empty(out_shape, dtype=data.dtype, device=data.device) + return repack(data, -1, -2, self.is_fp4, out=out) + + def _physical_to_canonical(self, data: torch.Tensor) -> torch.Tensor: + """Repack physical [E, K_packed, N] storage to canonical [..., K, N_packed].""" + if not self.is_fp4: + raise ValueError("BlackwellMX4ValueShuffledLayout only supports fp4 values") + out_shape = list(data.shape) + out_shape[-2] *= 2 + out_shape[-1] //= 2 + out = torch.empty(out_shape, dtype=data.dtype, device=data.device) + return repack(data, -2, -1, self.is_fp4, out=out) + def swizzle_data(self, data: torch.Tensor) -> torch.Tensor: """ - Convert data from physical [E, K_packed, N] to 5D shuffled layout. + Convert data from canonical [..., K, N_packed] to 5D shuffled layout. Target layout: [E, num_tiles_k, num_tiles_n, tile_n, tile_k_packed] This matches the baseline TMA block shape [block_n, packed_block_k] after swapping. """ + if data.ndim == 2: + data = data.unsqueeze(0) + if data.ndim != 3: + raise ValueError(f"Expected 2D or 3D canonical data, got {data.ndim}D") + + data = self._canonical_to_physical(data) E, K_packed, N = data.shape tile_k_packed, tile_n, padded_K_packed, padded_N, num_tiles_k, num_tiles_n = \ self._compute_params(E, K_packed, N) @@ -101,11 +129,12 @@ def swizzle_data(self, data: torch.Tensor) -> torch.Tensor: # Permute to [E, num_tiles_k, num_tiles_n, tile_n, tile_k_packed] # This puts K tiles first (for inner loop locality) and arranges # inner dims as [tile_n, tile_k_packed] to match baseline TMA block. - return data.permute(0, 3, 1, 2, 4).contiguous() + data = data.permute(0, 3, 1, 2, 4).contiguous() + return data def unswizzle_data(self, data: torch.Tensor) -> torch.Tensor: """ - Convert data from shuffled back to physical [E, K_packed, N]. + Convert data from shuffled back to canonical [..., K, N_packed]. Input layout: [E, num_tiles_k, num_tiles_n, tile_n, tile_k_packed] """ @@ -128,4 +157,6 @@ def unswizzle_data(self, data: torch.Tensor) -> torch.Tensor: data = data.transpose(1, 2).contiguous() # Trim padding back to original shape - return data[:, :orig_K_packed, :orig_N].contiguous() + data = data[:, :orig_K_packed, :orig_N].contiguous() + data = self._physical_to_canonical(data) + return data if len(self.shape) == 3 else data.squeeze(0)