Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions tests/ut/attention/test_attention_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,6 @@ def test_compute_prefill_context(self, mock_npu_attention):
attn_metadata.prefill.chunked_context = MagicMock()
local_context_lens_allranks = torch.tensor([[[256, 256], [256, 256]]])
attn_metadata.prefill.chunked_context.local_context_lens_allranks = local_context_lens_allranks
attn_metadata.prefill.chunked_context.batch_chunk_seq_mask = torch.randint(
0, 2, (1024, ), dtype=torch.bool)
attn_metadata.prefill.chunked_context.local_total_toks = local_context_lens_allranks[:,
0,
0].sum(
Expand Down
7 changes: 1 addition & 6 deletions vllm_ascend/attention/context_parallel/attention_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def build(
assert num_computed_tokens_of_pcp_dcp is not None
chunked_context_metadata = None
if num_prefills > 0:
query_lens = query_lens[num_decode_tokens:]
query_lens = query_lens[num_decodes:]
context_lens_cpu = num_computed_tokens_cpu[num_decodes:num_reqs]
max_context_len_cpu = context_lens_cpu.max().item()
pcp_size = get_pcp_group().world_size
Expand Down Expand Up @@ -172,10 +172,6 @@ def build(
kv_inverse_idx_for_chunk = None
cp_kv_recover_idx_for_chunk = None

batch_chunk_seq_mask = local_context_lens_allranks[:, self.pcp_rank, self.dcp_rank] == 0
batch_chunk_seq_mask = torch.repeat_interleave(
batch_chunk_seq_mask, repeats=(query_lens * self.pcp_size).to(self.device)
)
chunk_seq_mask_filtered_indices = filter_chunked_req_indices(query_lens, chunked_req_mask).to(
self.device
)
Expand All @@ -187,7 +183,6 @@ def build(
local_context_lens_allranks=local_context_lens_allranks,
cp_kv_recover_idx_for_chunk=cp_kv_recover_idx_for_chunk,
kv_inverse_idx_for_chunk=kv_inverse_idx_for_chunk,
batch_chunk_seq_mask=batch_chunk_seq_mask,
chunk_seq_mask_filtered_indices=chunk_seq_mask_filtered_indices,
local_total_toks=local_total_toks.item(),
)
Expand Down
2 changes: 1 addition & 1 deletion vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,7 +1016,7 @@ def propose_draft_token_ids(
target_positions = self._get_positions(num_scheduled_tokens)
target_hidden_states = hidden_states
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat([h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1)
target_hidden_states = torch.cat([h for h in aux_hidden_states], dim=-1)
Comment thread
lilinsiman marked this conversation as resolved.
else:
token_indices_to_sample = None
# input_ids can be None for multimodal models.
Expand Down
Loading