From 1998ba1d61b0c7a320701faf0fccef77ea1eeabb Mon Sep 17 00:00:00 2001 From: Jokeren Date: Fri, 22 May 2026 22:24:33 -0400 Subject: [PATCH] Enhance matmul_torch to support K-ragged activations with appropriate assertions and output scaling --- .../triton_kernels/triton_kernels/matmul.py | 42 +++++++++++++++++-- 1 file changed, 39 insertions(+), 3 deletions(-) diff --git a/python/triton_kernels/triton_kernels/matmul.py b/python/triton_kernels/triton_kernels/matmul.py index 4ff4cc2ea360..911e566e0d5d 100644 --- a/python/triton_kernels/triton_kernels/matmul.py +++ b/python/triton_kernels/triton_kernels/matmul.py @@ -798,7 +798,43 @@ def matmul_torch(a, b, bias, ): a, b = apply_precision(a, b, precision_config) + if ( + a_ragged_metadata is not None + and b_ragged_metadata is None + and a.shape[-1] != b.shape[-2] + ): + assert gather_indx is None, "gather not supported with K-ragged activations" + assert scatter_indx is None, "scatter not supported with K-ragged activations" + n_expts_tot = a_ragged_metadata.slice_sizes.shape[0] + m, n = a.shape[-2], b.shape[-1] + out = torch.zeros((n_expts_tot, m, n), dtype=torch.float32, device=a.device) + x_slice_offs = a_ragged_metadata.slice_offs + w_slice_offs = torch.zeros(n_expts_tot + 1, dtype=torch.int32, device=b.device) + w_slice_offs[1:] = torch.cumsum(a_ragged_metadata.slice_sizes, 0) + for expt in range(n_expts_tot): + k = int(a_ragged_metadata.slice_sizes[expt].item()) + if k == 0: + continue + x_start = int(x_slice_offs[expt].item()) + w_start = int(w_slice_offs[expt].item()) + x_slice = a[:, x_start:x_start + k] + w_base = b[expt] if b.ndim == 3 else b + w_slice = w_base[w_start:w_start + k, :] + out_expt = matmul_torch( + x_slice, w_slice, None if bias is None else bias[expt], + None, None, None, None, PrecisionConfig(), + betas, gammas, + round_x, round_y, + ) + out[expt] = out_expt.to(out.dtype) + actual_scale = precision_config.flex_ctx.out_data.actual_scale + if actual_scale is not None: + actual_scale.copy_(compute_actual_scale(out, precision_config.out_dtype)) + return scale(out, precision_config.flex_ctx.out_data.expected_scale) + if b_ragged_metadata is not None: + assert gather_indx is None, "gather not supported with K-ragged activations" + assert scatter_indx is None, "scatter not supported with K-ragged activations" n_expts_tot = b_ragged_metadata.slice_sizes.shape[0] m, n = a.shape[-2], b.shape[-1] out = torch.zeros((n_expts_tot, m, n), dtype=torch.float32, device=a.device) @@ -813,8 +849,8 @@ def matmul_torch(a, b, bias, x_slice = a[:, x_start:x_start + k] w_slice = b[w_start:w_start + k, :] out_expt = matmul_torch( - x_slice, w_slice, None, None, - None, None, None, PrecisionConfig(), + x_slice, w_slice, None, + None, None, None, None, PrecisionConfig(), betas, gammas, round_x, round_y, ) @@ -858,7 +894,7 @@ def matmul_torch(a, b, bias, else: idx = gather_indx[lo:hi] batch = i if is_input_batched else 0 - out = torch.matmul(round_x(a[batch, idx, :], torch.arange(lo, hi, device="cuda")).float(), + out = torch.matmul(round_x(a[batch, idx, :], torch.arange(lo, hi, device=a.device)).float(), b[i].float()) if bias is not None: out += bias[i, :] if betas is None else bias[i, :] * betas[lo:hi, None]