Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
weakref.proxy(self))
self.attn_mask_builder = AttentionMaskBuilder(
min(self.model_config.max_model_len,
int(os.getenv("PAGED_ATTENTION_MASK_LEN", 10000))), self.dtype)
int(os.getenv("PAGED_ATTENTION_MASK_LEN", 128))), self.dtype)

# Set up speculative decoding.
self.use_aux_hidden_state_outputs = False
Expand Down Expand Up @@ -834,9 +834,11 @@ def _make_attention_mask(self, seq_lens, query_lens, position,
seq_lens, query_lens, position, self.dtype, self.device)
# Prefill without cache situation.
elif attn_state == AscendAttentionState.PrefillNoCache:
max_seq_len = max(seq_lens, default=0)
# Note: `torch_npu._npu_flash_attention` only requires a 128x128 mask, so we hardcode it here.
# Once a new attention operator for prefill-only state is added,
# the mask generation logic here must be updated according to the new operator used.
return self.attn_mask_builder.get_attn_mask(
max_seq_len, self.dtype, self.device)
128, self.dtype, self.device)
# Prefill with cache hit.
elif attn_state == AscendAttentionState.PrefillCacheHit:
return self.attn_mask_builder.get_attn_mask(
Expand Down
Loading