Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 147 additions & 33 deletions vllm_omni/diffusion/model_loader/hub_prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,22 @@
import logging
import os
import time
from collections.abc import Iterable, Iterator
from collections.abc import Callable, Iterable, Iterator
from typing import Any

logger = logging.getLogger(__name__)

# A racy / partially-evicted HF cache (the exact failure this module defends
# against) is transient: re-running ``snapshot_download`` blocks on the peer
# writer's per-blob ``.lock`` and then returns a complete tree. So a bounded
# retry with linear backoff is what actually closes the window that a single
# best-effort attempt left open (Buildkite vllm-omni-rebase #1858: both the
# ``cuda_ti2v_hsdp`` missing-shard ``OSError`` and the ``wan_2_1_vace`` default

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider making _PREFETCH_MAX_ATTEMPTS and _PREFETCH_BACKOFF_BASE_S configurable via environment variables for CI tuning.

# ``UMT5Config`` size-mismatch were a swallowed prefetch followed by a
# ``from_pretrained`` against the half-written cache).
_PREFETCH_MAX_ATTEMPTS = 3
_PREFETCH_BACKOFF_BASE_S = 1.0


def _node_lock_dir() -> str:
"""Return a node-local directory suitable for the prefetch lock files.
Expand Down Expand Up @@ -276,7 +288,11 @@ def prefetch_subfolders(
if local_files_only or not model or os.path.isdir(model):
return

logger.info("Prefetching %s subfolders: %s", model, list(subfolders))
# Materialise ``subfolders`` up-front: it may be a one-shot generator and
# we reference it again in the retry / logging paths below.
subfolders = list(subfolders)

logger.info("Prefetching %s subfolders: %s", model, subfolders)

try:
from huggingface_hub import snapshot_download
Expand Down Expand Up @@ -306,39 +322,65 @@ def prefetch_subfolders(
# snapshot is now warm. This is what makes the prefetch race-free
# even when many ``DiffusionWorker`` subprocesses (or multiple
# OmniServer instances on the same node) hit this code in parallel.
try:
with _repo_prefetch_lock(model):
snapshot_download(
repo_id=model,
allow_patterns=allow_patterns,
)
logger.info("Prefetch complete for %s", model)
except Exception as exc:
# Best-effort: propagate only via logging. The subsequent
# ``from_pretrained`` call will raise a clearer, call-site-specific
# error (auth, 404, disk full, ...) that we'd rather surface - EXCEPT
# for auth/gating, which we escalate here with an explicit hint so
# readers of CI logs don't have to correlate the generic "OSError:
# <repo> does not appear to have a file named ..." that
# ``from_pretrained`` would otherwise emit much later with an
# unrelated-looking message.
if _looks_like_auth_error(exc):
logger.error(
"Hub prefetch for '%s' failed with an authentication / gated "
"repository error (%s: %s). The CI HF_TOKEN must (1) be set "
"in the step env, (2) be valid, and (3) belong to an account "
"that has accepted the model license on huggingface.co. See "
"docs/contributing/ci/hf_credentials.md.",
model,
type(exc).__name__,
exc,
)
else:
# A single best-effort attempt is not enough: when several diffusion
# workers race a cold cache (and the node-wide lock fails to serialise
# them, as observed for HSDP / ring launches), ``snapshot_download`` can
# raise on a half-written tree. Swallowing that and proceeding straight
# to ``from_pretrained`` is exactly what turned a recoverable prefetch
# hiccup into a hard server crash. Retry with backoff so the snapshot
# actually completes before any loader reads the cache.
for attempt in range(1, _PREFETCH_MAX_ATTEMPTS + 1):
try:
with _repo_prefetch_lock(model):
snapshot_download(
repo_id=model,
allow_patterns=allow_patterns,
)
logger.info("Prefetch complete for %s", model)
return
except Exception as exc:
# Auth / gating never heals on retry - escalate immediately with
# an explicit hint so readers of CI logs don't have to correlate
# the generic "OSError: <repo> does not appear to have a file
# named ..." that ``from_pretrained`` would otherwise emit later.
if _looks_like_auth_error(exc):
logger.error(
"Hub prefetch for '%s' failed with an authentication / gated "
"repository error (%s: %s). The CI HF_TOKEN must (1) be set "
"in the step env, (2) be valid, and (3) belong to an account "
"that has accepted the model license on huggingface.co. See "
"docs/contributing/ci/hf_credentials.md.",
model,
type(exc).__name__,
exc,
)
return

if attempt < _PREFETCH_MAX_ATTEMPTS:
backoff = _PREFETCH_BACKOFF_BASE_S * attempt
logger.warning(
"Hub prefetch for repo '%s' subfolders %s failed on attempt %d/%d (%s: %s); retrying in %.1fs",
model,
subfolders,
attempt,
_PREFETCH_MAX_ATTEMPTS,
type(exc).__name__,
exc,
backoff,
)
time.sleep(backoff)
continue

# Exhausted retries. Stay best-effort: propagate only via logging
# so the subsequent ``from_pretrained`` call surfaces the real,
# call-site-specific error (and ``from_pretrained_with_prefetch``
# gets a final chance to heal the cache).
logger.warning(
"Hub prefetch for repo '%s' subfolders %s failed (%s: %s); "
"falling back to on-demand download in from_pretrained",
"Hub prefetch for repo '%s' subfolders %s failed after %d attempts "
"(%s: %s); falling back to on-demand download in from_pretrained",
model,
list(subfolders),
subfolders,
_PREFETCH_MAX_ATTEMPTS,
type(exc).__name__,
exc,
)
Expand Down Expand Up @@ -376,6 +418,78 @@ def _looks_like_auth_error(exc: BaseException) -> bool:
return "401 client error" in msg or "403 client error" in msg or "gatedrepo" in msg


def from_pretrained_with_prefetch(
factory: Callable[..., Any],
model: str,
*,
subfolder: str,
prefetch_list: Iterable[str],
local_files_only: bool = False,
max_attempts: int = _PREFETCH_MAX_ATTEMPTS,
**from_pretrained_kwargs: Any,
) -> Any:
"""Call ``factory.from_pretrained`` healing a racy / partial HF cache.

