From 67519c7d8f440adb5a0bbbfe9aa28602661b8ec6 Mon Sep 17 00:00:00 2001 From: tanhaoan333 Date: Wed, 4 Mar 2026 10:57:22 +0800 Subject: [PATCH 1/2] Update mm_encoder_attention.py Signed-off-by: tanhaoan333 --- vllm_ascend/ops/mm_encoder_attention.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/vllm_ascend/ops/mm_encoder_attention.py b/vllm_ascend/ops/mm_encoder_attention.py index 889b88c42da..b3d389d94cd 100644 --- a/vllm_ascend/ops/mm_encoder_attention.py +++ b/vllm_ascend/ops/mm_encoder_attention.py @@ -62,7 +62,6 @@ def __init__( prefix=prefix, ) - self.layer_index = int("".join(filter(str.isdigit, prefix))) self.enable_pad = self.head_size > MIN_PAD_SIZE and self.head_size < MAX_PAD_SIZE self.scale_value = self.head_size**-0.5 @@ -104,11 +103,9 @@ def forward_oot( # Directly use seq_lens cpu cache to avoid d2h copy. global seq_lens_cpu_cache - if self.layer_index == 0: - if cu_seqlens is None: - cu_seqlens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device="cpu") - # Update seq_lens cpu cache. - seq_lens_cpu_cache = torch.diff(cu_seqlens).to("cpu") + if cu_seqlens is None: + cu_seqlens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device="cpu") + seq_lens_cpu = torch.diff(cu_seqlens).to("cpu") # q, k, v: [b, s, head, head_dim] -> [b * s, head, head_dim] q, k, v = self._reshape_qkv_to_3d(query, key, value, bsz, q_len, kv_len) @@ -128,7 +125,7 @@ def forward_oot( query=q, key=k, value=v, - seq_len=seq_lens_cpu_cache, + seq_len=seq_lens_cpu, scale_value=self.scale_value, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, From 4b2f257a3cb9c7fe86738a03ee73376206222c32 Mon Sep 17 00:00:00 2001 From: tanhaoan333 Date: Wed, 4 Mar 2026 11:07:26 +0800 Subject: [PATCH 2/2] Update mm_encoder_attention.py Signed-off-by: tanhaoan333 --- vllm_ascend/ops/mm_encoder_attention.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm_ascend/ops/mm_encoder_attention.py b/vllm_ascend/ops/mm_encoder_attention.py index b3d389d94cd..4122c7f178e 100644 --- a/vllm_ascend/ops/mm_encoder_attention.py +++ b/vllm_ascend/ops/mm_encoder_attention.py @@ -102,7 +102,6 @@ def forward_oot( is_reshaped = query.dim() == 4 # Directly use seq_lens cpu cache to avoid d2h copy. - global seq_lens_cpu_cache if cu_seqlens is None: cu_seqlens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device="cpu") seq_lens_cpu = torch.diff(cu_seqlens).to("cpu")