Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions vllm_ascend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,3 +1116,22 @@ def enable_dsa_cp_with_layer_shard() -> bool:
vllm_config = get_current_vllm_config()
is_prefill_instance = vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_producer
return is_prefill_instance


def check_gdn_layer(vllm_config) -> bool:
"""
gdn layer is marked with `linear_attention`.
So, if `linear_attention` is detected, we think the model has gdn-attention.
"""
if not hasattr(vllm_config, "model_config"):
return False

model_config = vllm_config.model_config
if not hasattr(model_config, "hf_config"):
return False

hf_config = model_config.hf_config
if not hasattr(hf_config, "layer_types"):
return False

return "linear_attention" in hf_config.layer_types
34 changes: 34 additions & 0 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@
from vllm_ascend.spec_decode.medusa_proposer import MedusaProposer
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.utils import (
check_gdn_layer,
enable_flash_comm_v1,
enable_sp,
is_drafter_moe_model,
Expand Down Expand Up @@ -229,6 +230,17 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
dtype=torch.int32,
)

# Now, query_start_loc is padded.
# But gdn needs an unpadded one.
# gdn_query_start_loc is an unpadded version of query_start_loc.
# TODO delete it if fia's check is removed.
self._has_gdn = check_gdn_layer(vllm_config)
if self._has_gdn:
self.gdn_query_start_loc = self._make_buffer(
self.max_num_reqs + 1, # type: ignore[has-type]
dtype=torch.int32,
)

vllm_config.scheduler_config.max_num_batched_tokens -= max_pcp_pad_tokens
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.max_num_reqs = self.scheduler_config.max_num_seqs
Expand Down Expand Up @@ -677,6 +689,16 @@ def _prepare_inputs(
self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens
self.query_start_loc.copy_to_gpu()

# Now, query_start_loc is padded.
# But gdn needs an unpadded one.
# gdn_query_start_loc is an unpadded version of query_start_loc.
# TODO delete it if fia's check is removed.
if self._has_gdn:
self.gdn_query_start_loc.np[0] = 0
self.gdn_query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens
self.gdn_query_start_loc.np[num_reqs + 1 :].fill(cu_num_tokens[-1])
self.gdn_query_start_loc.copy_to_gpu()

self.seq_lens.np[:num_reqs] = self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens
self.seq_lens.copy_to_gpu()

Expand Down Expand Up @@ -2019,6 +2041,18 @@ def _build_attn_group_metadata(
kv_cache_group.kv_cache_spec,
num_reqs_padded,
)

# Now, query_start_loc is padded.
# But gdn needs an unpadded one.
# gdn_query_start_loc is an unpadded version of query_start_loc.
# TODO delete it if fia's check is removed.
if self._has_gdn:
attn_group = self.attn_groups[kv_cache_gid][0]
builder = attn_group.get_metadata_builder(0)
if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder):
cm.query_start_loc_cpu = self.gdn_query_start_loc.cpu[: num_reqs_padded + 1]
cm.query_start_loc = self.gdn_query_start_loc.gpu[: num_reqs_padded + 1]

if kv_cache_gid > 0:
cm.block_table_tensor, cm.slot_mapping = _get_block_table_and_slot_mapping(kv_cache_gid)
if self.speculative_config and spec_decode_common_attn_metadata is None:
Expand Down
Loading