Skip to content

[NPU] perf update with kvcache nz & w4a8 quant#14423

Merged
iforgetmyname merged 6 commits intosgl-project:mainfrom
iforgetmyname:fia_nz
Dec 13, 2025
Merged

[NPU] perf update with kvcache nz & w4a8 quant#14423
iforgetmyname merged 6 commits intosgl-project:mainfrom
iforgetmyname:fia_nz

Conversation

@liupeng374
Copy link
Contributor

@liupeng374 liupeng374 commented Dec 4, 2025

Motivation

1、Use the nz format for kv cache, thie method accelerates the FIA operator.;
2、Moe's w4a8 uses per-channel quantization;
3、Accelerating preprocessing of MHA in prefill using the npu_interleave_rope operator;
4、bugfix num_token_non_padded_cpu;

Modifications

Use export SGLANG_USE_FIA_NZ=1 to enable FIA NZ, and this feature must be turned on together with mlapo export SGLANG_NPU_USE_MLAPO=1.

Accuracy Tests

image

Benchmarking and Profiling

FIANZ can speed up tpot by about 2ms, and the optimization of prefill can improve performance by more than 10%.

Checklist

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @liupeng374, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates support for an 'FIA NZ' optimization specifically tailored for Ascend NPUs. The primary goal is to enhance the performance and efficiency of attention operations by modifying KV cache management and Rotary Positional Embedding (RoPE) application within the DeepseekV2 attention mechanism. This is achieved through the introduction of a new cache mode and a refactored RoPE calculation that leverages precomputed values and a unified NPU kernel.

Highlights

  • FIA NZ Cache Mode: Introduced a new 'FIA NZ' cache mode for NPU operations, controlled by the SGLANG_USE_FIA_NZ environment variable, enabling conditional logic for KV cache handling and attention preprocessing.
  • Refactored KV Cache and RoPE for DeepseekV2: The KV cache management and Rotary Positional Embedding (RoPE) application for DeepseekV2 attention on NPU have been refactored, integrating a new torch_npu.npu_kv_rmsnorm_rope_cache function for unified processing.
  • Optimized Rotary Embedding Calculation: The Rotary Positional Embedding calculation now precomputes and caches cos and sin values, providing a new utility method (get_cos_sin_cache) for efficient retrieval and formatting, simplifying the forward_npu method.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for the NZ format on Ascend NPUs, controlled by the SGLANG_USE_FIA_NZ environment variable. The changes primarily affect attention preprocessing and rotary embeddings to leverage NPU-specific optimizations. My review focuses on ensuring these changes are correctly implemented without introducing regressions. I've identified a critical issue in rotary_embedding.py that could break other models, a likely typo in mla_preprocess.py that could cause runtime errors, and some code duplication that could be refactored for better maintainability. Addressing these points will improve the robustness and quality of the code.

) # (B*S,N,1,D)

cache_mode = "PA_BNSD"
cache_mode = ("PA_NZ" if _use_fia_nz else "PA_BNSD",)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The cache_mode is being assigned a tuple with a trailing comma. This is likely a typo, as the torch.ops.npu.npu_kv_rmsnorm_rope_cache operator probably expects a string, similar to other parts of the code. This could cause a runtime error or unexpected behavior.

Suggested change
cache_mode = ("PA_NZ" if _use_fia_nz else "PA_BNSD",)
cache_mode = "PA_NZ" if _use_fia_nz else "PA_BNSD"

Comment on lines +172 to +204
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached_total = torch.cos(emb) * self.mscale
self.sin_cached_total = torch.sin(emb) * self.mscale
return cache

def get_cos_cached_total(self):
return self.cos_cached_total

def get_sin_cached_total(self):
return self.sin_cached_total

def get_cos_sin_cache(
self, positions, dtype, offsets: Optional[torch.Tensor] = None
):
self.cos_cached = (
self.cos_cached_total[
torch.add(positions, offsets) if offsets is not None else positions
]
.unsqueeze(-2)
.unsqueeze(-2)
.to(dtype)
)
self.sin_cached = (
self.sin_cached_total[
torch.add(positions, offsets) if offsets is not None else positions
]
.unsqueeze(-2)
.unsqueeze(-2)
.to(dtype)
)
cos = self.cos_cached.to(positions.device)
sin = self.sin_cached.to(positions.device)
return cos, sin
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The new logic to compute cos_cached_total and sin_cached_total in _compute_cos_sin_cache, and the new get_cos_sin_cache method, depend on self.mscale. This attribute is not present in the base RotaryEmbedding class or some of its subclasses (e.g., LinearScalingRotaryEmbedding), which will cause an AttributeError for models using them.

This logic appears to be specific to DeepseekScalingRotaryEmbedding. To fix this and prevent breaking other models, please:

  1. Revert the changes in RotaryEmbedding._compute_cos_sin_cache.
  2. Move the logic for computing self.cos_cached_total and self.sin_cached_total into DeepseekScalingRotaryEmbedding._compute_cos_sin_cache.
  3. Move the new methods (get_cos_cached_total, get_sin_cached_total, and get_cos_sin_cache) from RotaryEmbedding to DeepseekScalingRotaryEmbedding.

