Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,9 @@ line-length = 120
# Folder to be modified
exclude = [
"tests/**",
"vllm_ascend/_cann_ops_custom",
"vllm_ascend/attention",
"vllm_ascend/attention/mla_v1.py",
"vllm_ascend/attention/sfa_v1.py",
"vllm_ascend/core",
"vllm_ascend/device",
"vllm_ascend/device_allocator",
"vllm_ascend/distributed",
"vllm_ascend/eplb",
"vllm_ascend/kv_offload",
Expand All @@ -66,8 +64,6 @@ exclude = [
"vllm_ascend/spec_decode",
"vllm_ascend/worker",
"vllm_ascend/xlite",
"vllm_ascend/envs.py",
"vllm_ascend/batch_invariant.py",
]

[tool.ruff.lint]
Expand Down
29 changes: 11 additions & 18 deletions vllm_ascend/attention/attention_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,18 @@

def _generate_attn_mask(max_seq_len, dtype):
# Construct lower triangle matrix.
mask_flag = torch.ones((max_seq_len, max_seq_len),
dtype=torch.bool).tril_()
mask_flag = torch.ones((max_seq_len, max_seq_len), dtype=torch.bool).tril_()
# Create upper triangle matrix used to mark mask positions.
mask_flag = ~mask_flag
# Currently for fp16 dtype, the mask value should be set to -inf.
# TODO: Eliminate this part in the future.
mask_value = float('-inf') if dtype == torch.float16 else 1
attn_mask = torch.zeros(size=(max_seq_len, max_seq_len), dtype=dtype) \
.masked_fill_(mask_flag, mask_value)
mask_value = float("-inf") if dtype == torch.float16 else 1
attn_mask = torch.zeros(size=(max_seq_len, max_seq_len), dtype=dtype).masked_fill_(mask_flag, mask_value)
return attn_mask


@singleton
class AttentionMaskBuilder:

def __init__(self, device: torch.device):
self.attn_mask_cache = None
self._seq_len_cached = 0
Expand All @@ -52,14 +49,13 @@ def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype):
assert self.attn_mask_cache is not None, "Something is wrong in generate_attn_mask."
if self.attn_mask_cache.dtype != dtype:
self.attn_mask_cache = self.attn_mask_cache.to(dtype)
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous(
).to(self.device, non_blocking=True)
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous().to(self.device, non_blocking=True)

def get_splitfuse_attn_mask(self) -> torch.Tensor:
if self.chunked_prefill_attn_mask is None:
self.chunked_prefill_attn_mask = torch.triu(
torch.ones(2048,
2048), diagonal=1).to(torch.int8).to(self.device)
self.chunked_prefill_attn_mask = (
torch.triu(torch.ones(2048, 2048), diagonal=1).to(torch.int8).to(self.device)
)
return self.chunked_prefill_attn_mask

def get_mla_mask(self, dtype: torch.dtype) -> torch.Tensor:
Expand All @@ -68,16 +64,13 @@ def get_mla_mask(self, dtype: torch.dtype) -> torch.Tensor:
mask_value = torch.finfo(torch.float32).min
else:
mask_value = 1
prefill_mask = torch.triu(
torch.ones(512, 512, device=self.device, dtype=dtype), 1)
self.mla_mask = torch.where(prefill_mask == 1, mask_value,
0).to(dtype)
prefill_mask = torch.triu(torch.ones(512, 512, device=self.device, dtype=dtype), 1)
self.mla_mask = torch.where(prefill_mask == 1, mask_value, 0).to(dtype)
return self.mla_mask

def get_pcp_mla_mask(self, dtype: torch.dtype):
if self.pcp_mla_mask is None or self.pcp_mla_mask.dtype != dtype:
self.pcp_mla_mask = torch.triu(
torch.ones(512, 512, device=self.device, dtype=dtype), 1)
self.pcp_mla_mask = torch.triu(torch.ones(512, 512, device=self.device, dtype=dtype), 1)
return self.pcp_mla_mask

def get_swa_mask(self, dtype: torch.dtype, sliding_window):
Expand All @@ -99,4 +92,4 @@ def get_final_mla_mask(self, model_config: ModelConfig):
if get_pcp_group().world_size > 1:
return self.get_pcp_mla_mask(model_config.dtype)
# Prefill stages use 512x512 mask with appropriate dtype
return self.get_mla_mask(model_config.dtype)
return self.get_mla_mask(model_config.dtype)
Loading
Loading