Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 51 additions & 39 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -3002,8 +3002,6 @@ def _dummy_run(
with_prefill) = self._sync_metadata_across_dp(num_tokens,
with_prefill)

moe_comm_type = self._select_moe_comm_method(num_tokens)

# If cudagraph_mode.decode_mode() == FULL and
# cudagraph_mode.seperate_routine(). This means that we are using
# different graphs and/or modes for mixed prefill-decode batches vs.
Expand Down Expand Up @@ -3049,23 +3047,62 @@ def _dummy_run(
if not self.in_profile_run and self.dynamic_eplb:
self.eplb_updator.forward_before()

has_lora = True if self.lora_config and self.compilation_config.cudagraph_specialize_lora else False
_ag_mode, batch_descriptor = \
self.aclgraph_dispatcher.dispatch(num_tokens=num_tokens, uniform_decode=uniform_decode, has_lora=has_lora)

num_tokens_padded = batch_descriptor.num_tokens
num_reqs_padded = (batch_descriptor.num_reqs if
batch_descriptor.num_reqs is not None else num_reqs)
if num_tokens_across_dp is not None and num_tokens_padded != num_tokens:
# pad is needed if the pad of `num_tokens` is triggered inside CudagraphDispatcher
num_tokens_across_dp[:] = num_tokens_padded
num_scheduled_tokens = num_scheduled_tokens.repeat(num_reqs_padded)

moe_comm_type = self._select_moe_comm_method(num_tokens_padded)

# filter out the valid batch descriptor
if aclgraph_runtime_mode is not None:
# we allow forcing NONE when the dispatcher disagrees to support
# warm ups for aclgraph capture
if aclgraph_runtime_mode != CUDAGraphMode.NONE and aclgraph_runtime_mode != _ag_mode:
raise ValueError(
f"Aclgraph runtime mode mismatch at dummy_run. "
f"Expected {_ag_mode}, but got {aclgraph_runtime_mode}.")
else:
aclgraph_runtime_mode = _ag_mode

# TODO(Mengqing): Set create_mixed_batch to False since it's only used in FI warmup
# and not supported in ASCEND now. We could remove it in the future.
attn_metadata = self._build_dummy_attn_metadata(
False,
num_reqs=num_reqs_padded,
num_tokens=num_tokens_padded,
max_query_len=max_query_len,
aclgraph_runtime_mode=aclgraph_runtime_mode,
force_attention=force_attention,
num_scheduled_tokens=num_scheduled_tokens,
)

with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens,
num_sampled_tokens):
# Make sure padding doesn't exceed max_num_tokens
assert num_tokens_padded <= self.max_num_tokens
if self.is_multimodal_model:
input_ids = None
inputs_embeds = self.inputs_embeds.gpu[:num_tokens]
inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded]
elif self.enable_prompt_embeds:
input_ids = None
inputs_embeds = self.inputs_embeds.gpu[:num_tokens]
inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded]
else:
input_ids = self.input_ids[:num_tokens]
input_ids = self.input_ids[:num_tokens_padded]
inputs_embeds = None

if self.uses_mrope:
positions = self.mrope_positions[:, :num_tokens]
positions = self.mrope_positions[:, :num_tokens_padded]
else:
positions = self.positions[:num_tokens]
positions = self.positions[:num_tokens_padded]

if get_pp_group().is_first_rank:
intermediate_tensors = None
Expand All @@ -3077,39 +3114,14 @@ def _dummy_run(
dtype=self.dtype,
device=self.device))
intermediate_tensors = IntermediateTensors({
k: v[:num_tokens]
k:
v[:num_tokens_padded]
for k, v in self.intermediate_tensors.items()
})
has_lora = True if self.lora_config and self.compilation_config.cudagraph_specialize_lora else False
# filter out the valid batch descriptor
_ag_mode, batch_descriptor = \
self.aclgraph_dispatcher.dispatch(num_tokens=num_tokens, uniform_decode=uniform_decode, has_lora=has_lora)
if aclgraph_runtime_mode is not None:
# we allow forcing NONE when the dispatcher disagrees to support
# warm ups for aclgraph capture
if aclgraph_runtime_mode != CUDAGraphMode.NONE and aclgraph_runtime_mode != _ag_mode:
raise ValueError(
f"Aclgraph runtime mode mismatch at dummy_run. "
f"Expected {_ag_mode}, but got {aclgraph_runtime_mode}."
)
else:
aclgraph_runtime_mode = _ag_mode

# TODO(Mengqing): Set create_mixed_batch to False since it's only used in FI warmup
# and not supported in ASCEND now. We could remove it in the future.
attn_metadata = self._build_dummy_attn_metadata(
False,
num_reqs=num_reqs,
num_tokens=num_tokens,
max_query_len=max_query_len,
aclgraph_runtime_mode=aclgraph_runtime_mode,
force_attention=force_attention,
num_scheduled_tokens=num_scheduled_tokens,
)

need_dummy_logits = (not self.in_profile_run
and lmhead_tp_enable())
max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs
max_num_reqs_across_dp = num_tokens_padded if not with_prefill else max_num_reqs
dummy_indices = torch.zeros(max_num_reqs_across_dp,
dtype=torch.int32)

Expand All @@ -3129,7 +3141,7 @@ def dummy_drafter_compute_logits(hidden_states):
with set_ascend_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
num_tokens=num_tokens_padded,
num_tokens_across_dp=num_tokens_across_dp,
with_prefill=with_prefill,
in_profile_run=self.in_profile_run,
Expand All @@ -3143,15 +3155,15 @@ def dummy_drafter_compute_logits(hidden_states):
weight_prefetch_method=self.weight_prefetch_method):
hidden_states = self._generate_dummy_run_hidden_states(
with_prefill, is_torchair_compile, input_ids, positions,
attn_metadata, num_tokens, intermediate_tensors,
attn_metadata, num_tokens_padded, intermediate_tensors,
inputs_embeds)
dummy_compute_logits(hidden_states)

if self.drafter:
self.drafter.dummy_run(
num_tokens=num_tokens,
num_tokens=num_tokens_padded,
with_prefill=with_prefill,
num_reqs=num_reqs,
num_reqs=num_reqs_padded,
num_tokens_across_dp=num_tokens_across_dp,
aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor,
Expand Down
Loading