Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions vllm/model_executor/models/deepseek_vl2.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,7 @@ def _call_hf_processor(
dict(prompt=prompt, **mm_data),
mm_kwargs,
)
target_dtype = self.info.ctx.model_config.dtype
pixel_values = processed_outputs.pop("pixel_values").to(
target_dtype)
pixel_values = processed_outputs["pixel_values"]
# split pixel values into patches corresponding to each image
images_spatial_crop = processed_outputs["images_spatial_crop"]
patches_per_image = [
Expand Down
5 changes: 0 additions & 5 deletions vllm/model_executor/models/gemma3_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,11 +263,6 @@ def _call_hf_processor(
mm_data,
mm_kwargs,
)
if "pixel_values" in processed_outputs:
# Cast pixel values to model dtype already here,
# so we need to transfer less data to the GPU
processed_outputs["pixel_values"] = processed_outputs[
"pixel_values"].to(self.info.ctx.model_config.dtype)

# HF processor pops the `num_crops` kwarg, which is needed by vLLM
if (images := mm_data.get("images")) is not None:
Expand Down
8 changes: 7 additions & 1 deletion vllm/multimodal/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,11 +731,17 @@ def as_kwargs(
batched_inputs: BatchedTensorInputs,
*,
device: torch.types.Device,
dtype: Optional[torch.dtype] = None,
) -> BatchedTensorInputs:
json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)

def maybe_cast_dtype(x: torch.Tensor):
# This mimics the behavior of transformers.BatchFeature
return x.to(dtype=dtype) if x.is_floating_point() else x

json_mapped = json_map_leaves(
lambda x: x.to(device, non_blocking=True),
# NOTE: Cast the dtype before sending it to device
lambda x: maybe_cast_dtype(x).to(device=device, non_blocking=True),
json_inputs,
)

Expand Down
7 changes: 5 additions & 2 deletions vllm/spec_decode/draft_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,11 @@ def execute_model(
inputs_embeds=None,
positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device),
**MultiModalKwargs.as_kwargs(
multi_modal_kwargs,
dtype=self.model_runner.model_config.dtype,
device=self.device,
),
**model_execute_kwargs,
)

Expand Down
12 changes: 9 additions & 3 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,8 +929,11 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
encoder_outputs = []
for grouped_mm_inputs in grouped_mm_inputs_list:
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
device=self.device)
batched_mm_inputs = MultiModalKwargs.as_kwargs(
batched_mm_inputs,
dtype=self.model_config.dtype,
device=self.device,
)

# Run the encoder.
# `curr_group_outputs` is either of the following:
Expand Down Expand Up @@ -1874,7 +1877,10 @@ def profile_run(self) -> None:
batched_dummy_mm_inputs = MultiModalKwargs.batch(
[dummy_mm_kwargs] * max_num_mm_items)
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
batched_dummy_mm_inputs, device=self.device)
batched_dummy_mm_inputs,
dtype=self.model_config.dtype,
device=self.device,
)

# Run multimodal encoder.
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
Expand Down
14 changes: 10 additions & 4 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,8 +652,11 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
encoder_outputs = []
for grouped_mm_inputs in grouped_mm_inputs_list:
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
device=self.device)
batched_mm_inputs = MultiModalKwargs.as_kwargs(
batched_mm_inputs,
dtype=self.model_config.dtype,
device=self.device,
)

# Run the encoder.
# `curr_group_outputs` is either of the following:
Expand Down Expand Up @@ -1435,8 +1438,11 @@ def _get_mm_dummy_batch(self, modality: str,

batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] *
batch_size)
return MultiModalKwargs.as_kwargs(batched_dummy_mm_inputs,
device=self.device)
return MultiModalKwargs.as_kwargs(
batched_dummy_mm_inputs,
dtype=self.model_config.dtype,
device=self.device,
)


def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]:
Expand Down
7 changes: 5 additions & 2 deletions vllm/worker/cpu_enc_dec_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,11 @@ def execute_model(
model_input.encoder_input_tokens,
"encoder_positions":
model_input.encoder_input_positions,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
device=self.device,
),
"intermediate_tensors":
intermediate_tensors,
}
Expand Down
5 changes: 4 additions & 1 deletion vllm/worker/cpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,10 @@ def execute_model(
multimodal_kwargs = {}
if model_input.multi_modal_kwargs is not None:
multimodal_kwargs = MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs, device=self.device)
model_input.multi_modal_kwargs,
dtype=self.model_config.dtype,
device=self.device,
)
execute_model_kwargs = {}
if previous_hidden_states is not None:
execute_model_kwargs.update(
Expand Down
7 changes: 5 additions & 2 deletions vllm/worker/cpu_pooling_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,11 @@ def execute_model(
model_input.input_tokens,
"positions":
model_input.input_positions,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
device=self.device,
),
**cross_enc_kwargs,
"intermediate_tensors":
intermediate_tensors,
Expand Down
10 changes: 7 additions & 3 deletions vllm/worker/enc_dec_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,13 @@ def execute_model(
encoder_input_ids=model_input.encoder_input_tokens,
encoder_positions=model_input.encoder_input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device),
**seqlen_agnostic_kwargs)
**MultiModalKwargs.as_kwargs(
multi_modal_kwargs,
dtype=self.model_config.dtype,
device=self.device,
),
**seqlen_agnostic_kwargs,
)

logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata)
Expand Down
7 changes: 5 additions & 2 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1845,8 +1845,11 @@ def execute_model(
inputs_embeds=model_input.inputs_embeds,
positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device),
**MultiModalKwargs.as_kwargs(
multi_modal_kwargs,
dtype=self.model_config.dtype,
device=self.device,
),
**seqlen_agnostic_kwargs,
**model_kwargs,
)
Expand Down
7 changes: 5 additions & 2 deletions vllm/worker/multi_step_neuron_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,11 @@ def execute_model(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
input_block_ids=model_input.input_block_ids,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
device=self.device,
),
)

output = self.model.sample(
Expand Down
7 changes: 5 additions & 2 deletions vllm/worker/multi_step_neuronx_distributed_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,11 @@ def execute_model(
positions=model_input.input_positions,
input_block_ids=model_input.input_block_ids,
sampling_params=sampling_params,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
device=self.device,
),
)

output = self.model.sample(
Expand Down
16 changes: 10 additions & 6 deletions vllm/worker/neuron_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,11 @@ def execute_model(
positions=model_input.input_positions,
input_block_ids=model_input.input_block_ids,
sampling_params=sampling_params,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs
or {},
device=self.device),
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
device=self.device,
),
)
elif current_platform.use_transformers_neuronx():
# [TODO] validate on-device sampling
Expand All @@ -389,9 +391,11 @@ def execute_model(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
input_block_ids=model_input.input_block_ids,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs
or {},
device=self.device),
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
device=self.device,
),
)

# Compute the logits only if the on-device sampling is turned off as
Expand Down
10 changes: 7 additions & 3 deletions vllm/worker/pooling_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,14 @@ def execute_model(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device),
**MultiModalKwargs.as_kwargs(
multi_modal_kwargs,
dtype=self.model_config.dtype,
device=self.device,
),
**cross_enc_kwargs,
**seqlen_agnostic_kwargs)
**seqlen_agnostic_kwargs,
)

if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
Expand Down
9 changes: 6 additions & 3 deletions vllm/worker/xpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,9 +562,12 @@ def execute_model(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs
or {},
device=self.device))
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
device=self.device,
),
)
# Compute the logits in the last pipeline stage.
if not get_pp_group().is_last_rank:
return hidden_or_intermediate_states
Expand Down