``factory`` is a bound ``SomeModel.from_pretrained`` (or any callable with
the same ``(model, *, subfolder, local_files_only, **kwargs)`` signature).

This is a stronger sibling of :func:`retry_on_missing_shard`: that helper
only retries the missing-shard ``OSError`` and never re-prefetches, so it
cannot recover the second face of the same race. Two shapes of partial
-cache failure crash the diffusion server outright:

* ``OSError: <repo> does not appear to have a file named
text_encoder/model-0000X-of-0000Y.safetensors`` - a shard is still under
its ``.incomplete`` name.
* ``RuntimeError: You set 'ignore_mismatched_sizes' to 'False' ...`` -
``text_encoder/config.json`` was not present yet, so ``transformers`` v5
silently fell back to the default (tiny) config and then could not load
the real checkpoint into it.

Both heal once the cache is complete. So on those errors we re-run a
*verified* prefetch (which blocks on the peer writer and retries the
download) and reload, instead of letting the worker die. Local paths and
``local_files_only`` loads cannot be healed by re-fetching, so they raise
on the first failure exactly as before.
"""
prefetch_list = list(prefetch_list)
can_heal = not local_files_only and bool(model) and not os.path.isdir(model)
last_exc: BaseException | None = None

for attempt in range(1, max_attempts + 1):
try:
return factory(
model,
subfolder=subfolder,
local_files_only=local_files_only,
**from_pretrained_kwargs,
)
except (OSError, RuntimeError, ValueError) as exc:
last_exc = exc
if not can_heal or attempt >= max_attempts:
break
backoff = _PREFETCH_BACKOFF_BASE_S * attempt
logger.warning(
"from_pretrained(%s, subfolder=%s) failed on attempt %d/%d "
"(%s: %s); re-prefetching repo and retrying in %.1fs",
model,
subfolder,
attempt,
max_attempts,
type(exc).__name__,
exc,
backoff,
)
time.sleep(backoff)
# Force a fresh, verified snapshot of every component this pipeline
# needs - not just ``subfolder`` - so a sibling component that was
# also half-written gets repaired in the same pass.
prefetch_subfolders(model, prefetch_list, local_files_only=False)

assert last_exc is not None # loop only exits via return or a caught exc
raise last_exc


def retry_on_missing_shard(load_fn, *, max_retries: int = 3, base_delay: float = 5.0):
"""Call *load_fn* with retry on the transformers v5 shard-resolution race.

