Skip to content

[Triton] DS FP4/FP8 Triton fusion and GEMM optimization#119

Merged
k50112113 merged 61 commits intomainfrom
shaoclee/ds_fp4_gemm
Jan 14, 2026
Merged

[Triton] DS FP4/FP8 Triton fusion and GEMM optimization#119
k50112113 merged 61 commits intomainfrom
shaoclee/ds_fp4_gemm

Conversation

@k50112113
Copy link
Contributor

@k50112113 k50112113 commented Jan 9, 2026

This PR is co-authored by @k50112113, @omuhamma (#61) and @farlukas (#116)

This PR provides Triton fusion/GEMM optimizations for DS FP4 and FP8,
please use the following AITER branch for testing for now as some of the PRs are yet to be merged to AITER main:
https://github.com/ROCm/aiter/tree/shaoclee/atom_triton_tmp_0106

The required AITER PRs include:

  1. [Triton] Triton A16WFP4 GEMM prequant aiter#1777
  2. [Triton] Triton a16w8 gemm preshuffle aiter#1778
  3. [Triton] Add Fused GEMM A8W8 + Split + Concat Triton Kernel aiter#1553 (review)

To activate the optimizations on ATOM, the following env variables are required:

# for concurrency > 4, use AR + RMS_Quant + GEMM optimizations:
export ATOM_USE_TRITON_GEMM=1
# note: ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION is turned on automictically when ATOM_USE_TRITON_GEMM is on

# for concurrency = 4, use AR_RMS + Quant_GEMM optimizations:
export ATOM_USE_TRITON_GEMM=1
export ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION=0

The following command along with the above env var are used to derive e2e performance results:

# for DS FP8
python -m atom.entrypoints.openai_server \
    --model /data/deepseek-ai/DeepSeek-R1-0528/ \
    -tp 8 \
    --block-size 1 \
    --server-port 8989 2>&1 | tee server.out

# for DS FP4
export ATOM_USE_TRITON_MXFP4_BMM=1
export AMDGCN_USE_BUFFER_OPS=1
python -m atom.entrypoints.openai_server \
    --model /data/DeepSeek-R1-0528-MXFP4-Preview \
    -tp 8 \
    --block-size 16 \
    --kv_cache_dtype fp8 \
    --server-port 8989 \
    2>&1 | tee server.out

For client command:

MODEL=<DS FP4 or FP8 model paths>
ISL=3500
OSL=1500
PORT=8989
for CONC in 4 256 128 64 32 16 8; do
    RESULT_FILENAME=${ISL}_${OSL}_${CONC}
    python /root/ATOM/atom/benchmarks/benchmark_serving.py \
        --model=$MODEL --backend=vllm --base-url=http://localhost:$PORT \
        --dataset-name=random \
        --random-input-len=$ISL --random-output-len=$OSL \
        --random-range-ratio 1.0 \
        --num-prompts=$(( $CONC * 8 )) \
        --max-concurrency=$CONC \
        --request-rate=inf --ignore-eos \
        --save-result --percentile-metrics="ttft,tpot,itl,e2el" \
        --result-dir=./ --result-filename=$RESULT_FILENAME.json 2>&1 | tee -a ${RESULT_FILENAME}.log
done

DS FP8 performance comparisons and uplift
image

DS FP4 performance comparisons and uplift
image

k50112113 and others added 30 commits December 11, 2025 19:32
Making the BMM use fp4 weights
@k50112113 k50112113 requested a review from valarLip January 9, 2026 15:47
ChuanLi1101
ChuanLi1101 previously approved these changes Jan 10, 2026
Copy link
Collaborator

@ChuanLi1101 ChuanLi1101 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, approved for benchmark testing.

@k50112113
Copy link
Contributor Author

Addressed all the comments

@k50112113 k50112113 requested review from ChuanLi1101 and valarLip and removed request for valarLip January 12, 2026 15:04
Copy link
Collaborator

@ChuanLi1101 ChuanLi1101 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved due to comments has addressed.

@k50112113 k50112113 merged commit c5fab3e into main Jan 14, 2026
4 checks passed
@k50112113 k50112113 deleted the shaoclee/ds_fp4_gemm branch January 14, 2026 18:26
PerryZhang01 pushed a commit that referenced this pull request Jan 30, 2026
* tmp

* fix

* clean

* Making the BMM use fp4 weights

* add ATOM_USE_TRITON_GEMM and a16wfp4 gemm for o_proj

* Cleaning up the code and ensuring other weights wont crash

* add import check for gemm_a16wfp4_preshuffle

* clean

* clean

* disable FP4 triton GEMM on o_proj on DS FP4

* Fused rms for fp4

* Adding the x_scale change in linear.py to choose when to quantize or not

* Enabling the second fused rms before attention

* Fixing issue where there was a shape mismatch when running the second fused rms

* Marking shuffle and shuffle padding as true temporarily always

* Working implemenation of fused_rms for fp4

* Formatting fixes

* Fix syntax error

* Remove some commented code from the fp4 section

* disable only AR + input layernorm with ATOM_ENABLE_RMSNORM_QUANT_FUSION=1

* add _fuse_qkv_a_proj_reduce_rmsnorm_quant for DS FP4

* add gemm split + cat for DS FP4

* Integreated fused rmsnorm + quant in decoder layer

* No need to fuse post attention

* Refactored fusion condition

* Transpose scales for input layernorm

* Added torch compile guards on fusion to enable torch compiler

* Refactored fp8 fused rms quant function

* Added fp8 triton preshuffled gemm

* Fixed triton gemm condition

* Added fused rmsnorm quant fp8 back in

* Added transpose_scale back to fp8 fake function

* Remove duplicate env

* Implemented fp8 gemm preshuffled + split + cat

* add back triton fusk_rope_kv_cache

* consider both AR_RMS + Quant and AR + RMS_Quant condition via ATOM_ENABLE_RMSNORM_QUANT_FUSION

* Implemented fp8 fused reduce rms quant

* change boundary

* Removed unreachable branch

* Added transpose_scale to fused reduce rms quant

* fix

* clean

* add a16w8 preshuffle gemm

* clean

* change fp8 gemm boundary

* triton fp8 gemm rename

* remove loader change

* remove comments

* address comments

---------

Co-authored-by: Omar Muhammad <omar.muhammad@amd.com>
Co-authored-by: Farel Lukas <farlukas@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants