diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index 164fa40ffebe..5c8793f59ffb 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -210,9 +210,7 @@ def _call_hf_processor( dict(prompt=prompt, **mm_data), mm_kwargs, ) - target_dtype = self.info.ctx.model_config.dtype - pixel_values = processed_outputs.pop("pixel_values").to( - target_dtype) + pixel_values = processed_outputs["pixel_values"] # split pixel values into patches corresponding to each image images_spatial_crop = processed_outputs["images_spatial_crop"] patches_per_image = [ diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 00a972d33b04..182cc86d3ca8 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -263,11 +263,6 @@ def _call_hf_processor( mm_data, mm_kwargs, ) - if "pixel_values" in processed_outputs: - # Cast pixel values to model dtype already here, - # so we need to transfer less data to the GPU - processed_outputs["pixel_values"] = processed_outputs[ - "pixel_values"].to(self.info.ctx.model_config.dtype) # HF processor pops the `num_crops` kwarg, which is needed by vLLM if (images := mm_data.get("images")) is not None: diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 71ef1a98e0d0..3580e0805fec 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -731,11 +731,17 @@ def as_kwargs( batched_inputs: BatchedTensorInputs, *, device: torch.types.Device, + dtype: Optional[torch.dtype] = None, ) -> BatchedTensorInputs: json_inputs = cast(JSONTree[torch.Tensor], batched_inputs) + def maybe_cast_dtype(x: torch.Tensor): + # This mimics the behavior of transformers.BatchFeature + return x.to(dtype=dtype) if x.is_floating_point() else x + json_mapped = json_map_leaves( - lambda x: x.to(device, non_blocking=True), + # NOTE: Cast the dtype before sending it to device + lambda x: maybe_cast_dtype(x).to(device=device, non_blocking=True), json_inputs, ) diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index a6276c563394..991d2040a878 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -294,8 +294,11 @@ def execute_model( inputs_embeds=None, positions=model_input.input_positions, intermediate_tensors=intermediate_tensors, - **MultiModalKwargs.as_kwargs(multi_modal_kwargs, - device=self.device), + **MultiModalKwargs.as_kwargs( + multi_modal_kwargs, + dtype=self.model_runner.model_config.dtype, + device=self.device, + ), **model_execute_kwargs, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index aa47ac253bb9..910c0e80bb31 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -929,8 +929,11 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): encoder_outputs = [] for grouped_mm_inputs in grouped_mm_inputs_list: batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs) - batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs, - device=self.device) + batched_mm_inputs = MultiModalKwargs.as_kwargs( + batched_mm_inputs, + dtype=self.model_config.dtype, + device=self.device, + ) # Run the encoder. # `curr_group_outputs` is either of the following: @@ -1874,7 +1877,10 @@ def profile_run(self) -> None: batched_dummy_mm_inputs = MultiModalKwargs.batch( [dummy_mm_kwargs] * max_num_mm_items) batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs( - batched_dummy_mm_inputs, device=self.device) + batched_dummy_mm_inputs, + dtype=self.model_config.dtype, + device=self.device, + ) # Run multimodal encoder. dummy_encoder_outputs = self.model.get_multimodal_embeddings( diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index b13ff9f97e6f..46bcf64ed0c3 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -652,8 +652,11 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): encoder_outputs = [] for grouped_mm_inputs in grouped_mm_inputs_list: batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs) - batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs, - device=self.device) + batched_mm_inputs = MultiModalKwargs.as_kwargs( + batched_mm_inputs, + dtype=self.model_config.dtype, + device=self.device, + ) # Run the encoder. # `curr_group_outputs` is either of the following: @@ -1435,8 +1438,11 @@ def _get_mm_dummy_batch(self, modality: str, batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] * batch_size) - return MultiModalKwargs.as_kwargs(batched_dummy_mm_inputs, - device=self.device) + return MultiModalKwargs.as_kwargs( + batched_dummy_mm_inputs, + dtype=self.model_config.dtype, + device=self.device, + ) def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]: diff --git a/vllm/worker/cpu_enc_dec_model_runner.py b/vllm/worker/cpu_enc_dec_model_runner.py index c2120c035175..82eeeb570d22 100644 --- a/vllm/worker/cpu_enc_dec_model_runner.py +++ b/vllm/worker/cpu_enc_dec_model_runner.py @@ -297,8 +297,11 @@ def execute_model( model_input.encoder_input_tokens, "encoder_positions": model_input.encoder_input_positions, - **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, - device=self.device), + **MultiModalKwargs.as_kwargs( + model_input.multi_modal_kwargs or {}, + dtype=self.model_config.dtype, + device=self.device, + ), "intermediate_tensors": intermediate_tensors, } diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 710ca1a13b0c..fb436a079f87 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -628,7 +628,10 @@ def execute_model( multimodal_kwargs = {} if model_input.multi_modal_kwargs is not None: multimodal_kwargs = MultiModalKwargs.as_kwargs( - model_input.multi_modal_kwargs, device=self.device) + model_input.multi_modal_kwargs, + dtype=self.model_config.dtype, + device=self.device, + ) execute_model_kwargs = {} if previous_hidden_states is not None: execute_model_kwargs.update( diff --git a/vllm/worker/cpu_pooling_model_runner.py b/vllm/worker/cpu_pooling_model_runner.py index 1ceb2557c6b3..2a60e51261ad 100644 --- a/vllm/worker/cpu_pooling_model_runner.py +++ b/vllm/worker/cpu_pooling_model_runner.py @@ -50,8 +50,11 @@ def execute_model( model_input.input_tokens, "positions": model_input.input_positions, - **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, - device=self.device), + **MultiModalKwargs.as_kwargs( + model_input.multi_modal_kwargs or {}, + dtype=self.model_config.dtype, + device=self.device, + ), **cross_enc_kwargs, "intermediate_tensors": intermediate_tensors, diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 4864163b0de2..3957e5608524 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -202,9 +202,13 @@ def execute_model( encoder_input_ids=model_input.encoder_input_tokens, encoder_positions=model_input.encoder_input_positions, intermediate_tensors=intermediate_tensors, - **MultiModalKwargs.as_kwargs(multi_modal_kwargs, - device=self.device), - **seqlen_agnostic_kwargs) + **MultiModalKwargs.as_kwargs( + multi_modal_kwargs, + dtype=self.model_config.dtype, + device=self.device, + ), + **seqlen_agnostic_kwargs, + ) logits = self.model.compute_logits(hidden_or_intermediate_states, model_input.sampling_metadata) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 53e79adf9aae..8c968faa7810 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1845,8 +1845,11 @@ def execute_model( inputs_embeds=model_input.inputs_embeds, positions=model_input.input_positions, intermediate_tensors=intermediate_tensors, - **MultiModalKwargs.as_kwargs(multi_modal_kwargs, - device=self.device), + **MultiModalKwargs.as_kwargs( + multi_modal_kwargs, + dtype=self.model_config.dtype, + device=self.device, + ), **seqlen_agnostic_kwargs, **model_kwargs, ) diff --git a/vllm/worker/multi_step_neuron_model_runner.py b/vllm/worker/multi_step_neuron_model_runner.py index 9618a4b49ff8..aafb7ab7cfb8 100644 --- a/vllm/worker/multi_step_neuron_model_runner.py +++ b/vllm/worker/multi_step_neuron_model_runner.py @@ -70,8 +70,11 @@ def execute_model( input_ids=model_input.input_tokens, positions=model_input.input_positions, input_block_ids=model_input.input_block_ids, - **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, - device=self.device), + **MultiModalKwargs.as_kwargs( + model_input.multi_modal_kwargs or {}, + dtype=self.model_config.dtype, + device=self.device, + ), ) output = self.model.sample( diff --git a/vllm/worker/multi_step_neuronx_distributed_model_runner.py b/vllm/worker/multi_step_neuronx_distributed_model_runner.py index b6a3492a493b..3a9c0993e004 100644 --- a/vllm/worker/multi_step_neuronx_distributed_model_runner.py +++ b/vllm/worker/multi_step_neuronx_distributed_model_runner.py @@ -49,8 +49,11 @@ def execute_model( positions=model_input.input_positions, input_block_ids=model_input.input_block_ids, sampling_params=sampling_params, - **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, - device=self.device), + **MultiModalKwargs.as_kwargs( + model_input.multi_modal_kwargs or {}, + dtype=self.model_config.dtype, + device=self.device, + ), ) output = self.model.sample( diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index e97adf757cc1..968596471a26 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -378,9 +378,11 @@ def execute_model( positions=model_input.input_positions, input_block_ids=model_input.input_block_ids, sampling_params=sampling_params, - **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs - or {}, - device=self.device), + **MultiModalKwargs.as_kwargs( + model_input.multi_modal_kwargs or {}, + dtype=self.model_config.dtype, + device=self.device, + ), ) elif current_platform.use_transformers_neuronx(): # [TODO] validate on-device sampling @@ -389,9 +391,11 @@ def execute_model( input_ids=model_input.input_tokens, positions=model_input.input_positions, input_block_ids=model_input.input_block_ids, - **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs - or {}, - device=self.device), + **MultiModalKwargs.as_kwargs( + model_input.multi_modal_kwargs or {}, + dtype=self.model_config.dtype, + device=self.device, + ), ) # Compute the logits only if the on-device sampling is turned off as diff --git a/vllm/worker/pooling_model_runner.py b/vllm/worker/pooling_model_runner.py index fdb7353f2f9c..912e04c435f5 100644 --- a/vllm/worker/pooling_model_runner.py +++ b/vllm/worker/pooling_model_runner.py @@ -119,10 +119,14 @@ def execute_model( input_ids=model_input.input_tokens, positions=model_input.input_positions, intermediate_tensors=intermediate_tensors, - **MultiModalKwargs.as_kwargs(multi_modal_kwargs, - device=self.device), + **MultiModalKwargs.as_kwargs( + multi_modal_kwargs, + dtype=self.model_config.dtype, + device=self.device, + ), **cross_enc_kwargs, - **seqlen_agnostic_kwargs) + **seqlen_agnostic_kwargs, + ) if (self.observability_config is not None and self.observability_config.collect_model_forward_time): diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 7042b575aa78..79fa7d2c73e8 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -562,9 +562,12 @@ def execute_model( input_ids=model_input.input_tokens, positions=model_input.input_positions, intermediate_tensors=intermediate_tensors, - **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs - or {}, - device=self.device)) + **MultiModalKwargs.as_kwargs( + model_input.multi_modal_kwargs or {}, + dtype=self.model_config.dtype, + device=self.device, + ), + ) # Compute the logits in the last pipeline stage. if not get_pp_group().is_last_rank: return hidden_or_intermediate_states