Skip to content

[KERNELS] Tuning for small batch mxfp4 matmuls with splitk#9980

Merged
aeng-openai merged 8 commits into
triton-lang:mainfrom
aeng-openai:aeng/small-batch-mxfp4-again
Apr 9, 2026
Merged

[KERNELS] Tuning for small batch mxfp4 matmuls with splitk#9980
aeng-openai merged 8 commits into
triton-lang:mainfrom
aeng-openai:aeng/small-batch-mxfp4-again

Conversation

@aeng-openai
Copy link
Copy Markdown
Collaborator

In general splitk was disabled in the mlp matmuls since when it is ragged it is not trivial to statically choose the splitk factor (though it would be possible to do so dynamically)

In the small batch, non-ragged case though, it is simple to allow split k. This PR does that and does some basic heuristic tuning for such cases as well as optimization to the splitk reduce itself.

These changes exposed a bug in the smem accounting heuristics where we weren't counting smem needed to perform SWAP_XW. This change fixes that.

Perf will probably get even better after integrating the shuffled mxfp4 weight layout from #9698 as well

- get up to 5 stages for small batch matmuls with mxfp4 weights
- tune the split k reduce as well
- add a benchmark script for the reduce

perf will probably get even better after integrating the shuffled
mxfp4 weight layout as well
This reverts commit f06e5941c93e5e345cb93ea6d14e82b69557556c.
@aeng-openai aeng-openai requested a review from ptillet as a code owner April 9, 2026 17:18
@aeng-openai aeng-openai changed the title Tuning for small batch mxfp4 matmuls with splitk [KERNELS] Tuning for small batch mxfp4 matmuls with splitk Apr 9, 2026
@aeng-openai aeng-openai merged commit 617cff0 into triton-lang:main Apr 9, 2026
17 of 18 checks passed
plognjen pushed a commit to plognjen/triton that referenced this pull request Apr 14, 2026
…ng#9980)

In general splitk was disabled in the mlp matmuls since when it is
ragged it is not trivial to statically choose the splitk factor (though
it would be possible to do so dynamically)

In the small batch, non-ragged case though, it is simple to allow split
k. This PR does that and does some basic heuristic tuning for such cases
as well as optimization to the splitk reduce itself.

These changes exposed a bug in the smem accounting heuristics where we
weren't counting smem needed to perform SWAP_XW. This change fixes that.

Perf will probably get even better after integrating the shuffled mxfp4
weight layout from triton-lang#9698 as
well
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants