Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/diffusion/test_diffusion_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def is_enabled(self):

def test_load_model_clears_cache_backend_for_unsupported_pipeline(monkeypatch):
class _DummyLoader:
def __init__(self, load_config):
del load_config
def __init__(self, load_config, od_config=None):
del load_config, od_config

def load_model(self, **kwargs):
del kwargs
Expand Down
4 changes: 4 additions & 0 deletions vllm_omni/diffusion/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,10 @@ class OmniDiffusionConfig:
# Compilation
enforce_eager: bool = False

# Parallel weight loading (for faster diffusion model startup)
enable_multithread_weight_load: bool = True
num_weight_load_threads: int = 4

# Enable sleep mode
enable_sleep_mode: bool = False

Expand Down
52 changes: 33 additions & 19 deletions vllm_omni/diffusion/model_loader/diffusers_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import dataclasses
import glob
import os
import re
import time
from collections.abc import Generator, Iterable
from pathlib import Path
Expand All @@ -20,6 +21,7 @@
filter_duplicate_safetensors_files,
filter_files_not_needed_for_inference,
maybe_download_from_modelscope,
multi_thread_safetensors_weights_iterator,
safetensors_weights_iterator,
)
from vllm.utils.import_utils import resolve_obj_by_qualname
Expand All @@ -31,16 +33,19 @@
logger = init_logger(__name__)


def _natural_sort_key(filepath: str) -> list:
"""Natural sort key for filenames with numeric components, e.g.
model-00001-of-00005.safetensors -> ['model-', 1, '-of-', 5, '.safetensors']."""
return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", os.path.basename(filepath))]


MODEL_INDEX = "model_index.json"
DIFFUSION_MODEL_WEIGHTS_INDEX = "diffusion_pytorch_model.safetensors.index.json"


class DiffusersPipelineLoader:
"""Model loader that can load diffusers pipeline components from disk."""

# default number of thread when enable multithread weight loading
DEFAULT_NUM_THREADS = 8

@dataclasses.dataclass
class ComponentSource:
"""A source for weights."""
Expand All @@ -66,18 +71,9 @@ class ComponentSource:
counter_before_loading_weights: float = 0.0
counter_after_loading_weights: float = 0.0

def __init__(self, load_config: LoadConfig):
def __init__(self, load_config: LoadConfig, od_config: OmniDiffusionConfig | None = None):
self.load_config = load_config

# TODO(Isotr0py): Enable multithreaded weight loading
# extra_config = load_config.model_loader_extra_config
# allowed_keys = {"enable_multithread_load", "num_threads"}
# unexpected_keys = set(extra_config.keys()) - allowed_keys

# if unexpected_keys:
# raise ValueError(
# f"Unexpected extra config keys for load format {load_config.load_format}: {unexpected_keys}"
# )
self.od_config = od_config

def _prepare_weights(
self,
Expand Down Expand Up @@ -171,18 +167,36 @@ def _prepare_weights(

def _get_weights_iterator(self, source: "ComponentSource") -> Generator[tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format."""
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
_, hf_weights_files, use_safetensors = self._prepare_weights(
source.model_or_path,
source.subfolder,
source.revision,
source.fall_back_to_pt,
source.allow_patterns_overrides,
)
weights_iterator = safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
self.load_config.safetensors_load_strategy,

od_config = self.od_config
use_multithread = (
use_safetensors
and od_config is not None
and getattr(od_config, "enable_multithread_weight_load", False)
and self.load_config.safetensors_load_strategy != "torchao"
)
if use_multithread:
num_threads = getattr(od_config, "num_weight_load_threads", 4)
# Keep deterministic shard order before passing to vLLM helper.
sorted_hf_weights_files = sorted(hf_weights_files, key=_natural_sort_key)
weights_iterator = multi_thread_safetensors_weights_iterator(
sorted_hf_weights_files,
self.load_config.use_tqdm_on_load,
max_workers=num_threads,
)
else:
weights_iterator = safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
self.load_config.safetensors_load_strategy,
)

if self.counter_before_loading_weights == 0.0:
self.counter_before_loading_weights = time.perf_counter()
Expand Down
2 changes: 1 addition & 1 deletion vllm_omni/diffusion/worker/diffusion_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def get_memory_context():

# Load model within forward context
load_config = LoadConfig()
model_loader = DiffusersPipelineLoader(load_config)
model_loader = DiffusersPipelineLoader(load_config, od_config=self.od_config)
time_before_load = time.perf_counter()

with get_memory_context():
Expand Down
2 changes: 2 additions & 0 deletions vllm_omni/entrypoints/async_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ def _create_default_diffusion_stage_cfg(self, kwargs: dict[str, Any]) -> dict[st
"enforce_eager": kwargs.get("enforce_eager", False),
"diffusion_load_format": kwargs.get("diffusion_load_format", "default"),
"custom_pipeline_args": kwargs.get("custom_pipeline_args", None),
"enable_multithread_weight_load": kwargs.get("enable_multithread_weight_load", True),
"num_weight_load_threads": kwargs.get("num_weight_load_threads", 4),
},
"final_output": True,
"final_output_type": "image",
Expand Down
15 changes: 15 additions & 0 deletions vllm_omni/entrypoints/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,21 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu
help="Enable VAE tiling for memory optimization (useful for mitigating OOM issues).",
)

# Parallel weight loading (faster diffusion startup)
omni_config_group.add_argument(
"--disable-multithread-weight-load",
action="store_false",
dest="enable_multithread_weight_load",
default=True,
help="Disable multi-threaded safetensors loading (default: enabled with 4 threads).",
Comment on lines +246 to +250
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Wire new weight-load CLI flags into diffusion engine args

These new serve flags are parsed, but in the async serving path they are not propagated into diffusion engine_args (the default diffusion stage builder in vllm_omni/entrypoints/async_omni.py only forwards a fixed subset and omits enable_multithread_weight_load / num_weight_load_threads). Because _build_od_config only copies fields present in engine_args, --disable-multithread-weight-load (and custom thread counts) are silently ignored for default diffusion serving.

Useful? React with 👍 / 👎.

)
omni_config_group.add_argument(
"--num-weight-load-threads",
type=int,
default=4,
help="Number of threads for parallel weight loading (default: 4).",
)

# diffusion model offload parameters
omni_config_group.add_argument(
"--enable-cpu-offload",
Expand Down