[AMD][HIP] NSA: bf16 passthrough from RMSNorm to eliminate FP8 dequantization#22258
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for Native Sparse Attention (NSA) with FP8 quantization on AMD (HIP) platforms. It updates the nsa_indexer to process unquantized activation tuples and modifies the communicator to manage these activations when NSA is active. The review feedback suggests enhancing type safety by updating function signatures in nsa_indexer.py to explicitly include Union[torch.Tensor, Tuple[torch.Tensor, ...]] for parameters that now accept multiple types.
| yield | ||
|
|
||
| def _weights_proj_bf16_in_fp32_out(self, x: torch.Tensor) -> torch.Tensor: | ||
| def _weights_proj_bf16_in_fp32_out(self, x) -> torch.Tensor: |
There was a problem hiding this comment.
For better type safety and readability, please add a type hint for the x parameter. Based on its usage and the caller _get_logits_head_gate, it should be Union[torch.Tensor, Tuple[torch.Tensor, ...]].
| def _weights_proj_bf16_in_fp32_out(self, x) -> torch.Tensor: | |
| def _weights_proj_bf16_in_fp32_out(self, x: Union[torch.Tensor, Tuple[torch.Tensor, ...]]) -> torch.Tensor: |
|
|
||
| @torch.compile(dynamic=True) if not _is_hip else lambda f: f | ||
| @torch.compile(dynamic=True) | ||
| def _project_and_scale_head_gates(self, x: torch.Tensor): |
There was a problem hiding this comment.
The type hint for x should be updated to Union[torch.Tensor, Tuple[torch.Tensor, ...]] to reflect that it can now receive a tuple, similar to the change made in _get_logits_head_gate. This function is called with x which can be a tuple in forward_cuda.
| def _project_and_scale_head_gates(self, x: torch.Tensor): | |
| def _project_and_scale_head_gates(self, x: Union[torch.Tensor, Tuple[torch.Tensor, ...]]): |
63a918b to
e377298
Compare
|
/tag-and-rerun-ci |
0331677 to
9fdf72b
Compare
Eliminate the FP8 dequantization overhead (Step 10) by passing the unquantized bf16 tensor from fused_rms_fp8_group_quant directly to the NSA gate projection (weights_proj). On gfx95 with NSA enabled, output_unquantized_inp1=True produces a 3-tuple (fp8, scale, bf16) that flows through the communicator and indexer unchanged. Also enable torch.compile on HIP for _get_logits_head_gate and _project_and_scale_head_gates (aligned with PR sgl-project#22232). Expected improvement on MI355X: - Step 10: ~18 us → 0 us (dequant completely eliminated) Made-with: Cursor
- communicator.py: Document that fused RMSNorm + FP8 group quant with NSA bf16 passthrough is aiter (ROCm gfx95) specific - nsa_indexer.py: Add Union type hints to _weights_proj_bf16_in_fp32_out and _project_and_scale_head_gates per gemini-code-assist review - nsa_indexer.py: Clarify HIP/aiter context in forward_cuda comments Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Made-with: Cursor
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Made-with: Cursor
9fdf72b to
88b7471
Compare
Done. Updated the consumer guard in nsa_indexer.py to use _use_aiter and _is_gfx95_supported, aligned with the producer side in communicator.py. |
|
@amd-bot ci-status |
CI Status for PR #22258PR: [AMD][HIP] NSA: bf16 passthrough from RMSNorm to eliminate FP8 dequantization Both changed files modify AMD HIP + gfx95 codepaths only — guarded by
DetailsNo failures are related to this PR. All 24 failures fall into 3 categories, none of which touch the PR's changes:
Passing AMD jobs confirm no regressions from PR changes: All 14 AMD small partitions (0-10, 12-13) passed except partition 11 (the unrelated TTFT test). Both AMD large partitions (0-1), AMD MI35x, AMD nondeterministic, and AMD 8-GPU disaggregation tests all passed. The NVIDIA stage-a and partitions (large 0-4, small 2) that weren't killed by the cascade also passed. Recommendation: Re-run CI to clear the transient install timeout. No code changes needed from this PR.Generated by amd-bot using Claude Code CLI |
Keep quant_format == "fp8" (exact match) to prevent fp8_per_token from being intercepted by the fp8 group-quant path, while preserving the NSA bf16 passthrough logic (_nsa_needs_bf16 / 3-tuple packing) from upstream PR sgl-project#22258. Made-with: Cursor
Motivation
When running GLM-5-FP8 on MI355X with FP8 quantization, the NSA indexer's head gate path (
weights_proj) requires bf16 input. However, the upstreamfused_rms_fp8_group_quantonly outputs FP8 tensors, discarding the bf16 intermediate from RMSNorm. This forces a redundant FP8-to-bf16 dequantization (4 PyTorch eager kernels, ~18 us per layer) that reconstructs a value already computed inside RMSNorm.Modifications
communicator.py: When NSA is active on gfx95 with FP8 quantization, setoutput_unquantized_inp1=Trueinfused_rms_fp8_group_quantto preserve the bf16 output at near-zero cost. Pack the result as a 3-tuple(x_fp8, x_scale, x_bf16).nsa_indexer.py: In_weights_proj_bf16_in_fp32_out, extractx[2](bf16) directly from the 3-tuple on HIP, completely bypassing dequantization. Skip the dequant block inforward_cudafor 3-tuple inputs. Also enabletorch.compileon HIP for_get_logits_head_gateand_project_and_scale_head_gates(aligned with PR Reduce unnecessary kernels and copies in the NSA indexer #22232).The 3-tuple flows transparently through the existing tuple-handling paths:
Fp8LinearMethod.applyuses onlyx[0]andx[1]for FP8 GEMMs (e.g.,self.wk), sox[2]is naturally ignored by quantized layers.This change affects the AMD HIP path only.

Accuracy Tests
GSM8K accuracy:
Speed Tests and Profiling
GLM-5-FP8 server cmd on MI355:
Benchmark on MI355X TP8, concurrency 4/8/16/32/64 averaged (baseline: sglang PR #22232 + aiter PR#2575):
Per-layer profiling:
Checklist
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ciMade with Cursor