Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 47 additions & 20 deletions vllm_ascend/attention/context_parallel/sfa_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ def __init__(
self.block_size = (self.block_size * self.cp_virtual_block_size) // np.gcd(
self.block_size, self.cp_virtual_block_size
)
self.slot_mapping_buf = torch.empty(
(
vllm_config.scheduler_config.max_num_batched_tokens
+ 2 * self.pcp_size * vllm_config.scheduler_config.max_num_seqs,
),
dtype=torch.int32,
device=device,
)

def build(
self,
Expand Down Expand Up @@ -82,15 +90,31 @@ def build(
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
assert long_seq_metadata is not None
num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens_pcp_padded]
self.slot_mapping_buf[:num_actual_tokens_pcp_padded].copy_(
common_attn_metadata.slot_mapping[:num_actual_tokens_pcp_padded], non_blocking=True
)
if self.enable_mlapo:
slot_mapping[:num_decode_tokens] = slot_mapping[: num_decode_tokens * self.pcp_size : self.pcp_size]
slot_mapping[num_decode_tokens : num_decode_tokens * self.pcp_size].fill_(-1)
metadata_cls.slot_mapping = slot_mapping
self.slot_mapping_buf[:num_decode_tokens] = self.slot_mapping_buf[
: num_decode_tokens * self.pcp_size : self.pcp_size
]
self.slot_mapping_buf[num_decode_tokens : num_decode_tokens * self.pcp_size].fill_(-1)
elif self.speculative_config is not None and num_decodes > 0:
# when mtp, pcp_allgather_restore_idx=[696,-1,697,-1,560,-1,561,-1,100,101,102],
# slot_mapping should be [696,697,-1,-1,560,561,-1,-1,100,101,102]
num_tokens_per_request = num_decode_tokens // num_decodes
decode_slot_mapping = self.slot_mapping_buf[: num_decode_tokens * self.pcp_size].reshape(
num_decodes, -1
)
decode_slot_mapping[:, :num_tokens_per_request] = decode_slot_mapping[
:, : num_tokens_per_request * self.pcp_size : self.pcp_size
]
decode_slot_mapping[:, num_tokens_per_request : num_tokens_per_request * self.pcp_size].fill_(-1)
self.slot_mapping_buf[: num_decode_tokens * self.pcp_size] = decode_slot_mapping.flatten()
metadata_cls.slot_mapping = self.slot_mapping_buf[:num_actual_tokens_pcp_padded]
actual_seq_lengths_query = metadata_cls.cum_query_lens
if num_prefills > 0 and num_decode_tokens > 0:
prefill_q_cum_seqlens = (
actual_seq_lengths_query[num_decode_tokens:] - actual_seq_lengths_query[num_decode_tokens - 1]
actual_seq_lengths_query[num_decodes:] - actual_seq_lengths_query[num_decodes - 1]
)
else:
prefill_q_cum_seqlens = actual_seq_lengths_query
Expand All @@ -108,8 +132,9 @@ def build_cp_metadata(
) -> AscendPCPMetadata | None:
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
assert common_long_seq_metadata is not None
q_head_kv_lens = (seq_lens // 2) * (self.pcp_rank + 1)
q_tail_kv_lens = seq_lens * self.pcp_size - (seq_lens // 2) * self.pcp_rank
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(seq_lens.device)
q_head_kv_lens = (seq_lens // 2) * (self.pcp_rank + 1) + num_computed_tokens
q_tail_kv_lens = seq_lens * self.pcp_size - (seq_lens // 2) * self.pcp_rank + num_computed_tokens
return AscendPCPMetadata(
q_head_idx=common_long_seq_metadata.q_head_idx_tensor,
q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor,
Expand Down Expand Up @@ -181,6 +206,7 @@ def _execute_sparse_flash_attention_process(
return self._execute_sparse_flash_attention(
ql_nope, q_pe, kv, key_rope, block_table, topk_indices, actual_seq_lengths_query, actual_seq_lengths_key
)
num_decodes = attn_metadata.num_decodes
num_decode_tokens = attn_metadata.num_decode_tokens
num_prefills = attn_metadata.num_prefills
decode_attn_out = None
Expand All @@ -190,10 +216,10 @@ def _execute_sparse_flash_attention_process(
q_pe[:num_decode_tokens],
kv,
key_rope,
block_table[:num_decode_tokens],
block_table[:num_decodes],
topk_indices[:num_decode_tokens],
actual_seq_lengths_query[:num_decode_tokens],
actual_seq_lengths_key[:num_decode_tokens],
actual_seq_lengths_query[:num_decodes],
actual_seq_lengths_key[:num_decodes],
)

if num_prefills < 1:
Expand All @@ -205,10 +231,10 @@ def _execute_sparse_flash_attention_process(
ql_nope = ql_nope[num_decode_tokens:]
q_pe = q_pe[num_decode_tokens:]
topk_indices = topk_indices[num_decode_tokens:]
block_table = block_table[num_decode_tokens:]
block_table = block_table[num_decodes:]

# q head compute
q_head_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.head_attn_nomask_seqlens[num_decode_tokens:]
q_head_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.head_attn_nomask_seqlens[num_decodes:]
q_head_output = self._execute_sparse_flash_attention(
torch.index_select(ql_nope, 0, q_head_idx),
torch.index_select(q_pe, 0, q_head_idx),
Expand All @@ -221,7 +247,7 @@ def _execute_sparse_flash_attention_process(
)

# q tail compute
q_tail_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.tail_attn_nomask_seqlens[num_decode_tokens:]
q_tail_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.tail_attn_nomask_seqlens[num_decodes:]
q_tail_output = self._execute_sparse_flash_attention(
torch.index_select(ql_nope, 0, q_tail_idx),
torch.index_select(q_pe, 0, q_tail_idx),
Expand Down Expand Up @@ -321,6 +347,7 @@ def indexer_select_post_process(
)

# decode compute
num_decodes = attn_metadata.num_decodes
num_decode_tokens = attn_metadata.num_decode_tokens
num_prefills = attn_metadata.num_prefills
decode_topk_indices = None
Expand All @@ -329,24 +356,24 @@ def indexer_select_post_process(
q[:num_decode_tokens],
key,
weights[:num_decode_tokens],
actual_seq_lengths_query[:num_decode_tokens],
actual_seq_lengths_key[:num_decode_tokens],
block_table[:num_decode_tokens],
actual_seq_lengths_query[:num_decodes],
actual_seq_lengths_key[:num_decodes],
block_table[:num_decodes],
)

# prefill compute
if num_prefills == 0:
return decode_topk_indices
q = q[num_decode_tokens:]
weights = weights[num_decode_tokens:]
actual_seq_lengths_key = actual_seq_lengths_key[num_decode_tokens:]
block_table = block_table[num_decode_tokens:]
actual_seq_lengths_key = actual_seq_lengths_key[num_decodes:]
block_table = block_table[num_decodes:]
# pcp split for head and tail
q_head_idx = attn_metadata.sfa_cp_metadata.q_head_idx
q_tail_idx = attn_metadata.sfa_cp_metadata.q_tail_idx

# q head compute
q_head_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.head_attn_nomask_seqlens[num_decode_tokens:]
q_head_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.head_attn_nomask_seqlens[num_decodes:]
q_head_topk_indices = self._execute_indexer_select(
q=torch.index_select(q, 0, q_head_idx),
key=key,
Expand All @@ -357,7 +384,7 @@ def indexer_select_post_process(
)

# q tail compute
q_tail_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.tail_attn_nomask_seqlens[num_decode_tokens:]
q_tail_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.tail_attn_nomask_seqlens[num_decodes:]
q_tail_topk_indices = self._execute_indexer_select(
q=torch.index_select(q, 0, q_tail_idx),
key=key,
Expand Down
63 changes: 33 additions & 30 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,36 +246,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
try:
self.dcp_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
self.pcp_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
except Exception:
self.dcp_size = 1
self.dcp_rank = 0
self.pcp_size = 1
self.pcp_rank = 0
if self.pcp_size > 1:
self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs
max_buffer_num_tokens = self.max_num_tokens
if self.pcp_size * self.dcp_size > 1:
max_buffer_num_tokens = self.max_num_tokens + self.max_num_reqs * 2 * self.pcp_size
self.pcp_manager = PCPManager(
self.pcp_size,
self.pcp_rank,
self.dcp_size,
self.dcp_rank,
max_buffer_num_tokens,
self.max_num_reqs,
self.device,
self.vllm_config,
self.use_async_scheduling,
self.pin_memory,
)
# TODO(zhenwenqi) after https://github.com/vllm-project/vllm/pull/28988 is merged, we can delete this
self.input_ids = self._make_buffer(max_buffer_num_tokens, dtype=torch.int32)
self.positions = self._make_buffer(max_buffer_num_tokens, dtype=torch.int64)

self.sampler = AscendSampler()
self.attn_state: AscendAttentionState | None = None

Expand Down Expand Up @@ -310,6 +281,38 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
use_mm_prefix=self.model_config is not None and self.model_config.is_mm_prefix_lm,
)

try:
self.dcp_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
self.pcp_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
except Exception:
self.dcp_size = 1
self.dcp_rank = 0
self.pcp_size = 1
self.pcp_rank = 0
if self.pcp_size > 1:
self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs
max_buffer_num_tokens = self.max_num_tokens
if self.pcp_size * self.dcp_size > 1:
max_buffer_num_tokens = self.max_num_tokens + self.max_num_reqs * 2 * self.pcp_size
self.pcp_manager = PCPManager(
self.pcp_size,
self.pcp_rank,
self.dcp_size,
self.dcp_rank,
max_buffer_num_tokens,
self.max_num_reqs,
self.device,
self.vllm_config,
self.use_async_scheduling,
self.pin_memory,
self.use_sparse,
)
# TODO(zhenwenqi) after https://github.com/vllm-project/vllm/pull/28988 is merged, we can delete this
self.input_ids = self._make_buffer(max_buffer_num_tokens, dtype=torch.int32)
self.positions = self._make_buffer(max_buffer_num_tokens, dtype=torch.int64)

self._set_up_drafter()

# kv role
Expand Down
25 changes: 15 additions & 10 deletions vllm_ascend/worker/pcp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
vllm_config: VllmConfig,
use_async_scheduling: bool,
pin_memory: bool = False,
use_sparse: bool = False,
) -> None:
self.pcp_world_size = pcp_world_size
self.pcp_world_rank = pcp_rank
Expand Down Expand Up @@ -97,6 +98,7 @@ def __init__(
+ self.pcp_world_size * self.dcp_world_size * self.max_num_reqs
)
)
self.use_sparse = use_sparse
if self.speculative_config and self.pcp_world_size * self.dcp_world_size > 1:
self.input_ids_pcp_full = CpuGpuBuffer(
self.max_num_tokens, dtype=torch.int32, device=device, pin_memory=pin_memory
Expand Down Expand Up @@ -784,16 +786,19 @@ def generate_pcp_metadata(
num_prefill_reqs = self.num_prefill_reqs
num_decode_reqs = self.num_decode_reqs
num_decode_reqs_flatten = ori_query_lens_cpu[:num_decode_reqs].sum().item()
block_table_tensor[num_decode_reqs_flatten : num_decode_reqs_flatten + num_prefill_reqs].copy_(
block_table_tensor[num_decode_reqs : num_decode_reqs + num_prefill_reqs].clone()
)
block_table_tensor[:num_decode_reqs_flatten].copy_(
block_table_tensor[:num_decode_reqs].repeat_interleave(ori_query_lens[:num_decode_reqs], dim=0)
)
block_table_tensor = block_table_tensor[: num_decode_reqs_flatten + num_prefill_reqs]
if num_reqs_padded > num_reqs:
pad_size = num_reqs_padded - num_reqs
ori_query_lens_cpu[-pad_size:] = torch.full([pad_size], ori_query_lens_cpu[-pad_size - 1].item())
if not self.use_sparse:
block_table_tensor[num_decode_reqs_flatten : num_decode_reqs_flatten + num_prefill_reqs].copy_(
block_table_tensor[num_decode_reqs : num_decode_reqs + num_prefill_reqs].clone()
)
block_table_tensor[:num_decode_reqs_flatten].copy_(
block_table_tensor[:num_decode_reqs].repeat_interleave(ori_query_lens[:num_decode_reqs], dim=0)
)
block_table_tensor = block_table_tensor[: num_decode_reqs_flatten + num_prefill_reqs]
if num_reqs_padded > num_reqs:
pad_size = num_reqs_padded - num_reqs
ori_query_lens_cpu[-pad_size:] = torch.full(
[pad_size], ori_query_lens_cpu[-pad_size - 1].item()
)
pcp_unpad_mask = self.pcp_unpad_mask_cpu[: self.pcp_padded_tokens_length]
long_seq_metadata = AscendPrefillContextParallelMetadata(
pcp_use_hybrid_attn=self.pcp_use_hybrid_attn,
Expand Down
Loading