Skip to content

[AMD] Fix Aiter RMSNorm layout handling#23974

Merged
HaiShaw merged 2 commits intosgl-project:mainfrom
hubertlu-tw:aiter_rmsnorm_fix
Apr 29, 2026
Merged

[AMD] Fix Aiter RMSNorm layout handling#23974
HaiShaw merged 2 commits intosgl-project:mainfrom
hubertlu-tw:aiter_rmsnorm_fix

Conversation

@hubertlu-tw
Copy link
Copy Markdown
Collaborator

Motivation

This PR fixes a ROCm/Aiter RMSNorm crash seen on MI325/gfx942 when Qwen3-style Q/K normalization passes strided higher-rank views into Aiter's RMSNorm path.

The SGLang-side fix keeps the Aiter backend enabled, but normalizes the input layout before calling Aiter:

  • Already-safe 2D contiguous inputs stay on the zero-copy fast path.
  • Strided or higher-rank inputs are made contiguous and flattened to 2D.
  • The output is reshaped back to the original higher-rank shape.

Root Cause

Qwen3 splits a packed QKV projection into Q/K/V views. Q and K can be shaped like:

q: shape=(2048, 32, 128), stride=(6144, 128, 1)
k: shape=(2048,  8, 128), stride=(6144, 128, 1)

These are valid PyTorch views, but Aiter's RMSNorm kernels expect 2D contiguous input. Passing the strided 3D views directly to Aiter can memory-fault on MI325/gfx942.

Fix

In RMSNorm.forward_aiter():

needs_reshape = x.dim() != 2 and residual is None
if needs_reshape:
    original_shape = x.shape
    x = x.contiguous().reshape(-1, original_shape[-1])
elif not x.is_contiguous():
    x = x.contiguous()

...

if needs_reshape:
    output = output.reshape(original_shape)

This is layout-gated, not architecture-gated. It avoids overhead for the normal 2D contiguous path and only copies when the input layout is unsafe for Aiter.

Validation

Direct Reproducer

Before the fix, directly calling the Aiter RMSNorm path on strided Q/K views reproduced:

Memory access fault by GPU node-2
Fatal Python error: Aborted

After the fix, the same strided Q/K RMSNorm probe succeeds:

sglang aiter RMSNorm strided ok torch.Size([2048, 32, 128]) torch.Size([2048, 8, 128])

Qwen3-8B Server

Launched Qwen3-8B on MI325/gfx942 with:

python -m sglang.launch_server \
    --model-path Qwen/Qwen3-8B \
    --host 0.0.0.0 --port 9000 \
    --dtype bfloat16 \
    --disable-radix-cache \
    --cuda-graph-max-bs 2048 \
    --max-mamba-cache-size 15000 \
    --mem-fraction-static 0.88

The server completed CUDA graph capture and served requests without the memory fault.

Result:

Server launched successfully.
CUDA graph capture completed.
No GPU memory fault.
/v1/models returned 200.

GSM8K benchmark:

python3 benchmark/gsm8k/bench_sglang.py  --num-questions 1319 --parallel 1319 --num-shots 5 --port 9000
Accuracy: 0.902
Invalid: 0.000
Latency: 48.308 s
Output throughput: 3385.019 token/s

Notes

  • This is not a fallback away from Aiter.
  • This does not disable CUDA graphs.
  • This is a caller-side compatibility guard for Aiter's 2D contiguous RMSNorm kernel contract.
  • An Aiter-side public API hardening PR is still recommended so aiter.rmsnorm2d_fwd() itself safely handles higher-rank or strided inputs.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@HaiShaw HaiShaw merged commit e5da200 into sgl-project:main Apr 29, 2026
89 of 111 checks passed
@hnyls2002 hnyls2002 mentioned this pull request Apr 29, 2026
vguduruTT pushed a commit to vguduruTT/sglang that referenced this pull request May 2, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

aiter AI Tensor Engine ROCm amd run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants