Skip to content

[AMD][HIP] NSA: bf16 passthrough from RMSNorm to eliminate FP8 dequantization#22258

Merged
HaiShaw merged 3 commits intosgl-project:mainfrom
Jacob0226:jacob/nsa_bf16_passthrough
Apr 10, 2026
Merged

[AMD][HIP] NSA: bf16 passthrough from RMSNorm to eliminate FP8 dequantization#22258
HaiShaw merged 3 commits intosgl-project:mainfrom
Jacob0226:jacob/nsa_bf16_passthrough

Conversation

@Jacob0226
Copy link
Copy Markdown
Contributor

@Jacob0226 Jacob0226 commented Apr 7, 2026

Motivation

When running GLM-5-FP8 on MI355X with FP8 quantization, the NSA indexer's head gate path (weights_proj) requires bf16 input. However, the upstream fused_rms_fp8_group_quant only outputs FP8 tensors, discarding the bf16 intermediate from RMSNorm. This forces a redundant FP8-to-bf16 dequantization (4 PyTorch eager kernels, ~18 us per layer) that reconstructs a value already computed inside RMSNorm.

Modifications

  • communicator.py: When NSA is active on gfx95 with FP8 quantization, set output_unquantized_inp1=True in fused_rms_fp8_group_quant to preserve the bf16 output at near-zero cost. Pack the result as a 3-tuple (x_fp8, x_scale, x_bf16).
  • nsa_indexer.py: In _weights_proj_bf16_in_fp32_out, extract x[2] (bf16) directly from the 3-tuple on HIP, completely bypassing dequantization. Skip the dequant block in forward_cuda for 3-tuple inputs. Also enable torch.compile on HIP for _get_logits_head_gate and _project_and_scale_head_gates (aligned with PR Reduce unnecessary kernels and copies in the NSA indexer #22232).

The 3-tuple flows transparently through the existing tuple-handling paths: Fp8LinearMethod.apply uses only x[0] and x[1] for FP8 GEMMs (e.g., self.wk), so x[2] is naturally ignored by quantized layers.

This change affects the AMD HIP path only.
image

Accuracy Tests

GSM8K accuracy:

  • GLM-5-FP8 on MI355X (TP8): 0.945
  • GLM-5-FP8 on B200 (TP8): 0.949

Speed Tests and Profiling

GLM-5-FP8 server cmd on MI355:

export SGLANG_ROCM_FUSED_DECODE_MLA=0
export ROCM_QUICK_REDUCE_QUANTIZATION=INT4
export SAFETENSORS_FAST_GPU=1
python3 -m sglang.launch_server \
  --model-path GLM-5-FP8 \
  --tp 8 \
  --port 9000 --trust-remote-code \
  --tool-call-parser glm47 --reasoning-parser glm45 \
  --mem-fraction-static 0.85 \
  --model-loader-extra-config '{"enable_multithread_load": true, "num_threads": 8}' \
  --nsa-prefill-backend tilelang --nsa-decode-backend tilelang --disable-radix-cache \
  --kv-cache-dtype fp8_e4m3

Benchmark on MI355X TP8, concurrency 4/8/16/32/64 averaged (baseline: sglang PR #22232 + aiter PR#2575):

  • ISL/OSL 1k/1k: Throughput +3.4%, TPOT +3.8%
  • ISL/OSL 8k/1k: Throughput +2.8%, TPOT +1.0%

Per-layer profiling:

  • FP8 dequantization: ~18 us → 0 us (4 kernels eliminated)

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

Made with Cursor

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Native Sparse Attention (NSA) with FP8 quantization on AMD (HIP) platforms. It updates the nsa_indexer to process unquantized activation tuples and modifies the communicator to manage these activations when NSA is active. The review feedback suggests enhancing type safety by updating function signatures in nsa_indexer.py to explicitly include Union[torch.Tensor, Tuple[torch.Tensor, ...]] for parameters that now accept multiple types.

yield

def _weights_proj_bf16_in_fp32_out(self, x: torch.Tensor) -> torch.Tensor:
def _weights_proj_bf16_in_fp32_out(self, x) -> torch.Tensor:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better type safety and readability, please add a type hint for the x parameter. Based on its usage and the caller _get_logits_head_gate, it should be Union[torch.Tensor, Tuple[torch.Tensor, ...]].

Suggested change
def _weights_proj_bf16_in_fp32_out(self, x) -> torch.Tensor:
def _weights_proj_bf16_in_fp32_out(self, x: Union[torch.Tensor, Tuple[torch.Tensor, ...]]) -> torch.Tensor:


@torch.compile(dynamic=True) if not _is_hip else lambda f: f
@torch.compile(dynamic=True)
def _project_and_scale_head_gates(self, x: torch.Tensor):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The type hint for x should be updated to Union[torch.Tensor, Tuple[torch.Tensor, ...]] to reflect that it can now receive a tuple, similar to the change made in _get_logits_head_gate. This function is called with x which can be a tuple in forward_cuda.

Suggested change
def _project_and_scale_head_gates(self, x: torch.Tensor):
def _project_and_scale_head_gates(self, x: Union[torch.Tensor, Tuple[torch.Tensor, ...]]):

@HaiShaw
Copy link
Copy Markdown
Collaborator

HaiShaw commented Apr 9, 2026

/tag-and-rerun-ci

@github-actions github-actions bot added the run-ci label Apr 9, 2026
Copy link
Copy Markdown
Collaborator

@HaiShaw HaiShaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please align the producer use _use_aiter and _is_gfx95_supported, and consumer use _is_hip. After #22422 you can use _use_aiter_gfx95

@Jacob0226 Jacob0226 force-pushed the jacob/nsa_bf16_passthrough branch from 0331677 to 9fdf72b Compare April 10, 2026 00:54
Jacob0226 and others added 3 commits April 10, 2026 00:58
Eliminate the FP8 dequantization overhead (Step 10) by passing the
unquantized bf16 tensor from fused_rms_fp8_group_quant directly to
the NSA gate projection (weights_proj).  On gfx95 with NSA enabled,
output_unquantized_inp1=True produces a 3-tuple (fp8, scale, bf16)
that flows through the communicator and indexer unchanged.

Also enable torch.compile on HIP for _get_logits_head_gate and
_project_and_scale_head_gates (aligned with PR sgl-project#22232).

Expected improvement on MI355X:
  - Step 10: ~18 us → 0 us (dequant completely eliminated)

Made-with: Cursor
- communicator.py: Document that fused RMSNorm + FP8 group quant with
  NSA bf16 passthrough is aiter (ROCm gfx95) specific
- nsa_indexer.py: Add Union type hints to _weights_proj_bf16_in_fp32_out
  and _project_and_scale_head_gates per gemini-code-assist review
- nsa_indexer.py: Clarify HIP/aiter context in forward_cuda comments

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Made-with: Cursor
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Made-with: Cursor
@Jacob0226 Jacob0226 force-pushed the jacob/nsa_bf16_passthrough branch from 9fdf72b to 88b7471 Compare April 10, 2026 00:58
@Jacob0226 Jacob0226 requested a review from HaiShaw April 10, 2026 01:22
@Jacob0226
Copy link
Copy Markdown
Contributor Author

Please align the producer use _use_aiter and _is_gfx95_supported, and consumer use _is_hip. After #22422 you can use _use_aiter_gfx95

Done. Updated the consumer guard in nsa_indexer.py to use _use_aiter and _is_gfx95_supported, aligned with the producer side in communicator.py.

@HaiShaw
Copy link
Copy Markdown
Collaborator

HaiShaw commented Apr 10, 2026

@amd-bot ci-status

@amd-bot
Copy link
Copy Markdown

amd-bot commented Apr 10, 2026

@HaiShaw

CI Status for PR #22258

PR: [AMD][HIP] NSA: bf16 passthrough from RMSNorm to eliminate FP8 dequantization
Changed files: nsa_indexer.py (+30/-7), communicator.py (+38/-18)

Both changed files modify AMD HIP + gfx95 codepaths only — guarded by _use_aiter and _is_gfx95_supported.

Job Error Related? Explanation Log
stage-b-test-1-gpu-small (0) Install timeout: pip download of torch-2.9.1+cu129 (1.2GB) exceeded 20min step limit 🟢 Unlikely Infrastructure/network issue downloading PyTorch wheel. Not code-related. Log
stage-b-test-1-gpu-small (1) Same install timeout 🟢 Unlikely Same infrastructure issue Log
stage-b-test-1-gpu-small (3-7) Fast-fail: skipping -- root cause job(s): small (1), small (0) 🟢 Unlikely Cascade from install timeout above. No tests ran. Log
stage-b-test-1-gpu-large (5-13) Fast-fail: skipping -- root cause job(s): small (1), small (0) 🟢 Unlikely Same cascade. No tests ran. Log
stage-b-test-4-gpu-b200 Fast-fail: skipping 🟢 Unlikely Same cascade. No tests ran. Log
stage-b-test-1-gpu-small-amd (11) AssertionError: 88.77 not less than 86 in test_multi_tokenizer_ttft 🟢 Unlikely TTFT latency test on tokenizer benchmark (unrelated to NSA/FP8). Threshold is 86ms, measured ~89ms. Flaky perf assertion. Log
build-test (all) NotImplementedError: flash_attn at sgl-kernel is only supported on sm90 and above 🟢 Unlikely sgl-kernel flash_attn test on non-SM90 runner. Unrelated to this PR's changes. Log
build-and-test (XPU) NotImplementedError: flash_attn at sgl-kernel is only supported on sm90 and above 🟢 Unlikely XPU test hits SM90 guard in flash_attn_v3. Unrelated to NSA/communicator changes. Log
pr-test-finish, finish, pr-test-amd-finish Gate jobs failed due to upstream failures 🟢 Unlikely Aggregate finish gates; fail because child jobs failed.

Details

No failures are related to this PR. All 24 failures fall into 3 categories, none of which touch the PR's changes:

  1. Install timeout (15 jobs): Two NVIDIA runners (small (0) and small (1)) timed out downloading the PyTorch wheel (1.2GB) over a slow network link. This triggered the fast-fail cascade that aborted all remaining NVIDIA GPU jobs (small 3-7, large 5-13, b200) without running any tests. This is a transient infrastructure issue.

  2. flash_attn SM90 guard (2 jobs): The build-test (all) and build-and-test (XPU) jobs both hit NotImplementedError in flash_attention_v3.py:195 because the test runners lack SM90+ GPUs. This is unrelated to the NSA/communicator code — it's a pre-existing compatibility issue in flash_attn dispatch.

  3. AMD TTFT flaky threshold (1 job): test_multi_tokenizer_ttft measures tokenizer TTFT latency and asserts < 86ms, but measured ~89ms on MI325. This test (test/registered/tokenizer/test_multi_tokenizer.py:73) tests tokenizer/benchmark latency, not the NSA indexer or FP8 dequant paths modified by this PR.

Passing AMD jobs confirm no regressions from PR changes: All 14 AMD small partitions (0-10, 12-13) passed except partition 11 (the unrelated TTFT test). Both AMD large partitions (0-1), AMD MI35x, AMD nondeterministic, and AMD 8-GPU disaggregation tests all passed. The NVIDIA stage-a and partitions (large 0-4, small 2) that weren't killed by the cascade also passed.

Recommendation: Re-run CI to clear the transient install timeout. No code changes needed from this PR.

Generated by amd-bot using Claude Code CLI

@HaiShaw HaiShaw merged commit dd41764 into sgl-project:main Apr 10, 2026
78 of 102 checks passed
Jacob0226 added a commit to Jacob0226/sglang that referenced this pull request Apr 10, 2026
Keep quant_format == "fp8" (exact match) to prevent fp8_per_token
from being intercepted by the fp8 group-quant path, while preserving
the NSA bf16 passthrough logic (_nsa_needs_bf16 / 3-tuple packing)
from upstream PR sgl-project#22258.

Made-with: Cursor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants