Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions vllm/model_executor/models/deepseek_vl2.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,7 @@ def _call_hf_processor(
dict(prompt=prompt, **mm_data),
mm_kwargs,
)
target_dtype = self.info.ctx.model_config.dtype
pixel_values = processed_outputs.pop("pixel_values").to(
target_dtype)
pixel_values = processed_outputs["pixel_values"]
# split pixel values into patches corresponding to each image
images_spatial_crop = processed_outputs["images_spatial_crop"]
patches_per_image = [
Expand Down
8 changes: 7 additions & 1 deletion vllm/multimodal/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,11 +731,17 @@ def as_kwargs(
batched_inputs: BatchedTensorInputs,
*,
device: torch.types.Device,
dtype: Optional[torch.dtype] = None,
) -> BatchedTensorInputs:
json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)

json_mapped = json_map_leaves(
lambda x: x.to(device, non_blocking=True),
lambda x: x.to(
device=device,
# This mimics the behavior of transformers.BatchFeature
dtype=dtype if x.is_floating_point() else None,
non_blocking=True,
),
json_inputs,
)

Expand Down
7 changes: 5 additions & 2 deletions vllm/spec_decode/draft_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,11 @@ def execute_model(
inputs_embeds=None,
positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device),
**MultiModalKwargs.as_kwargs(
multi_modal_kwargs,
dtype=self.model_runner.model_config.dtype,
device=self.device,
),
**model_execute_kwargs,
)

Expand Down
12 changes: 9 additions & 3 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,8 +929,11 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
encoder_outputs = []
for grouped_mm_inputs in grouped_mm_inputs_list:
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
device=self.device)
batched_mm_inputs = MultiModalKwargs.as_kwargs(
batched_mm_inputs,
dtype=self.model_config.dtype,
device=self.device,
)

# Run the encoder.
# `curr_group_outputs` is either of the following:
Expand Down Expand Up @@ -1874,7 +1877,10 @@ def profile_run(self) -> None:
batched_dummy_mm_inputs = MultiModalKwargs.batch(
[dummy_mm_kwargs] * max_num_mm_items)
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
batched_dummy_mm_inputs, device=self.device)
batched_dummy_mm_inputs,
dtype=self.model_config.dtype,
device=self.device,
)

# Run multimodal encoder.
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
Expand Down
14 changes: 10 additions & 4 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,8 +652,11 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
encoder_outputs = []
for grouped_mm_inputs in grouped_mm_inputs_list:
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
device=self.device)
batched_mm_inputs = MultiModalKwargs.as_kwargs(
batched_mm_inputs,
dtype=self.model_config.dtype,
device=self.device,
)

# Run the encoder.
# `curr_group_outputs` is either of the following:
Expand Down Expand Up @@ -1435,8 +1438,11 @@ def _get_mm_dummy_batch(self, modality: str,

batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] *
batch_size)
return MultiModalKwargs.as_kwargs(batched_dummy_mm_inputs,
device=self.device)
return MultiModalKwargs.as_kwargs(
batched_dummy_mm_inputs,
dtype=self.model_config.dtype,
device=self.device,
)


def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]:
Expand Down
7 changes: 5 additions & 2 deletions vllm/worker/cpu_enc_dec_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,11 @@ def execute_model(
model_input.encoder_input_tokens,
"encoder_positions":
model_input.encoder_input_positions,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
device=self.device,
),
"intermediate_tensors":
intermediate_tensors,
}
Expand Down
5 changes: 4 additions & 1 deletion vllm/worker/cpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,10 @@ def execute_model(
multimodal_kwargs = {}
if model_input.multi_modal_kwargs is not None:
multimodal_kwargs = MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs, device=self.device)
model_input.multi_modal_kwargs,
dtype=self.model_config.dtype,
device=self.device,
)
execute_model_kwargs = {}
if previous_hidden_states is not None:
execute_model_kwargs.update(
Expand Down
7 changes: 5 additions & 2 deletions vllm/worker/cpu_pooling_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,11 @@ def execute_model(
model_input.input_tokens,
"positions":
model_input.input_positions,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
device=self.device,
),
**cross_enc_kwargs,
"intermediate_tensors":
intermediate_tensors,
Expand Down
10 changes: 7 additions & 3 deletions vllm/worker/enc_dec_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,13 @@ def execute_model(
encoder_input_ids=model_input.encoder_input_tokens,
encoder_positions=model_input.encoder_input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device),
**seqlen_agnostic_kwargs)
**MultiModalKwargs.as_kwargs(
multi_modal_kwargs,
dtype=self.model_config.dtype,
device=self.device,
),
**seqlen_agnostic_kwargs,
)

logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata)
Expand Down
7 changes: 5 additions & 2 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1845,8 +1845,11 @@ def execute_model(
inputs_embeds=model_input.inputs_embeds,
positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device),
**MultiModalKwargs.as_kwargs(
multi_modal_kwargs,
dtype=self.model_config.dtype,
device=self.device,
),
**seqlen_agnostic_kwargs,
**model_kwargs,
)
Expand Down
7 changes: 5 additions & 2 deletions vllm/worker/multi_step_neuron_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,11 @@ def execute_model(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
input_block_ids=model_input.input_block_ids,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
device=self.device,
),
)

output = self.model.sample(
Expand Down
7 changes: 5 additions & 2 deletions vllm/worker/multi_step_neuronx_distributed_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,11 @@ def execute_model(
positions=model_input.input_positions,
input_block_ids=model_input.input_block_ids,
sampling_params=sampling_params,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
device=self.device,
),
)

output = self.model.sample(
Expand Down
16 changes: 10 additions & 6 deletions vllm/worker/neuron_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,11 @@ def execute_model(
positions=model_input.input_positions,
input_block_ids=model_input.input_block_ids,
sampling_params=sampling_params,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs
or {},
device=self.device),
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
device=self.device,
),
)
elif current_platform.use_transformers_neuronx():
# [TODO] validate on-device sampling
Expand All @@ -389,9 +391,11 @@ def execute_model(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
input_block_ids=model_input.input_block_ids,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs
or {},
device=self.device),
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
device=self.device,
),
)

# Compute the logits only if the on-device sampling is turned off as
Expand Down
10 changes: 7 additions & 3 deletions vllm/worker/pooling_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,14 @@ def execute_model(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device),
**MultiModalKwargs.as_kwargs(
multi_modal_kwargs,
dtype=self.model_config.dtype,
device=self.device,
),
**cross_enc_kwargs,
**seqlen_agnostic_kwargs)
**seqlen_agnostic_kwargs,
)

if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
Expand Down
9 changes: 6 additions & 3 deletions vllm/worker/xpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,9 +562,12 @@ def execute_model(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs
or {},
device=self.device))
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
device=self.device,
),
)
# Compute the logits in the last pipeline stage.
if not get_pp_group().is_last_rank:
return hidden_or_intermediate_states
Expand Down