Skip to content

Commit bd2a9e3

Browse files
committed
Download HF model for all nodes.
Signed-off-by: Yuxian Qiu <[email protected]>
1 parent a32a2e4 commit bd2a9e3

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

tensorrt_llm/llmapi/llm_utils.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
import torch
1313
from tqdm import tqdm
1414

15-
from .._utils import (global_mpi_rank, mpi_barrier, mpi_broadcast, mpi_rank,
16-
release_gc)
15+
from .._utils import (global_mpi_rank, local_mpi_rank, mpi_barrier,
16+
mpi_broadcast, mpi_rank, release_gc)
1717
from ..auto_parallel import AutoParallelConfig
1818
# yapf: disable
1919
from ..bindings.executor import (BatchingType, CapacitySchedulerPolicy,
@@ -627,9 +627,11 @@ def __call__(self) -> Tuple[Path, Union[Path, None]]:
627627
f'backend {self.llm_args.backend} is not supported.')
628628

629629
if self.model_loader.model_obj.is_hub_model:
630-
self._hf_model_dir = download_hf_model(
631-
self.model_loader.model_obj.model_name,
632-
self.llm_args.revision)
630+
hf_model_dirs = self.mpi_session.submit_sync(
631+
CachedModelLoader._node_download_hf_model,
632+
model=self.model_loader.model_obj.model_name,
633+
revision=self.llm_args.revision)
634+
self._hf_model_dir = hf_model_dirs[0]
633635
else:
634636
self._hf_model_dir = self.model_loader.model_obj.model_dir
635637

@@ -806,6 +808,17 @@ def build_task(engine_dir: Path):
806808

807809
return self.get_engine_dir()
808810

811+
@print_traceback_on_error
812+
@staticmethod
813+
def _node_download_hf_model(
814+
model: str,
815+
revision: Optional[str] = None,
816+
) -> Optional[Path]:
817+
if local_mpi_rank() == 0:
818+
return download_hf_model(model, revision)
819+
else:
820+
return None
821+
809822
@print_traceback_on_error
810823
@staticmethod
811824
def _node_build_task(

0 commit comments

Comments
 (0)