Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 27 additions & 13 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,28 @@ def _calc_spec_decode_metadata(
logits_indices=logits_indices,
)

def sync_and_slice_intermediate_tensors(
self,
num_tokens: int,
intermediate_tensors: IntermediateTensors | None,
sync_self: bool,
) -> IntermediateTensors:
assert self.intermediate_tensors is not None

tp = self.vllm_config.parallel_config.tensor_parallel_size
is_rs = enable_sp(self.vllm_config)
# When sequence parallelism/flashcomm1 is enabled, the intermediate tensor is sharded
# across tensor parallel ranks, so each rank only needs its own slice.
if sync_self:
assert intermediate_tensors is not None
for k, v in intermediate_tensors.items():
copy_len = num_tokens // tp if is_rs else num_tokens
self.intermediate_tensors[k][:copy_len].copy_(v[:copy_len], non_blocking=True)

return IntermediateTensors(
{k: v[: num_tokens // tp] if is_rs else v[:num_tokens] for k, v in self.intermediate_tensors.items()}
)

# TODO: Once the PCP features are complete, it will fully inherit the classes from the VLLM community.
def propose_draft_token_ids(
self,
Expand Down Expand Up @@ -2300,23 +2322,15 @@ def _dummy_run(
if get_pp_group().is_first_rank:
intermediate_tensors = None
else:
# When PP and flashcomm1 are enabled, during dummy_run the estimated space should divide num_tokens by
# tp_size; otherwise, on non-first PP ranks it would effectively perform an extra all-gather, leading
# to incorrect memory estimation and potentially causing OOM.
intermediate_tokens = num_tokens_padded
if enable_sp():
tp_size = get_tensor_model_parallel_world_size()
intermediate_tokens = (num_tokens_padded + tp_size - 1) // tp_size
if self.intermediate_tensors is None:
max_actual_tokens = self.max_num_tokens
if enable_sp():
max_actual_tokens = (self.max_num_tokens + tp_size - 1) // tp_size
tp_size = get_tensor_model_parallel_world_size()
max_actual_tokens = (
(self.max_num_tokens + tp_size - 1) // tp_size if enable_sp() else self.max_num_tokens
)
self.intermediate_tensors = self.model.make_empty_intermediate_tensors(
batch_size=max_actual_tokens, dtype=self.dtype, device=self.device
)
intermediate_tensors = IntermediateTensors(
{k: v[:intermediate_tokens] for k, v in self.intermediate_tensors.items()}
)
intermediate_tensors = self.sync_and_slice_intermediate_tensors(num_tokens_padded, None, False)

need_dummy_logits = not is_profile and lmhead_tp_enable()
max_num_reqs_across_dp = max_num_reqs * self.uniform_decode_query_len
Expand Down