Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 39 additions & 3 deletions python/triton_kernels/triton_kernels/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,43 @@ def matmul_torch(a, b, bias,
):
a, b = apply_precision(a, b, precision_config)

if (
a_ragged_metadata is not None
and b_ragged_metadata is None
and a.shape[-1] != b.shape[-2]
):
Comment on lines +801 to +805
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need a new special case?

assert gather_indx is None, "gather not supported with K-ragged activations"
assert scatter_indx is None, "scatter not supported with K-ragged activations"
n_expts_tot = a_ragged_metadata.slice_sizes.shape[0]
m, n = a.shape[-2], b.shape[-1]
out = torch.zeros((n_expts_tot, m, n), dtype=torch.float32, device=a.device)
x_slice_offs = a_ragged_metadata.slice_offs
w_slice_offs = torch.zeros(n_expts_tot + 1, dtype=torch.int32, device=b.device)
w_slice_offs[1:] = torch.cumsum(a_ragged_metadata.slice_sizes, 0)
for expt in range(n_expts_tot):
k = int(a_ragged_metadata.slice_sizes[expt].item())
if k == 0:
continue
x_start = int(x_slice_offs[expt].item())
w_start = int(w_slice_offs[expt].item())
x_slice = a[:, x_start:x_start + k]
w_base = b[expt] if b.ndim == 3 else b
w_slice = w_base[w_start:w_start + k, :]
out_expt = matmul_torch(
x_slice, w_slice, None if bias is None else bias[expt],
None, None, None, None, PrecisionConfig(),
betas, gammas,
round_x, round_y,
)
out[expt] = out_expt.to(out.dtype)
actual_scale = precision_config.flex_ctx.out_data.actual_scale
if actual_scale is not None:
actual_scale.copy_(compute_actual_scale(out, precision_config.out_dtype))
return scale(out, precision_config.flex_ctx.out_data.expected_scale)

if b_ragged_metadata is not None:
assert gather_indx is None, "gather not supported with K-ragged activations"
assert scatter_indx is None, "scatter not supported with K-ragged activations"
n_expts_tot = b_ragged_metadata.slice_sizes.shape[0]
m, n = a.shape[-2], b.shape[-1]
out = torch.zeros((n_expts_tot, m, n), dtype=torch.float32, device=a.device)
Expand All @@ -813,8 +849,8 @@ def matmul_torch(a, b, bias,
x_slice = a[:, x_start:x_start + k]
w_slice = b[w_start:w_start + k, :]
out_expt = matmul_torch(
x_slice, w_slice, None, None,
None, None, None, PrecisionConfig(),
x_slice, w_slice, None,
None, None, None, None, PrecisionConfig(),
betas, gammas,
round_x, round_y,
)
Expand Down Expand Up @@ -858,7 +894,7 @@ def matmul_torch(a, b, bias,
else:
idx = gather_indx[lo:hi]
batch = i if is_input_batched else 0
out = torch.matmul(round_x(a[batch, idx, :], torch.arange(lo, hi, device="cuda")).float(),
out = torch.matmul(round_x(a[batch, idx, :], torch.arange(lo, hi, device=a.device)).float(),
b[i].float())
if bias is not None:
out += bias[i, :] if betas is None else bias[i, :] * betas[lo:hi, None]
Expand Down
Loading