diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py index 9e49f951e20..510813c6abd 100644 --- a/vllm_omni/entrypoints/omni.py +++ b/vllm_omni/entrypoints/omni.py @@ -10,6 +10,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Any, Literal, overload +import huggingface_hub from omegaconf import OmegaConf from tqdm.auto import tqdm from vllm import SamplingParams @@ -40,6 +41,9 @@ ) from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType, OmniSamplingParams from vllm_omni.metrics import OrchestratorAggregator, StageRequestStats +from vllm_omni.model_executor.model_loader.weight_utils import ( + download_weights_from_hf_specific, +) from vllm_omni.outputs import OmniRequestOutput logger = init_logger(__name__) @@ -66,14 +70,28 @@ def _dummy_snapshot_download(model_id): def omni_snapshot_download(model_id) -> str: + # If it's already a local path, just return it + if os.path.exists(model_id): + return model_id # TODO: this is just a workaround for quickly use modelscope, we should support # modelscope in weight loading feature instead of using `snapshot_download` if os.environ.get("VLLM_USE_MODELSCOPE", False): from modelscope.hub.snapshot_download import snapshot_download return snapshot_download(model_id) - else: - return _dummy_snapshot_download(model_id) + # For other cases (Hugging Face), perform a real download to ensure all + # necessary files (including *.pt for audio/diffusion) are available locally + # before stage workers are spawned. This prevents initialization timeouts. + try: + return download_weights_from_hf_specific( + model_name_or_path=model_id, + cache_dir=None, + allow_patterns=["*"], + require_all=True, + ) + except huggingface_hub.errors.RepositoryNotFoundError: + logger.warning(f"Repository not found for '{model_id}'.") + return model_id class OmniBase: diff --git a/vllm_omni/model_executor/model_loader/weight_utils.py b/vllm_omni/model_executor/model_loader/weight_utils.py index 7432ad9a2a4..c5225ec0287 100644 --- a/vllm_omni/model_executor/model_loader/weight_utils.py +++ b/vllm_omni/model_executor/model_loader/weight_utils.py @@ -20,6 +20,7 @@ def download_weights_from_hf_specific( allow_patterns: list[str], revision: str | None = None, ignore_patterns: str | list[str] | None = None, + require_all: bool = False, ) -> str: """Download model weights from Hugging Face Hub. Users can specify the allow_patterns to download only the necessary weights. @@ -35,6 +36,9 @@ def download_weights_from_hf_specific( ignore_patterns (Optional[Union[str, list[str]]]): The patterns to filter out the weight files. Files matched by any of the patterns will be ignored. + require_all (bool): If True, will iterate through and download files + matching all patterns in allow_patterns. If False, will stop after + the first pattern that matches any files. Returns: str: The path to the downloaded model weights. @@ -48,20 +52,31 @@ def download_weights_from_hf_specific( # downloading the same model weights at the same time. with get_lock(model_name_or_path, cache_dir): start_time = time.perf_counter() - for allow_pattern in allow_patterns: + if require_all: hf_folder = snapshot_download( model_name_or_path, - allow_patterns=allow_pattern, + allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, cache_dir=cache_dir, revision=revision, local_files_only=local_only, **download_kwargs, ) - # If we have downloaded weights for this allow_pattern, - # we don't need to check the rest. - if any(Path(hf_folder).glob(allow_pattern)): - break + else: + for allow_pattern in allow_patterns: + hf_folder = snapshot_download( + model_name_or_path, + allow_patterns=allow_pattern, + ignore_patterns=ignore_patterns, + cache_dir=cache_dir, + revision=revision, + local_files_only=local_only, + **download_kwargs, + ) + # If we have downloaded weights for this allow_pattern, + # we don't need to check the rest, unless require_all is set. + if any(Path(hf_folder).glob(allow_pattern)): + break time_taken = time.perf_counter() - start_time if time_taken > 0.5: logger.info(