[Triton] DS FP4/FP8 Triton fusion and GEMM optimization#119
Merged
Conversation
Making the BMM use fp4 weights
Enable DSR1 FP8 Optimizations
valarLip
reviewed
Jan 9, 2026
ChuanLi1101
previously approved these changes
Jan 10, 2026
Collaborator
ChuanLi1101
left a comment
There was a problem hiding this comment.
Overall LGTM, approved for benchmark testing.
valarLip
reviewed
Jan 10, 2026
Contributor
Author
|
Addressed all the comments |
ChuanLi1101
approved these changes
Jan 14, 2026
Collaborator
ChuanLi1101
left a comment
There was a problem hiding this comment.
Approved due to comments has addressed.
PerryZhang01
pushed a commit
that referenced
this pull request
Jan 30, 2026
* tmp * fix * clean * Making the BMM use fp4 weights * add ATOM_USE_TRITON_GEMM and a16wfp4 gemm for o_proj * Cleaning up the code and ensuring other weights wont crash * add import check for gemm_a16wfp4_preshuffle * clean * clean * disable FP4 triton GEMM on o_proj on DS FP4 * Fused rms for fp4 * Adding the x_scale change in linear.py to choose when to quantize or not * Enabling the second fused rms before attention * Fixing issue where there was a shape mismatch when running the second fused rms * Marking shuffle and shuffle padding as true temporarily always * Working implemenation of fused_rms for fp4 * Formatting fixes * Fix syntax error * Remove some commented code from the fp4 section * disable only AR + input layernorm with ATOM_ENABLE_RMSNORM_QUANT_FUSION=1 * add _fuse_qkv_a_proj_reduce_rmsnorm_quant for DS FP4 * add gemm split + cat for DS FP4 * Integreated fused rmsnorm + quant in decoder layer * No need to fuse post attention * Refactored fusion condition * Transpose scales for input layernorm * Added torch compile guards on fusion to enable torch compiler * Refactored fp8 fused rms quant function * Added fp8 triton preshuffled gemm * Fixed triton gemm condition * Added fused rmsnorm quant fp8 back in * Added transpose_scale back to fp8 fake function * Remove duplicate env * Implemented fp8 gemm preshuffled + split + cat * add back triton fusk_rope_kv_cache * consider both AR_RMS + Quant and AR + RMS_Quant condition via ATOM_ENABLE_RMSNORM_QUANT_FUSION * Implemented fp8 fused reduce rms quant * change boundary * Removed unreachable branch * Added transpose_scale to fused reduce rms quant * fix * clean * add a16w8 preshuffle gemm * clean * change fp8 gemm boundary * triton fp8 gemm rename * remove loader change * remove comments * address comments --------- Co-authored-by: Omar Muhammad <omar.muhammad@amd.com> Co-authored-by: Farel Lukas <farlukas@amd.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR is co-authored by @k50112113, @omuhamma (#61) and @farlukas (#116)
This PR provides Triton fusion/GEMM optimizations for DS FP4 and FP8,
please use the following AITER branch for testing for now as some of the PRs are yet to be merged to AITER main:
https://github.com/ROCm/aiter/tree/shaoclee/atom_triton_tmp_0106
The required AITER PRs include:
To activate the optimizations on ATOM, the following env variables are required:
The following command along with the above env var are used to derive e2e performance results:
For client command:
DS FP8 performance comparisons and uplift

DS FP4 performance comparisons and uplift