Here's a suggested implementation for DeepseekScalingRotaryEmbedding._compute_cos_sin_cache:

def _compute_cos_sin_cache(self) -> torch.Tensor:
    inv_freq = self._compute_inv_freq(self.scaling_factor)
    t = torch.arange(
        self.max_position_embeddings * self.scaling_factor,
        device=self.device,
        dtype=torch.float32,
    )
    freqs = torch.einsum("i,j -> ij", t, inv_freq)
    cos = freqs.cos() * self.mscale
    sin = freqs.sin() * self.mscale
    cache = torch.cat((cos, sin), dim=-1)

    emb = torch.cat((freqs, freqs), dim=-1)
    self.cos_cached_total = torch.cos(emb) * self.mscale
    self.sin_cached_total = torch.sin(emb) * self.mscale

    return cache

from sglang.srt.hardware_backend.npu.utils import npu_format_cast
from sglang.srt.utils import get_bool_env_var

_use_fia_nz = get_bool_env_var("SGLANG_USE_FIA_NZ")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _use_fia_nz flag is also defined in python/sglang/srt/hardware_backend/npu/modules/deepseek_v2_attention_mla_npu.py. To avoid code duplication and improve maintainability, consider defining this flag once in a shared utility module (e.g., sglang.srt.hardware_backend.npu.utils) and importing it where needed.

from sglang.srt.models.deepseek_v2 import DeepseekV2AttentionMLA
from sglang.srt.utils import BumpAllocator

_use_fia_nz = get_bool_env_var("SGLANG_USE_FIA_NZ")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _use_fia_nz flag is also defined in python/sglang/srt/hardware_backend/npu/attention/mla_preprocess.py. To avoid code duplication and improve maintainability, consider defining this flag once in a shared utility module (e.g., sglang.srt.hardware_backend.npu.utils) and importing it where needed.

@ping1jing2 ping1jing2 changed the title [ascend] FIA support NZ [NPU] FIA support NZ Dec 4, 2025
@ping1jing2 ping1jing2 self-assigned this Dec 4, 2025
@liupeng374 liupeng374 force-pushed the fia_nz branch 10 times, most recently from 8fe44cc to 0f17578 Compare December 11, 2025 01:45
@iforgetmyname
Copy link
Collaborator

/tag-and-rerun-ci

@liupeng374 liupeng374 force-pushed the fia_nz branch 2 times, most recently from d415a2c to 869df3c Compare December 11, 2025 14:13
@liupeng374 liupeng374 force-pushed the fia_nz branch 3 times, most recently from 0123a0a to 4a3e989 Compare December 12, 2025 06:55
@iforgetmyname
Copy link
Collaborator

/rerun-failed-ci

@iforgetmyname iforgetmyname changed the title [NPU] FIA support NZ [NPU] perf update with kvcache nz & w4a8 quant Dec 13, 2025
@iforgetmyname iforgetmyname merged commit d36299a into sgl-project:main Dec 13, 2025
171 of 183 checks passed
@iforgetmyname iforgetmyname deleted the fia_nz branch December 13, 2025 09:40
Liwansi added a commit to iforgetmyname/sglang that referenced this pull request Dec 13, 2025
…n_eagle3_npu

* 'main' of https://github.com/sgl-project/sglang: (25 commits)
  [NPU] perf update with kvcache nz & w4a8 quant (sgl-project#14423)
  [PP Prefill][NIXL] Fix PP mode transfer completion tracking to wait for all ranks (sgl-project#15027)
  Fix GLM-4.6 tool calls don't support streaming output for arguments i… (sgl-project#13989)
  feature: adding nightly wheel workflow and indexer (sgl-project#14924)
  [diffusion] feat: Improve LoRA compatibility by adding unified format detection and diffusers-based normalization (sgl-project#14659)
  [Fix] Disable trtllm moe backend for draft model for a qucik fix (sgl-project#15002)
  [diffusion] fix: use NDRotaryEmbedding in flux_2   (sgl-project#15034)
  Mistral Large 3 NVFP4 support (sgl-project#14485)
  call check_quantized_moe_compatibility after initialize (sgl-project#13876)
  Add sgl_router_attempt_http_responses_total for single attempt information (sgl-project#15037)
  Add error code in prometheus metrics and add X-SMG-Error-Code header (sgl-project#15036)
  Provide more fine grained error reason for reqwest error (sgl-project#15032)
  Tiny change http router response format to unify (sgl-project#15031)
  Tiny unify grpc existing error responses into new format (sgl-project#15030)
  Add `code` field and unify error responses for router (sgl-project#15028)
  Super tiny remove unused log_request (sgl-project#15035)
  Fix decode OOM caused by retraction (sgl-project#14939)
  [CI]Add gb200 runner back (sgl-project#15024)
  Add a special label for b200 CI runner that can run kernel tests (sgl-project#15033)
  Fix regression caused by fa3 block_table (sgl-project#15009)
  ...

# Conflicts:
#	python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py
Prozac614 pushed a commit to Prozac614/sglang that referenced this pull request Dec 17, 2025
YChange01 pushed a commit to YChange01/sglang that referenced this pull request Jan 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants