|
12 | 12 | import torch |
13 | 13 | from tqdm import tqdm |
14 | 14 |
|
15 | | -from .._utils import (global_mpi_rank, mpi_barrier, mpi_broadcast, mpi_rank, |
16 | | - release_gc) |
| 15 | +from .._utils import (global_mpi_rank, local_mpi_rank, mpi_barrier, |
| 16 | + mpi_broadcast, mpi_rank, release_gc) |
17 | 17 | from ..auto_parallel import AutoParallelConfig |
18 | 18 | # yapf: disable |
19 | 19 | from ..bindings.executor import (BatchingType, CapacitySchedulerPolicy, |
@@ -627,9 +627,11 @@ def __call__(self) -> Tuple[Path, Union[Path, None]]: |
627 | 627 | f'backend {self.llm_args.backend} is not supported.') |
628 | 628 |
|
629 | 629 | if self.model_loader.model_obj.is_hub_model: |
630 | | - self._hf_model_dir = download_hf_model( |
631 | | - self.model_loader.model_obj.model_name, |
632 | | - self.llm_args.revision) |
| 630 | + hf_model_dirs = self.mpi_session.submit_sync( |
| 631 | + CachedModelLoader._node_download_hf_model, |
| 632 | + model=self.model_loader.model_obj.model_name, |
| 633 | + revision=self.llm_args.revision) |
| 634 | + self._hf_model_dir = hf_model_dirs[0] |
633 | 635 | else: |
634 | 636 | self._hf_model_dir = self.model_loader.model_obj.model_dir |
635 | 637 |
|
@@ -806,6 +808,17 @@ def build_task(engine_dir: Path): |
806 | 808 |
|
807 | 809 | return self.get_engine_dir() |
808 | 810 |
|
| 811 | + @print_traceback_on_error |
| 812 | + @staticmethod |
| 813 | + def _node_download_hf_model( |
| 814 | + model: str, |
| 815 | + revision: Optional[str] = None, |
| 816 | + ) -> Optional[Path]: |
| 817 | + if local_mpi_rank() == 0: |
| 818 | + return download_hf_model(model, revision) |
| 819 | + else: |
| 820 | + return None |
| 821 | + |
809 | 822 | @print_traceback_on_error |
810 | 823 | @staticmethod |
811 | 824 | def _node_build_task( |
|
0 commit comments