diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index d3a29852bd6..d52be032707 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -219,7 +219,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): weakref.proxy(self)) self.attn_mask_builder = AttentionMaskBuilder( min(self.model_config.max_model_len, - int(os.getenv("PAGED_ATTENTION_MASK_LEN", 10000))), self.dtype) + int(os.getenv("PAGED_ATTENTION_MASK_LEN", 128))), self.dtype) # Set up speculative decoding. self.use_aux_hidden_state_outputs = False @@ -834,9 +834,11 @@ def _make_attention_mask(self, seq_lens, query_lens, position, seq_lens, query_lens, position, self.dtype, self.device) # Prefill without cache situation. elif attn_state == AscendAttentionState.PrefillNoCache: - max_seq_len = max(seq_lens, default=0) + # Note: `torch_npu._npu_flash_attention` only requires a 128x128 mask, so we hardcode it here. + # Once a new attention operator for prefill-only state is added, + # the mask generation logic here must be updated according to the new operator used. return self.attn_mask_builder.get_attn_mask( - max_seq_len, self.dtype, self.device) + 128, self.dtype, self.device) # Prefill with cache hit. elif attn_state == AscendAttentionState.PrefillCacheHit: return self.attn_mask_builder.get_attn_mask(