Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion python/triton_kernels/tests/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from triton_kernels.swiglu import swiglu, swiglu_fn
from triton_kernels.swiglu import PrecisionConfig as SwiGLUPrecisionConfig
from triton_kernels.tensor_details import layout
from triton_kernels.tensor import Tensor, convert_layout, wrap_torch_tensor
from triton_kernels.tensor import Tensor, convert_layout, make_ragged_tensor_metadata, wrap_torch_tensor
from triton_kernels.tensor_details.dtype import FP32

# ---------------
Expand Down Expand Up @@ -579,6 +579,54 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, do_gamma,
f"ref_y_scale: {ref_y_scale}, tri_y_scale: {tri_y_scale.item()}"


def test_k_ragged_mxfp8_act_scale_swizzling(device):
if not is_cuda() or torch.cuda.get_device_capability()[0] < 10:
pytest.skip("requires Blackwell or newer")

m, n, k = 64, 128, 96
a_dtype = DType("mxfloat8_e4m3fn")

def make_a(scale_layout):
torch.manual_seed(0)
return make_random_tensor(
shape=(m, k),
n_slices=10,
ragged_dim=1,
ragged_padding=True,
device=device,
dtype=a_dtype,
mxfp_dim=-1,
transpose=False,
squeeze_batch_dim=False,
scale_hbm_swizzling=scale_layout,
)

# A scale layout is supplied in both cases so K-ragged values get identical padding.
canonical_a, canonical_scale, canonical_metadata = make_a(layout.StridedLayout(-1))
swizzled_a, swizzled_scale, swizzled_metadata = make_a(layout.make_default_matmul_mx_act_scale_layout)
b = torch.randn((k, n), dtype=torch.bfloat16, device=device)
b_metadata = make_ragged_tensor_metadata(canonical_metadata.slice_sizes, k)

def run(a, scale, metadata):
return matmul(
a,
b,
None,
metadata,
b_metadata,
precision_config=PrecisionConfig(
a_mx_scale=scale,
a_microblock_size=MXFP_BLOCK_SIZE.value,
out_dtype=torch.bfloat16,
),
)

with opt_flags.scoped_opt_flags_constraints({"block_m": 128, "is_persistent": True}):
swizzled = run(swizzled_a, swizzled_scale, swizzled_metadata)
canonical = run(canonical_a, canonical_scale, canonical_metadata)
torch.testing.assert_close(swizzled, canonical)


def test_set_idle_sms():
if not is_cuda():
pytest.skip("Only supported on CUDA")
Expand Down
3 changes: 2 additions & 1 deletion python/triton_kernels/triton_kernels/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ def make_random_tensor(shape, n_slices, ragged_dim, ragged_padding, device, dtyp
if scale_hbm_swizzling is not None:
# hack to avoid circular dependency
if callable(scale_hbm_swizzling):
scale_hbm_swizzling = scale_hbm_swizzling(ragged_metadata)
# Segment metadata describes scale rows, never its inner axis.
scale_hbm_swizzling = scale_hbm_swizzling(ragged_metadata if ragged_dim == 0 else None)
scales = convert_layout(scales, scale_hbm_swizzling)
return buffer, scales, ragged_metadata
Loading