Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions vllm_ascend/attention/sfa_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ class AscendSFAMetadata:
dsa_cp_context: DSACPContext | None = None
reshape_cache_event: torch.npu.Event = None
sfa_cp_metadata: AscendPCPMetadata | None = None
# Shared top-k indices reused by IndexCache layers in the same forward pass.
shared_topk_indices: torch.Tensor | None = None
num_decodes: int = 0
num_decode_tokens: int = 0
num_prefills: int = 0
Expand Down Expand Up @@ -447,6 +449,7 @@ def __init__(
if self.vllm_config.model_config.hf_config.model_type in ["glm_moe_dsa"]:
self.is_rope_neox_style = False
self.use_torch_npu_lightning_indexer = True
self.skip_topk = kwargs.get("skip_topk", False)

# dsa c8
self.use_sparse_c8_indexer = ascend_config.is_sparse_c8_layer(self.layer_name)
Expand Down Expand Up @@ -1229,16 +1232,20 @@ def forward(
if self.is_kv_producer:
attn_metadata.reshape_cache_event.record()

topk_indices = self.indexer_select_post_process(
x=hidden_states,
q_c=q_c,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
cos=cos,
sin=sin,
actual_seq_lengths_query=actual_seq_lengths_query,
actual_seq_lengths_key=actual_seq_lengths_key,
)
if self.skip_topk and attn_metadata.shared_topk_indices is not None:
topk_indices = attn_metadata.shared_topk_indices
else:
topk_indices = self.indexer_select_post_process(
x=hidden_states,
q_c=q_c,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
cos=cos,
sin=sin,
actual_seq_lengths_query=actual_seq_lengths_query,
actual_seq_lengths_key=actual_seq_lengths_key,
)
attn_metadata.shared_topk_indices = topk_indices
Comment on lines +1235 to +1248
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current implementation of IndexCache appears to have a critical flaw. The topk_indices are cached and reused across layers, but their computation in indexer_select_post_process depends on hidden_states, which is unique to each layer.

Specifically, indexer_select_post_process uses x (which is hidden_states) to compute weights:

weights, _ = self.weights_proj(x)

These weights are then used to determine topk_indices. Since hidden_states differ from one layer to the next, the topk_indices will also be different. Reusing them will lead to incorrect attention calculations.

For IndexCache to work correctly, the computation of topk_indices must be based on tensors that are shared across the layers intended to use the cache. This might require passing a shared tensor to indexer_select_post_process instead of the per-layer hidden_states.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review. This behavior is intentional and matches upstream IndexCache semantics (vLLM PR #37735, issue #37684).
IndexCache is an approximate optimization: “full” layers compute top-k indices, while “shared” layers reuse cached indices to reduce redundant computation.
In our implementation, reuse is only enabled when skip_topk=True; otherwise indices are computed per layer as usual.
We’ll also attach accuracy/performance results to quantify the tradeoff.


attn_output = self._execute_sparse_flash_attention_process(
ql_nope, q_pe, kv_cache, topk_indices, attn_metadata, actual_seq_lengths_query, actual_seq_lengths_key
Expand Down
2 changes: 2 additions & 0 deletions vllm_ascend/ops/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
skip_topk: bool = False,
) -> None:
nn.Module.__init__(self)
self.hidden_size = hidden_size
Expand Down Expand Up @@ -122,6 +123,7 @@ def __init__(
kv_a_layernorm=mla_modules.kv_a_layernorm,
o_proj=mla_modules.o_proj,
layer_name=f"{prefix}.attn",
skip_topk=skip_topk,
)

original_process_weights = self.mla_attn.process_weights_after_loading
Expand Down
Loading