From 4753e94928c9e2de775bbe548c55886533a4e630 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Fri, 17 May 2024 11:00:00 -0400 Subject: [PATCH 1/3] add hf download to sharded state loader --- vllm/model_executor/model_loader/loader.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index dc568928b285..4d5ae41525ce 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -422,6 +422,16 @@ def get_end_ptr(tensor: torch.Tensor) -> int: result[k] = t return result + def _prepare_weights(self, model_name_or_path: str, + revision: Optional[str]): + if os.path.isdir(model_name_or_path): + return model_name_or_path + else: + allow_patterns = ["*.safetensors"] + return download_weights_from_hf(model_name_or_path, + self.load_config.download_dir, + allow_patterns, revision) + def load_model(self, *, model_config: ModelConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig], @@ -432,6 +442,9 @@ def load_model(self, *, model_config: ModelConfig, from safetensors.torch import safe_open from vllm.distributed import get_tensor_model_parallel_rank + + self._prepare_weights(model_config.model, model_config.revision) + with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model = _initialize_model(model_config, self.load_config, From 05e280ca9cf814ef01601f6ff3f83b9ef10e897b Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Fri, 17 May 2024 11:33:53 -0400 Subject: [PATCH 2/3] update --- vllm/model_executor/model_loader/loader.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 4d5ae41525ce..3a16c7ed10df 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -443,7 +443,8 @@ def load_model(self, *, model_config: ModelConfig, from vllm.distributed import get_tensor_model_parallel_rank - self._prepare_weights(model_config.model, model_config.revision) + local_model_path = self._prepare_weights(model_config.model, + model_config.revision) with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): @@ -452,8 +453,7 @@ def load_model(self, *, model_config: ModelConfig, cache_config) rank = get_tensor_model_parallel_rank() pattern = os.path.join( - model_config.model, - self.pattern.format(rank=rank, part="*"), + local_model_path, self.pattern.format(rank=rank, part="*"), ) filepaths = glob.glob(pattern) if not filepaths: From b558b165e8edbe5765cb3c221f1eae20676c8395 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Mon, 20 May 2024 12:28:52 -0400 Subject: [PATCH 3/3] fix formatting --- vllm/model_executor/model_loader/loader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 3a16c7ed10df..f8cde9ba3c34 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -453,7 +453,8 @@ def load_model(self, *, model_config: ModelConfig, cache_config) rank = get_tensor_model_parallel_rank() pattern = os.path.join( - local_model_path, self.pattern.format(rank=rank, part="*"), + local_model_path, + self.pattern.format(rank=rank, part="*"), ) filepaths = glob.glob(pattern) if not filepaths: