Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,8 +1133,10 @@ def _generate_process_reqs_hidden_states(self, maybe_padded_num_tokens,
hidden_states, IntermediateTensors):
hidden_states = self._all_gather_hidden_states_and_aux(
hidden_states)
return hidden_states if self.pcp_size == 1 else self.pcp_manager.get_restore_hidden_states(
hidden_states)
if self.pcp_size > 1 and get_pp_group().is_last_rank:
Comment thread
gjc0824 marked this conversation as resolved.
hidden_states = self.pcp_manager.get_restore_hidden_states(
hidden_states)
return hidden_states

def _build_attn_state(self, num_reqs, num_scheduled_tokens,
num_valid_tokens):
Expand Down Expand Up @@ -2173,19 +2175,24 @@ def _dummy_run(
else:
# When PP and flashcomm1 are enabled, during dummy_run the estimated space should divide num_tokens by tp_size;
# otherwise, on non-first PP ranks it would effectively perform an extra all-gather, leading to incorrect memory estimation and potentially causing OOM.
actual_tokens = num_tokens
intermediate_tokens = num_tokens_padded
if enable_sp():
tp_size = get_tensor_model_parallel_world_size()
actual_tokens = num_tokens // tp_size
intermediate_tokens = (num_tokens_padded + tp_size -
1) // tp_size
if self.intermediate_tensors is None:
max_actual_tokens = self.max_num_tokens
if enable_sp():
max_actual_tokens = (self.max_num_tokens + tp_size -
1) // tp_size
self.intermediate_tensors = (
self.model.make_empty_intermediate_tensors(
batch_size=actual_tokens,
batch_size=max_actual_tokens,
dtype=self.dtype,
device=self.device))
intermediate_tensors = IntermediateTensors({
k:
v[:num_tokens_padded]
v[:intermediate_tokens]
for k, v in self.intermediate_tensors.items()
})
Comment on lines +2178 to 2197
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change correctly handles intermediate_tensors during a dummy run when Pipeline Parallelism (PP) and Sequence Parallelism (SP, via flashcomm1) are enabled. There are several important fixes here:

  1. Correct sharded size calculation: The size of intermediate tensors is now correctly calculated using ceiling division (num_tokens_padded + tp_size - 1) // tp_size, which is the proper way to determine the size of a sharded tensor. The previous floor division was incorrect.
  2. Robust buffer allocation: self.intermediate_tensors is now allocated once with the maximum possible size (self.max_num_tokens), making it robust and reusable across different dummy runs. Previously, it was allocated based on the current run's token count, which could be insufficient for subsequent runs.
  3. Correct tensor slicing: The slicing of the intermediate tensors now correctly uses the sharded size (intermediate_tokens), preventing potential out-of-bounds errors that could occur when using the un-sharded num_tokens_padded.

These changes are critical for preventing OOM errors and ensuring correctness in memory estimation during dummy runs.


Expand Down
Loading