Expand Down
17 changes: 13 additions & 4 deletions vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.model_loader.hub_prefetch import prefetch_subfolders
from vllm_omni.diffusion.model_loader.hub_prefetch import from_pretrained_with_prefetch, prefetch_subfolders
from vllm_omni.diffusion.models.flux2_klein.flux2_klein_transformer import (
Flux2Transformer2DModel,
)
Expand Down Expand Up @@ -213,9 +213,10 @@ def __init__(
# Avoid the transformers v5 multi-worker subfolder race (see
# ``vllm_omni/diffusion/model_loader/hub_prefetch.py`` for the full
# analysis; L4 build #1043 hit this on FLUX.2-klein-4B's text_encoder).
flux2_subfolders = ["scheduler", "text_encoder", "tokenizer", "vae"]
prefetch_subfolders(
model,
["scheduler", "text_encoder", "tokenizer", "vae"],
flux2_subfolders,
local_files_only=local_files_only,
)

Expand All @@ -224,19 +225,27 @@ def __init__(
subfolder="scheduler",
local_files_only=local_files_only,
)
self.text_encoder = Qwen3ForCausalLM.from_pretrained(
# ``from_pretrained_with_prefetch`` re-prefetches and retries if a peer
# worker left the cache half-written (missing-shard ``OSError`` or the
# default-config size-mismatch ``RuntimeError``) instead of crashing
# the worker - FLUX.2-klein-4B's sharded text_encoder hit this on #1043.
self.text_encoder = from_pretrained_with_prefetch(
Qwen3ForCausalLM.from_pretrained,
model,
subfolder="text_encoder",
prefetch_list=flux2_subfolders,
local_files_only=local_files_only,
).to(self._execution_device)
self.tokenizer = Qwen2TokenizerFast.from_pretrained(
model,
subfolder="tokenizer",
local_files_only=local_files_only,
)
self.vae = AutoencoderKLFlux2.from_pretrained(
self.vae = from_pretrained_with_prefetch(
AutoencoderKLFlux2.from_pretrained,
model,
subfolder="vae",
prefetch_list=flux2_subfolders,
local_files_only=local_files_only,
).to(self._execution_device)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from vllm_omni.diffusion.distributed.parallel_state import get_classifier_free_guidance_world_size
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.model_loader.hub_prefetch import from_pretrained_with_prefetch, prefetch_subfolders
from vllm_omni.diffusion.models.hidream_image import HiDreamImageTransformer2DModel
from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin
from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin
Expand Down Expand Up @@ -177,24 +178,55 @@ def __init__(
# Check if model is a local path
local_files_only = os.path.exists(model)

# See ``hub_prefetch.py`` for the transformers v5 multi-worker subfolder
# race; prefetch the in-repo component set before any from_pretrained
# (``text_encoder_4`` lives in a separate Llama repo and is unaffected).
hidream_subfolders = [
"scheduler",
"vae",
"text_encoder",
"tokenizer",
"text_encoder_2",
"tokenizer_2",
"text_encoder_3",
"tokenizer_3",
]
prefetch_subfolders(model, hidream_subfolders, local_files_only=local_files_only)

self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
model, subfolder="scheduler", local_files_only=local_files_only
)
self.vae = AutoencoderKL.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to(
self.device
)
self.text_encoder = CLIPTextModelWithProjection.from_pretrained(
model, subfolder="text_encoder", local_files_only=local_files_only
self.vae = from_pretrained_with_prefetch(
AutoencoderKL.from_pretrained,
model,
subfolder="vae",
prefetch_list=hidream_subfolders,
local_files_only=local_files_only,
).to(self.device)
self.text_encoder = from_pretrained_with_prefetch(
CLIPTextModelWithProjection.from_pretrained,
model,
subfolder="text_encoder",
prefetch_list=hidream_subfolders,
local_files_only=local_files_only,
)
self.tokenizer = CLIPTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only)
self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
model, subfolder="text_encoder_2", local_files_only=local_files_only
self.text_encoder_2 = from_pretrained_with_prefetch(
CLIPTextModelWithProjection.from_pretrained,
model,
subfolder="text_encoder_2",
prefetch_list=hidream_subfolders,
local_files_only=local_files_only,
)
self.tokenizer_2 = CLIPTokenizer.from_pretrained(
model, subfolder="tokenizer_2", local_files_only=local_files_only
)
self.text_encoder_3 = T5EncoderModel.from_pretrained(
model, subfolder="text_encoder_3", local_files_only=local_files_only
self.text_encoder_3 = from_pretrained_with_prefetch(
T5EncoderModel.from_pretrained,
model,
subfolder="text_encoder_3",
prefetch_list=hidream_subfolders,
local_files_only=local_files_only,
)
self.tokenizer_3 = T5Tokenizer.from_pretrained(
model, subfolder="tokenizer_3", local_files_only=local_files_only
Expand Down
Loading
Loading