Skip to content
64 changes: 39 additions & 25 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@
import numpy.typing as npt
import torch
import torch._dynamo.cache_size
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ReduceOp
from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.layer import Attention
from vllm.config import CompilationLevel, VllmConfig
Expand Down Expand Up @@ -548,16 +546,16 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
self.input_batch.refresh_sampling_metadata()

def _get_forward_metadata_across_dp(
self, total_num_scheduled_tokens: int,
with_prefill: bool) -> tuple[int, bool]:
forward_metadata = torch.tensor(
[total_num_scheduled_tokens, with_prefill],
device="cpu",
dtype=torch.int32)
dist.all_reduce(forward_metadata,
op=ReduceOp.MAX,
group=get_dp_group().cpu_group)
return int(forward_metadata[0]), bool(forward_metadata[1] > 0)
self, num_tokens: int,
with_prefill: bool) -> tuple[torch.Tensor, bool]:
local_forward_metadata = torch.tensor([num_tokens, with_prefill],
device="npu",
dtype=torch.int32).unsqueeze(0)
global_forward_metadata = get_dp_group().all_gather(
local_forward_metadata, dim=0)
num_tokens_across_dp = global_forward_metadata[:, 0].cpu()
with_prefill = bool(global_forward_metadata[:, 1].any())
return num_tokens_across_dp, with_prefill

def get_eagle_atten_dict(
self,
Expand Down Expand Up @@ -1013,23 +1011,35 @@ def _process_reqs(
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
]

num_tokens_across_dp = None
if self.dp_size > 1:
max_num_tokens, with_prefill = self._get_forward_metadata_across_dp(
total_num_scheduled_tokens, with_prefill)
num_tokens_across_dp, with_prefill = \
self._get_forward_metadata_across_dp(num_input_tokens,
with_prefill)
max_num_tokens = int(num_tokens_across_dp.max().item())
extra_builder_kwargs['max_num_tokens_across_dp'] = max_num_tokens
extra_builder_kwargs['with_prefill_across_dp'] = with_prefill

# Add graph_pad_size here
if self.torchair_graph_enabled and not with_prefill:
if self.dp_size > 1:
padded_batch_size = self.select_torchair_padded_batch_size(
max_num_tokens)
else:
padded_batch_size = self.select_torchair_padded_batch_size(
total_num_scheduled_tokens)
max_num_tokens = (max_num_tokens
if self.dp_size > 1 else num_input_tokens)
padded_batch_size = self.select_torchair_padded_batch_size(
max_num_tokens)
graph_pad_size = padded_batch_size - total_num_scheduled_tokens

extra_builder_kwargs['graph_pad_size'] = graph_pad_size
# If torchair graph is enabled and in decode mode, the dummy run
# batch size is set to the selected graph size.
dummy_num_tokens = padded_batch_size
else:
# If torchair graph is not enabled, or if with_prefill is True, the
# dummy run batch size is set to 1.
dummy_num_tokens = 1

if self.dp_size > 1:
assert num_tokens_across_dp is not None
num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1,
dummy_num_tokens)

if self.vllm_config.model_config.use_mla:
query_start_loc = self.query_start_loc[:num_reqs + 1]
Expand Down Expand Up @@ -1106,7 +1116,8 @@ def _process_reqs(
# Run forward pass
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp):
with ProfileExecuteDuration().capture_async("forward"):
model_kwargs = {}
if self.torchair_graph_enabled:
Expand Down Expand Up @@ -1585,6 +1596,7 @@ def _dummy_run(
num_tokens: int,
is_compile: bool = False,
with_prefill: bool = True,
num_tokens_across_dp: Optional[int] = None,
Comment thread
jianzs marked this conversation as resolved.
) -> torch.Tensor:
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
# for dummy run with LoRA so that the num_reqs collectively
Expand Down Expand Up @@ -1626,9 +1638,11 @@ def _dummy_run(
for k, v in self.intermediate_tensors.items()
})

with set_forward_context(None,
self.vllm_config,
num_tokens=num_tokens):
with set_forward_context(
None,
self.vllm_config,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp):
if self.torchair_graph_enabled and not with_prefill:
attn_metadata = self.attn_metadata_builder.build_dummy(
num_reqs=num_tokens, num_actual_tokens=1)
Expand Down
28 changes: 21 additions & 7 deletions vllm_ascend/worker/worker_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,16 +294,30 @@ def pin_lora(self, lora_id: int) -> bool:

def execute_dummy_batch(self) -> None:
runner = self.model_runner
max_num_tokens = 1
with_prefill = False
if runner.dp_size > 1:
max_num_tokens, with_prefill = runner._get_forward_metadata_across_dp(
max_num_tokens, with_prefill)
if runner.dp_size <= 1:
raise ValueError(
"Dummy batch execution should only be "
"performed with data parallelism enabled, but got "
f"dp_size={runner.dp_size}.")

# If torchair graph is enabled, notify the other DP ranks that this is a
# dummy run by using '-1' as a flag for num_tokens. This will be
# replaced with the final determined graph size before the forward pass.
num_tokens_across_dp, with_prefill = \
runner._get_forward_metadata_across_dp(-1, False)
Comment thread
jianzs marked this conversation as resolved.

if runner.torchair_graph_enabled and not with_prefill:
max_num_tokens = runner.select_torchair_padded_batch_size(
max_num_tokens = int(num_tokens_across_dp.max().item())
num_tokens = runner.select_torchair_padded_batch_size(
max_num_tokens)
runner._dummy_run(max_num_tokens,
else:
num_tokens = 1
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is it 1?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If graph mode is off, a dummy run only needs to be executed; computational requirements are not a factor.


num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1,
num_tokens)
runner._dummy_run(num_tokens,
is_compile=False,
num_tokens_across_dp=num_tokens_across_dp,
with_prefill=with_prefill)

def _init_worker_distributed_environment(self) -> None:
Expand Down
Loading