Skip to content

[AMD] Fix GLM-5 fp8 KV quant path dispatch on MI300#22314

Merged
HaiShaw merged 4 commits intosgl-project:mainfrom
1am9trash:fix-mi300-quant-path
Apr 8, 2026
Merged

[AMD] Fix GLM-5 fp8 KV quant path dispatch on MI300#22314
HaiShaw merged 4 commits intosgl-project:mainfrom
1am9trash:fix-mi300-quant-path

Conversation

@1am9trash
Copy link
Copy Markdown
Collaborator

@1am9trash 1am9trash commented Apr 8, 2026

Motivation

On MI300, running GLM-5-fp8 with FP8 KV cache can fail (see CI log).
The root cause is that the quant path does not dispatch the correct kernel (set_mla_kv_buffer_triton_fp8_quant).

Modifications

The flag self.nsa_kv_cache_store_fp8 is true only when KV cache is stored in fp8 with scaling. Our attention path uses fp8 KV cache without scaling, so it should not be gated by this flag.
This change moves the HIP + fp8 quant path out of the scaling-specific branch, ensuring MI300 dispatches the correct fused kernel (set_mla_kv_buffer_triton_fp8_quant).

This change only affects the MI300 code path.

Accuracy Tests

GLM-5-fp8 with fp8 kvcache Accuracy: 0.945
Also validated with the new CI script test_glm5_perf_amd.py prepared in PR #21710.

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

if (
_is_hip
and self.use_nsa
and self.dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can import "from sglang.srt.layers.quantization.fp8_kernel import fp8_dtype"
and use "self.dtype == fp8_dtype" to do condition check.

fp8_dtype is torch.float8_e4m3fnuz on mi300x and torch.float8_e4m3fn on mi35x

Copy link
Copy Markdown
Collaborator Author

@1am9trash 1am9trash Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed and reran well. Really appreciate the reminder.

):
# HIP FP8 path uses raw MLA KV layout (nope + rope) without per-block scales.
# Fuse BF16/FP16 -> FP8 cast with paged KV write.
fp8_dtype = torch.float8_e4m3fnuz if _is_fp8_fnuz else torch.float8_e4m3fn
Copy link
Copy Markdown
Collaborator

@kkHuang-amd kkHuang-amd Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove Line 1585, when you use from sglang.srt.layers.quantization.fp8_kernel import fp8_dtype

@HaiShaw HaiShaw merged commit 729b74d into sgl-project:main Apr 8, 2026
54 of 62 checks passed
michaelzhang-ai added a commit that referenced this pull request Apr 8, 2026
Now that #22314 (MI300 FP8 KV quant dispatch fix) and #22232 (NSA
indexer clone fix) are merged, re-enable FP8 KV cache for both
MI30x and MI35x perf tests.
JustinTong0323 pushed a commit to JustinTong0323/sglang that referenced this pull request Apr 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants