Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,9 +1116,7 @@ def forward(
else:
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
attn_metadata.decode.input_positions,
decode_q_pe.contiguous(),
decode_k_pe,
max_seq_len=attn_metadata.decode.max_seq_lens)
decode_q_pe.contiguous(), decode_k_pe)
if has_prefill:
assert attn_metadata.prefill is not None
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
Expand Down Expand Up @@ -1150,9 +1148,7 @@ def forward(
else:
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
attn_metadata.prefill.input_positions,
prefill_q_pe.contiguous(),
prefill_k_pe,
max_seq_len=attn_metadata.prefill.max_seq_lens)
prefill_q_pe.contiguous(), prefill_k_pe)

assert len(
kv_cache
Expand Down
21 changes: 7 additions & 14 deletions vllm_ascend/ops/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,7 @@ def native_rope_deepseek_forward(self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
max_seq_len: Optional[int] = None):
if max_seq_len is not None and max_seq_len > self.max_seq_len:
_set_cos_sin_cache(self, max_seq_len, query.device, query.dtype)
offsets: Optional[torch.Tensor] = None):
if len(key.shape) == 2:
key = key[:, None, :]
# Note: we implement the non neox_style method with shuffle the last dim and neox style
Expand Down Expand Up @@ -198,8 +195,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
return q_embed, k_embed


def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
def _set_cos_sin_cache(self, max_seq_len, device, dtype):
dim = self.rotary_dim

freq_extra = 1.0 / (self.base**(
Expand All @@ -219,9 +215,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
self.register_buffer("inv_freq", inv_freq, persistent=False)

t = torch.arange(seq_len * self.scaling_factor,
device=device,
dtype=torch.float32)
t = torch.arange(max_seq_len, device=device, dtype=torch.float32)

freqs = torch.outer(t, inv_freq)
cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale
Expand Down Expand Up @@ -266,11 +260,10 @@ def deepseek_rope_init_func(
super(DeepseekScalingRotaryEmbedding,
self).__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
self.max_seq_len = max_position_embeddings
_set_cos_sin_cache(self,
max_position_embeddings,
dtype=dtype,
device="npu")

# NOTE: For ascend friendly computing, reorder sin and cos cache
self.max_seq_len = max_position_embeddings * scaling_factor
_set_cos_sin_cache(self, self.max_seq_len, dtype=dtype, device="npu")


RotaryEmbedding.forward_oot = rope_forward_oot
Expand Down
14 changes: 9 additions & 5 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled and self.vllm_config.model_config.use_mla
self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph
self.torchair_graph_batch_sizes = ascend_config.torchair_graph_config.graph_batch_sizes
self.use_ring_mla = ascend_config.chunked_prefill_for_mla
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we consider which scheduler is be using?

self.use_ring_mla = ascend_config.chunked_prefill_for_mla or \
            not ascend_config.ascend_scheduler_config.enabled


if ascend_config.torchair_graph_config.graph_batch_sizes_init:
self.init_torchair_graph_batch_sizes()
Expand Down Expand Up @@ -908,11 +909,14 @@ def _process_reqs(
else:
attn_state = AscendAttentionState.PrefillCacheHit

attn_mask = self._make_attention_mask(seq_lens=seq_lens,
query_lens=num_scheduled_tokens,
position=positions,
attn_state=attn_state)
self.attn_mask = attn_mask
# NOTE: when use ring_mla, attn_mask don't need to generate here.
if not self.use_ring_mla or attn_state == AscendAttentionState.PrefillNoCache:
attn_mask = self._make_attention_mask(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can a compressed mask be used? If the sequence is too long, it might cause memory waste here.

Refer #1100

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ringmla's mask size is equal to chunksize, it won't cause memory waste.

seq_lens=seq_lens,
query_lens=num_scheduled_tokens,
position=positions,
attn_state=attn_state)
self.attn_mask = attn_mask
self.attn_state = attn_state # type: ignore

extra_builder_kwargs = {}
Expand Down