diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index dc568928b285..f8cde9ba3c34 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -422,6 +422,16 @@ def get_end_ptr(tensor: torch.Tensor) -> int: result[k] = t return result + def _prepare_weights(self, model_name_or_path: str, + revision: Optional[str]): + if os.path.isdir(model_name_or_path): + return model_name_or_path + else: + allow_patterns = ["*.safetensors"] + return download_weights_from_hf(model_name_or_path, + self.load_config.download_dir, + allow_patterns, revision) + def load_model(self, *, model_config: ModelConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig], @@ -432,6 +442,10 @@ def load_model(self, *, model_config: ModelConfig, from safetensors.torch import safe_open from vllm.distributed import get_tensor_model_parallel_rank + + local_model_path = self._prepare_weights(model_config.model, + model_config.revision) + with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model = _initialize_model(model_config, self.load_config, @@ -439,7 +453,7 @@ def load_model(self, *, model_config: ModelConfig, cache_config) rank = get_tensor_model_parallel_rank() pattern = os.path.join( - model_config.model, + local_model_path, self.pattern.format(rank=rank, part="*"), ) filepaths = glob.glob(pattern)