Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions vllm_ascend/ops/mm_encoder_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def __init__(
prefix=prefix,
)

self.layer_index = int("".join(filter(str.isdigit, prefix)))
self.enable_pad = self.head_size > MIN_PAD_SIZE and self.head_size < MAX_PAD_SIZE
self.scale_value = self.head_size**-0.5

Expand Down Expand Up @@ -103,12 +102,9 @@ def forward_oot(
is_reshaped = query.dim() == 4

# Directly use seq_lens cpu cache to avoid d2h copy.
global seq_lens_cpu_cache
if self.layer_index == 0:
if cu_seqlens is None:
cu_seqlens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device="cpu")
# Update seq_lens cpu cache.
seq_lens_cpu_cache = torch.diff(cu_seqlens).to("cpu")
if cu_seqlens is None:
cu_seqlens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device="cpu")
seq_lens_cpu = torch.diff(cu_seqlens).to("cpu")
Comment on lines +105 to +107
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The caching mechanism for seq_lens_cpu has been removed. The original implementation used a global cache (seq_lens_cpu_cache) and updated it only for the first layer (layer_index == 0) to avoid repeated and potentially expensive device-to-host copies of cu_seqlens on every layer. The original comment at lines 27-35 highlighted that this was a performance optimization.

The new implementation computes seq_lens_cpu = torch.diff(cu_seqlens).to("cpu") on every call to forward_oot. If cu_seqlens resides on the NPU, this will introduce a device-to-host copy for each layer, which could cause a performance regression.

If cu_seqlens is now guaranteed to be on the CPU, this change is acceptable. However, the comment on line 104 ("Directly use seq_lens cpu cache to avoid d2h copy.") is now misleading and should be removed or updated. Otherwise, please consider restoring the caching logic to prevent performance degradation.


# q, k, v: [b, s, head, head_dim] -> [b * s, head, head_dim]
q, k, v = self._reshape_qkv_to_3d(query, key, value, bsz, q_len, kv_len)
Expand All @@ -128,7 +124,7 @@ def forward_oot(
query=q,
key=k,
value=v,
seq_len=seq_lens_cpu_cache,
seq_len=seq_lens_cpu,
scale_value=self.scale_value,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
Expand Down
Loading