diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 43d2c88d3b9c..1605467bc3dd 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -143,6 +143,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.model: PreTrainedModel = AutoModel.from_config( self.config, attn_implementation="vllm", + torch_dtype=vllm_config.model_config.dtype, trust_remote_code=vllm_config.model_config.trust_remote_code, ) prefix = self.model.base_model_prefix