Skip to content

[AMD] Use aiter CK layernorm2d for LayerNorm to reduce NSA indexer kernel launches#22424

Merged
HaiShaw merged 4 commits intosgl-project:mainfrom
1am9trash:use-aiter-ck-layernorm
Apr 9, 2026
Merged

[AMD] Use aiter CK layernorm2d for LayerNorm to reduce NSA indexer kernel launches#22424
HaiShaw merged 4 commits intosgl-project:mainfrom
1am9trash:use-aiter-ck-layernorm

Conversation

@1am9trash
Copy link
Copy Markdown
Collaborator

@1am9trash 1am9trash commented Apr 9, 2026

cc @Jacob0226

Motivation

The current LayerNorm on HIP uses the torch implementation and triggers an extra dtype cast at both entry and exit. This results in 3 kernels per LayerNorm call (cast -> layernorm -> cast), hurting the performance of operations like k_norm() in the GLM-5-FP8 NSA indexer.

Modifications

  • In LayerNorm.forward_hip(), use the aiter CK kernel layernorm2d_fwd() when the dtype is bf16 or fp16. For other dtypes, fall back to the original torch code path.
  • Change k_norm dtype in the NSA indexer from fp32 to bf16 when aiter is enabled, so it can take the CK kernel path.

Accuracy Tests

LayerNorm unit test:

  • Command: python -m pytest python/sglang/test/test_layernorm.py::TestLayerNorm -v
  • Result: all 384 subtests passed

Model test:

  • GLM-5-FP8 on MI355 GSM8k (TP8): 0.946

Speed Tests and Profiling

GLM-5-FP8 server command on MI355:

export SGLANG_ROCM_FUSED_DECODE_MLA=0
export ROCM_QUICK_REDUCE_QUANTIZATION=INT4
export SAFETENSORS_FAST_GPU=1
python3 -m sglang.launch_server \
  --model-path GLM-5-FP8 \
  --tp 8 --port 9000 --trust-remote-code \
  --tool-call-parser glm47 --reasoning-parser glm45 \
  --mem-fraction-static 0.85 \
  --model-loader-extra-config '{"enable_multithread_load": true, "num_threads": 8}' \
  --nsa-prefill-backend tilelang --nsa-decode-backend tilelang --disable-radix-cache \
  --kv-cache-dtype fp8_e4m3

Benchmark on MI355X TP8, concurrency 4/8/16/32/64 averaged (baseline: sglang PR #22258 + aiter PR #2575):

  • ISL/OSL 1k/1k: Throughput +1.4%, TPOT -0.9%
  • ISL/OSL 8k/1k: Throughput +1.2%, TPOT -2.8%

Per-layer profiling:

  • Time: ~12us -> ~4us
  • Kernel: 3 -> 1

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@HaiShaw HaiShaw merged commit 628df31 into sgl-project:main Apr 9, 2026
54 of 62 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants