diff --git a/docs/user_guide/diffusion/lora.md b/docs/user_guide/diffusion/lora.md index e45c033b848..256698752a1 100644 --- a/docs/user_guide/diffusion/lora.md +++ b/docs/user_guide/diffusion/lora.md @@ -56,6 +56,92 @@ outputs = omni.generate( !!! note "Server-side Path Requirement" The LoRA adapter path (`local_path`) must be readable on the **server** machine. If your client and server are on different machines, ensure the LoRA adapter is accessible via a shared mount or copied to the server. +## Wan2.2 LightX2V Offline Assembly + +This workflow is LoRA-adjacent: it uses external LightX2V conversion plus +`Wan2.2-Distill-Loras` to bake converted Wan2.2 I2V checkpoints into a local +Diffusers directory, instead of loading LoRA adapters at runtime. + +### Required assets + +- Base model: `Wan-AI/Wan2.2-I2V-A14B` +- Diffusers skeleton: `Wan-AI/Wan2.2-I2V-A14B-Diffusers` +- Optional external converter from the LightX2V project (not shipped in this repository) +- Optional LoRA weights: `lightx2v/Wan2.2-Distill-Loras` + +### Step 1: Optional - convert high/low-noise DiT weights with LightX2V + +Install or clone LightX2V from the upstream repository +(`https://github.com/ModelTC/LightX2V`). After cloning, the converter used +below is available at `/tools/convert/converter.py`. + +```bash +python /path/to/lightx2v/tools/convert/converter.py \ + --source /path/to/Wan2.2-I2V-A14B/high_noise_model \ + --output /tmp/wan22_lightx2v/high_noise_out \ + --output_ext .safetensors \ + --output_name diffusion_pytorch_model \ + --model_type wan_dit \ + --direction forward \ + --lora_path /path/to/wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step_1022.safetensors \ + --lora_key_convert auto \ + --single_file + +python /path/to/lightx2v/tools/convert/converter.py \ + --source /path/to/Wan2.2-I2V-A14B/low_noise_model \ + --output /tmp/wan22_lightx2v/low_noise_out \ + --output_ext .safetensors \ + --output_name diffusion_pytorch_model \ + --model_type wan_dit \ + --direction forward \ + --lora_path /path/to/wan2.2_i2v_A14b_low_noise_lora_rank64_lightx2v_4step_1022.safetensors \ + --lora_key_convert auto \ + --single_file +``` + +If you are not using LightX2V, skip this step and either keep the original +Diffusers weights from the skeleton or point Step 2 at any other converted +`transformer/` and `transformer_2/` checkpoints. + +### Step 2: Assemble a final Diffusers-style directory + +```bash +python tools/wan22/assemble_wan22_i2v_diffusers.py \ + --diffusers-skeleton /path/to/Wan2.2-I2V-A14B-Diffusers \ + --transformer-weight /tmp/wan22_lightx2v/high_noise_out \ + --transformer-2-weight /tmp/wan22_lightx2v/low_noise_out \ + --output-dir /path/to/Wan2.2-I2V-A14B-Custom-Diffusers \ + --asset-mode symlink \ + --overwrite +``` + +`--transformer-weight` and `--transformer-2-weight` are optional. If you omit +them, the tool keeps the original weights from the Diffusers skeleton. + +### Step 3: Run offline inference + +```bash +python examples/offline_inference/image_to_video/image_to_video.py \ + --model /path/to/Wan2.2-I2V-A14B-Custom-Diffusers \ + --image /path/to/input.jpg \ + --prompt "A cat playing with yarn" \ + --num-frames 81 \ + --num-inference-steps 4 \ + --tensor-parallel-size 4 \ + --height 480 \ + --width 832 \ + --flow-shift 12 \ + --sample-solver euler \ + --guidance-scale 1.0 \ + --guidance-scale-high 1.0 \ + --boundary-ratio 0.875 +``` + +Notes: + +- This route avoids runtime LoRA loading changes in vLLM-Omni when you choose to bake converted weights into a local Diffusers directory. +- Output quality and speed depend on the replacement checkpoints and sampling params you choose. + ## See Also diff --git a/docs/user_guide/examples/offline_inference/image_to_video.md b/docs/user_guide/examples/offline_inference/image_to_video.md index 7a750aeff3b..6e105741a7e 100644 --- a/docs/user_guide/examples/offline_inference/image_to_video.md +++ b/docs/user_guide/examples/offline_inference/image_to_video.md @@ -62,12 +62,13 @@ Key arguments: - `--negative-prompt`: Optional list of artifacts to suppress. - `--boundary-ratio`: Boundary split ratio for two-stage MoE models. - `--flow-shift`: Scheduler flow shift (5.0 for 720p, 12.0 for 480p). +- `--sample-solver`: Wan2.2 sampling solver. Use `unipc` for the default multistep solver, or `euler` for Lightning/Distill checkpoints. - `--num-inference-steps`: Number of denoising steps (default 50). - `--fps`: Frames per second for the saved MP4 (requires `diffusers` export_to_video). - `--output`: Path to save the generated video. - `--vae-use-slicing`: Enable VAE slicing for memory optimization. - `--vae-use-tiling`: Enable VAE tiling for memory optimization. -- `--cfg-parallel-size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](https://github.com/vllm-project/vllm-omni/tree/main/docs/user_guide/diffusion/parallelism_acceleration.md#cfg-parallel). +- `--cfg-parallel-size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](https://github.com/vllm-project/vllm-omni/tree/main/docs/user_guide/diffusion/parallelism/cfg_parallel.md). - `--tensor-parallel-size`: tensor parallel size (effective for models that support TP, e.g. LTX2). - `--enable-cpu-offload`: enable CPU offloading for diffusion models. - `--use-hsdp`: Enable Hybrid Sharded Data Parallel to shard model weights across GPUs. @@ -78,6 +79,9 @@ Key arguments: > ℹ️ If you encounter OOM errors, try using `--vae-use-slicing` and `--vae-use-tiling` to reduce memory usage. +For Wan2.2 LightX2V-converted local Diffusers directories and related LoRA +assets, see the [LoRA guide](../../diffusion/lora.md#wan22-lightx2v-offline-assembly). + ## Example materials ??? abstract "image_to_video.py" diff --git a/examples/offline_inference/image_to_video/README.md b/examples/offline_inference/image_to_video/README.md index 2692c76df26..a458850a02b 100644 --- a/examples/offline_inference/image_to_video/README.md +++ b/examples/offline_inference/image_to_video/README.md @@ -59,12 +59,13 @@ Key arguments: - `--negative-prompt`: Optional list of artifacts to suppress. - `--boundary-ratio`: Boundary split ratio for two-stage MoE models. - `--flow-shift`: Scheduler flow shift (5.0 for 720p, 12.0 for 480p). +- `--sample-solver`: Wan2.2 sampling solver. Use `unipc` for the default multistep solver, or `euler` for Lightning/Distill checkpoints. - `--num-inference-steps`: Number of denoising steps (default 50). - `--fps`: Frames per second for the saved MP4 (requires `diffusers` export_to_video). - `--output`: Path to save the generated video. - `--vae-use-slicing`: Enable VAE slicing for memory optimization. - `--vae-use-tiling`: Enable VAE tiling for memory optimization. -- `--cfg-parallel-size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](../../../docs/user_guide/diffusion/parallelism_acceleration.md#cfg-parallel). +- `--cfg-parallel-size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](https://github.com/vllm-project/vllm-omni/tree/main/docs/user_guide/diffusion/parallelism/cfg_parallel.md). - `--tensor-parallel-size`: tensor parallel size (effective for models that support TP, e.g. LTX2). - `--enable-cpu-offload`: enable CPU offloading for diffusion models. - `--use-hsdp`: Enable Hybrid Sharded Data Parallel to shard model weights across GPUs. @@ -74,3 +75,6 @@ Key arguments: > ℹ️ If you encounter OOM errors, try using `--vae-use-slicing` and `--vae-use-tiling` to reduce memory usage. + +For Wan2.2 LightX2V-converted local Diffusers directories and related LoRA +assets, see the [LoRA guide](../../../docs/user_guide/diffusion/lora.md#wan22-lightx2v-offline-assembly). diff --git a/examples/offline_inference/image_to_video/image_to_video.py b/examples/offline_inference/image_to_video/image_to_video.py index c8c55c485ad..419108d9079 100644 --- a/examples/offline_inference/image_to_video/image_to_video.py +++ b/examples/offline_inference/image_to_video/image_to_video.py @@ -84,6 +84,13 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--flow-shift", type=float, default=5.0, help="Scheduler flow_shift (5.0 for 720p, 12.0 for 480p)." ) + parser.add_argument( + "--sample-solver", + type=str, + default="unipc", + choices=["unipc", "euler"], + help="Sampling solver for Wan2.2 pipelines. Use 'euler' for Lightning/Distill setups.", + ) parser.add_argument("--output", type=str, default="i2v_output.mp4", help="Path to save the video (mp4).") parser.add_argument("--fps", type=int, default=None, help="Frames per second for the output video.") parser.add_argument( @@ -305,6 +312,7 @@ def main(): print(f" Model: {args.model}") print(f" Inference steps: {args.num_inference_steps}") print(f" Frames: {args.num_frames}") + print(f" Solver: {args.sample_solver}") print( f" Parallel configuration: cfg_parallel_size={args.cfg_parallel_size}," f" tensor_parallel_size={args.tensor_parallel_size}, vae_patch_parallel_size={args.vae_patch_parallel_size}" @@ -326,9 +334,14 @@ def main(): generator=generator, guidance_scale=guidance_scale, guidance_scale_2=args.guidance_scale_high, + boundary_ratio=args.boundary_ratio, num_inference_steps=num_inference_steps, num_frames=num_frames, frame_rate=frame_rate, + extra_args={ + "sample_solver": args.sample_solver, + "flow_shift": args.flow_shift, + }, ), ) generation_end = time.perf_counter() diff --git a/examples/online_serving/image_to_video/README.md b/examples/online_serving/image_to_video/README.md index 49283bd9a06..285eeb27983 100644 --- a/examples/online_serving/image_to_video/README.md +++ b/examples/online_serving/image_to_video/README.md @@ -26,6 +26,23 @@ The script allows overriding: - `CACHE_BACKEND` (default: `none`) - `ENABLE_CACHE_DIT_SUMMARY` (default: `0`) +### Ascend / Local LightX2V Example + +For a local Wan2.2-LightX2V Diffusers directory on Ascend/NPU, you can start the server like this: + +```bash +vllm serve /path/to/Wan2.2-I2V-A14B-LightX2V-Diffusers-Lightning \ + --omni \ + --port 8091 \ + --flow-shift 12 \ + --cfg-parallel-size 1 \ + --ulysses-degree 4 \ + --use-hsdp \ + --trust-remote-code \ + --allowed-local-media-path / \ + --seed 42 +``` + ## Async Job Behavior `POST /v1/videos` is asynchronous. It creates a video job and immediately @@ -69,10 +86,35 @@ curl -X POST http://localhost:8091/v1/videos/sync \ -F "guidance_scale_2=1.0" \ -F "boundary_ratio=0.875" \ -F "flow_shift=12.0" \ + -F 'extra_params={"sample_solver":"euler"}' \ -F "seed=42" \ -o sync_i2v_output.mp4 ``` +For Wan Lightning/Distill checkpoints, pass `{"sample_solver":"euler"}` via `extra_params`. The default solver is `unipc`. + +Example matching the local LightX2V deployment above: + +```bash +curl -sS -X POST http://localhost:8091/v1/videos/sync \ + -H "Accept: video/mp4" \ + -F "prompt=A cat playing with yarn" \ + -F "input_reference=@/path/to/input.jpg" \ + -F "width=832" \ + -F "height=480" \ + -F "num_frames=81" \ + -F "fps=16" \ + -F "num_inference_steps=4" \ + -F "guidance_scale=1.0" \ + -F "guidance_scale_2=1.0" \ + -F "boundary_ratio=0.875" \ + -F "seed=42" \ + -F 'extra_params={"sample_solver":"euler"}' \ + -o ./output.mp4 +``` + +Use `/v1/videos/sync` if you want to write the MP4 directly to a file. `POST /v1/videos` is async and returns job metadata, not inline `b64_json`. + ## Storage Generated video files are stored on local disk by the async video API. @@ -96,6 +138,9 @@ export VLLM_OMNI_STORAGE_MAX_CONCURRENCY=8 # Basic image-to-video generation bash run_curl_image_to_video.sh +# Wan Lightning/Distill checkpoints +SAMPLE_SOLVER=euler bash run_curl_image_to_video.sh + # Or execute directly (OpenAI-style multipart) create_response=$(curl -s http://localhost:8091/v1/videos \ -H "Accept: application/json" \ @@ -111,6 +156,7 @@ create_response=$(curl -s http://localhost:8091/v1/videos \ -F "guidance_scale_2=1.0" \ -F "boundary_ratio=0.875" \ -F "flow_shift=12.0" \ + -F 'extra_params={"sample_solver":"euler"}' \ -F "seed=42") video_id=$(echo "$create_response" | jq -r '.id') @@ -169,9 +215,12 @@ curl -X POST http://localhost:8091/v1/videos \ -F "guidance_scale_2=1.0" \ -F "boundary_ratio=0.875" \ -F "flow_shift=12.0" \ + -F 'extra_params={"sample_solver":"euler"}' \ -F "seed=42" ``` +`sample_solver` is supported by Wan2.2 online serving through the existing `extra_params` field, which is merged into the pipeline `extra_args`. Use `unipc` for the default multistep solver, or `euler` for Lightning/Distill checkpoints. + ## Create Response Format `POST /v1/videos` returns a job record, not inline base64 video data. diff --git a/examples/online_serving/image_to_video/run_curl_image_to_video.sh b/examples/online_serving/image_to_video/run_curl_image_to_video.sh index f4c1496a69a..6f6a6f96d59 100644 --- a/examples/online_serving/image_to_video/run_curl_image_to_video.sh +++ b/examples/online_serving/image_to_video/run_curl_image_to_video.sh @@ -7,6 +7,7 @@ INPUT_IMAGE="${INPUT_IMAGE:-../../offline_inference/image_to_video/qwen-bear.png BASE_URL="${BASE_URL:-http://localhost:8099}" OUTPUT_PATH="${OUTPUT_PATH:-wan22_i2v_output.mp4}" NEGATIVE_PROMPT="${NEGATIVE_PROMPT:-}" +SAMPLE_SOLVER="${SAMPLE_SOLVER:-}" POLL_INTERVAL="${POLL_INTERVAL:-2}" if [ ! -f "$INPUT_IMAGE" ]; then @@ -34,6 +35,10 @@ if [ -n "${NEGATIVE_PROMPT}" ]; then create_cmd+=(-F "negative_prompt=${NEGATIVE_PROMPT}") fi +if [ -n "${SAMPLE_SOLVER}" ]; then + create_cmd+=(-F "extra_params={\"sample_solver\":\"${SAMPLE_SOLVER}\"}") +fi + create_response="$("${create_cmd[@]}")" video_id="$(echo "${create_response}" | jq -r '.id')" if [ -z "${video_id}" ] || [ "${video_id}" = "null" ]; then diff --git a/tests/entrypoints/openai_api/test_video_server.py b/tests/entrypoints/openai_api/test_video_server.py index 7200b38abb8..d0d968fcbfc 100644 --- a/tests/entrypoints/openai_api/test_video_server.py +++ b/tests/entrypoints/openai_api/test_video_server.py @@ -737,6 +737,28 @@ def test_extra_params_merged_with_existing_extra_args(test_client, mocker: Mocke assert captured.extra_args["zero_steps"] == 2 +def test_sample_solver_forwarded_via_extra_params(test_client, mocker: MockerFixture): + """sample_solver can be passed through existing extra_params for Wan2.2 online serving.""" + mocker.patch( + "vllm_omni.entrypoints.openai.serving_video.encode_video_base64", + return_value="Zg==", + ) + response = test_client.post( + "/v1/videos", + data={ + "prompt": "A fox running through snow.", + "extra_params": json.dumps({"sample_solver": "euler"}), + }, + ) + + assert response.status_code == 200 + video_id = response.json()["id"] + _wait_for_status(test_client, video_id, VideoGenerationStatus.COMPLETED.value) + engine = test_client.app.state.openai_serving_video._engine_client + captured = engine.captured_sampling_params_list[0] + assert captured.extra_args["sample_solver"] == "euler" + + # --------------------------------------------------------------------------- # Sync endpoint tests (POST /v1/videos/sync) # --------------------------------------------------------------------------- diff --git a/tools/wan22/assemble_wan22_i2v_diffusers.py b/tools/wan22/assemble_wan22_i2v_diffusers.py new file mode 100644 index 00000000000..8e14ca3c26d --- /dev/null +++ b/tools/wan22/assemble_wan22_i2v_diffusers.py @@ -0,0 +1,385 @@ +#!/usr/bin/env python3 +""" +Assemble a Wan2.2-I2V-A14B-Diffusers-style model directory using a Diffusers +skeleton and optional replacement transformer checkpoints. + +This tool does NOT run any external conversion step. You can use it in two +ways: +- keep the original weights from the Diffusers skeleton +- replace transformer/transformer_2 with converted checkpoints such as + LightX2V outputs +- use legacy LightX2V arg names (--high-noise-weight/--low-noise-weight), + which are accepted as aliases + +Typical use: + python tools/wan22/assemble_wan22_i2v_diffusers.py \ + --diffusers-skeleton /path/to/Wan2.2-I2V-A14B-Diffusers \ + --transformer-weight /path/to/high_noise_out/diffusion_pytorch_model.safetensors \ + --transformer-2-weight /path/to/low_noise_out/diffusion_pytorch_model.safetensors \ + --output-dir /path/to/Wan2.2-I2V-A14B-Custom-Diffusers +""" + +from __future__ import annotations + +import argparse +import json +import shutil +import sys +from dataclasses import dataclass +from pathlib import Path + +WEIGHT_CANDIDATES = ( + "diffusion_pytorch_model.safetensors", + "diffusion_pytorch_model.bin", + "diffusion_pytorch_model.pt", + "model.safetensors", + "pytorch_model.bin", + "model.pt", +) +WEIGHT_INDEX_CANDIDATES = ( + "diffusion_pytorch_model.safetensors.index.json", + "model.safetensors.index.json", + "pytorch_model.bin.index.json", +) + +ROOT_REQUIRED_FILES = ("model_index.json",) +ROOT_REQUIRED_DIRS = ("tokenizer", "text_encoder", "vae", "transformer", "transformer_2") +OPTIONAL_DIRS = ("image_encoder", "image_processor", "scheduler", "feature_extractor") + + +class AssembleError(RuntimeError): + pass + + +@dataclass(frozen=True) +class WeightSpec: + kind: str # "single" | "sharded" + single_file: Path | None = None + index_file: Path | None = None + shard_files: tuple[Path, ...] = () + + +def _load_shard_files_from_index(index_file: Path, role: str) -> tuple[Path, ...]: + try: + with index_file.open(encoding="utf-8") as f: + payload = json.load(f) + except Exception as exc: + raise AssembleError(f"Failed to parse {role} index file: {index_file}. error={exc}") from exc + + weight_map = payload.get("weight_map") + if not isinstance(weight_map, dict) or not weight_map: + raise AssembleError(f"Invalid {role} index file (missing/empty weight_map): {index_file}") + + shard_names = sorted({str(v) for v in weight_map.values()}) + shard_paths: list[Path] = [] + missing: list[str] = [] + for shard_name in shard_names: + shard_path = index_file.parent / shard_name + if not shard_path.is_file(): + missing.append(str(shard_path)) + else: + shard_paths.append(shard_path) + + if missing: + raise AssembleError(f"{role} index references missing shard file(s): " + ", ".join(missing)) + + if not shard_paths: + raise AssembleError(f"No shard files referenced by {role} index: {index_file}") + + return tuple(shard_paths) + + +def _resolve_weight_spec(path: Path, role: str) -> WeightSpec: + if path.is_file(): + return WeightSpec(kind="single", single_file=path) + + if path.is_dir(): + for name in WEIGHT_CANDIDATES: + candidate = path / name + if candidate.is_file(): + return WeightSpec(kind="single", single_file=candidate) + + for index_name in WEIGHT_INDEX_CANDIDATES: + index_file = path / index_name + if not index_file.is_file(): + continue + shard_files = _load_shard_files_from_index(index_file, role=role) + return WeightSpec( + kind="sharded", + index_file=index_file, + shard_files=shard_files, + ) + + shard_candidates = sorted(path.glob("diffusion_pytorch_model-*.safetensors")) + if shard_candidates: + raise AssembleError( + f"Detected sharded {role} files under {path}, but index json is missing. " + f"Expected one of: {', '.join(WEIGHT_INDEX_CANDIDATES)}" + ) + + raise AssembleError( + f"Cannot find {role} weight under directory: {path}. " + f"Expected one of single files [{', '.join(WEIGHT_CANDIDATES)}] " + f"or sharded index files [{', '.join(WEIGHT_INDEX_CANDIDATES)}]." + ) + + raise AssembleError(f"{role} path does not exist: {path}") + + +def _canonical_weight_name(weight_file: Path) -> str: + suffix = weight_file.suffix.lower() + if suffix == ".safetensors": + return "diffusion_pytorch_model.safetensors" + if suffix == ".bin": + return "diffusion_pytorch_model.bin" + if suffix == ".pt": + return "diffusion_pytorch_model.pt" + return weight_file.name + + +def _validate_skeleton(skeleton: Path) -> None: + if not skeleton.is_dir(): + raise AssembleError(f"--diffusers-skeleton is not a directory: {skeleton}") + + for file_name in ROOT_REQUIRED_FILES: + if not (skeleton / file_name).is_file(): + raise AssembleError(f"Missing required file in skeleton: {skeleton / file_name}") + + for dir_name in ROOT_REQUIRED_DIRS: + if not (skeleton / dir_name).is_dir(): + raise AssembleError(f"Missing required directory in skeleton: {skeleton / dir_name}") + + if not (skeleton / "transformer" / "config.json").is_file(): + raise AssembleError(f"Missing transformer config: {skeleton / 'transformer/config.json'}") + + if not (skeleton / "transformer_2" / "config.json").is_file(): + raise AssembleError(f"Missing transformer_2 config: {skeleton / 'transformer_2/config.json'}") + + +def _ensure_clean_output(output_dir: Path, overwrite: bool) -> None: + if output_dir.exists(): + if not overwrite: + raise AssembleError( + f"Output directory already exists: {output_dir}. Use --overwrite to remove and recreate it." + ) + shutil.rmtree(output_dir) + output_dir.mkdir(parents=True, exist_ok=False) + + +def _copy_or_link_dir(src: Path, dst: Path, asset_mode: str) -> None: + if asset_mode == "copy": + shutil.copytree(src, dst) + elif asset_mode == "symlink": + dst.symlink_to(src, target_is_directory=True) + else: + raise AssembleError(f"Unknown asset mode: {asset_mode}") + + +def _materialize_weight(weight: WeightSpec, dst_dir: Path, role: str) -> tuple[Path, ...]: + if weight.kind == "single": + assert weight.single_file is not None + dst = dst_dir / _canonical_weight_name(weight.single_file) + shutil.copy2(weight.single_file, dst) + return (dst,) + + if weight.kind == "sharded": + assert weight.index_file is not None + copied: list[Path] = [] + index_dst = dst_dir / weight.index_file.name + shutil.copy2(weight.index_file, index_dst) + copied.append(index_dst) + for shard_file in weight.shard_files: + shard_dst = dst_dir / shard_file.name + shutil.copy2(shard_file, shard_dst) + copied.append(shard_dst) + return tuple(copied) + + raise AssembleError(f"Unknown {role} weight kind: {weight.kind}") + + +def _assemble( + skeleton: Path, + output_dir: Path, + transformer_weight: WeightSpec, + transformer_2_weight: WeightSpec, + asset_mode: str, +) -> tuple[tuple[Path, ...], tuple[Path, ...]]: + shutil.copy2(skeleton / "model_index.json", output_dir / "model_index.json") + + for dir_name in ROOT_REQUIRED_DIRS: + if dir_name in ("transformer", "transformer_2"): + continue + _copy_or_link_dir(skeleton / dir_name, output_dir / dir_name, asset_mode) + + for dir_name in OPTIONAL_DIRS: + src_dir = skeleton / dir_name + if src_dir.is_dir(): + _copy_or_link_dir(src_dir, output_dir / dir_name, asset_mode) + + (output_dir / "transformer").mkdir(parents=True, exist_ok=True) + (output_dir / "transformer_2").mkdir(parents=True, exist_ok=True) + + shutil.copy2(skeleton / "transformer" / "config.json", output_dir / "transformer" / "config.json") + shutil.copy2(skeleton / "transformer_2" / "config.json", output_dir / "transformer_2" / "config.json") + + transformer_copied = _materialize_weight(transformer_weight, output_dir / "transformer", role="transformer") + transformer_2_copied = _materialize_weight( + transformer_2_weight, + output_dir / "transformer_2", + role="transformer_2", + ) + + return transformer_copied, transformer_2_copied + + +def _validate_output( + output_dir: Path, + transformer_copied: tuple[Path, ...], + transformer_2_copied: tuple[Path, ...], +) -> None: + if not (output_dir / "model_index.json").is_file(): + raise AssembleError("Output validation failed: model_index.json missing") + + required_paths = ( + output_dir / "tokenizer", + output_dir / "text_encoder", + output_dir / "vae", + output_dir / "transformer" / "config.json", + output_dir / "transformer_2" / "config.json", + *transformer_copied, + *transformer_2_copied, + ) + missing = [str(p) for p in required_paths if not p.exists()] + if missing: + raise AssembleError("Output validation failed, missing: " + ", ".join(missing)) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Assemble a Wan2.2-I2V-A14B-Diffusers directory while optionally " + "replacing transformer and transformer_2 weights." + ) + ) + parser.add_argument( + "--diffusers-skeleton", + type=Path, + required=True, + help="Path to a local Wan-AI/Wan2.2-I2V-A14B-Diffusers directory.", + ) + parser.add_argument( + "--transformer-weight", + type=Path, + help=( + "Optional checkpoint file, or directory containing either a single-file " + "weight or sharded index+shards for transformer/. If omitted, keep the " + "skeleton's original transformer weights." + ), + ) + parser.add_argument( + "--transformer-2-weight", + type=Path, + help=( + "Optional checkpoint file, or directory containing either a single-file " + "weight or sharded index+shards for transformer_2/. If omitted, keep the " + "skeleton's original transformer_2 weights." + ), + ) + parser.add_argument( + "--high-noise-weight", + type=Path, + help=argparse.SUPPRESS, + ) + parser.add_argument( + "--low-noise-weight", + type=Path, + help=argparse.SUPPRESS, + ) + parser.add_argument( + "--output-dir", + type=Path, + required=True, + help="Output directory for the assembled model.", + ) + parser.add_argument( + "--asset-mode", + choices=("symlink", "copy"), + default="symlink", + help=( + "How to materialize non-transformer assets (tokenizer/text_encoder/vae/optional dirs). " + "symlink saves disk and is default." + ), + ) + parser.add_argument( + "--overwrite", + action="store_true", + help="Overwrite output-dir if it exists.", + ) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + + skeleton = args.diffusers_skeleton.resolve() + output_dir = args.output_dir.resolve() + + if args.transformer_weight is not None and args.high_noise_weight is not None: + print( + "[ERROR] --transformer-weight and --high-noise-weight are aliases; please provide only one.", + file=sys.stderr, + ) + return 2 + if args.transformer_2_weight is not None and args.low_noise_weight is not None: + print( + "[ERROR] --transformer-2-weight and --low-noise-weight are aliases; please provide only one.", + file=sys.stderr, + ) + return 2 + + transformer_weight_arg = args.transformer_weight if args.transformer_weight is not None else args.high_noise_weight + transformer_2_weight_arg = ( + args.transformer_2_weight if args.transformer_2_weight is not None else args.low_noise_weight + ) + + transformer_input = ( + transformer_weight_arg.resolve() if transformer_weight_arg is not None else skeleton / "transformer" + ) + transformer_2_input = ( + transformer_2_weight_arg.resolve() if transformer_2_weight_arg is not None else skeleton / "transformer_2" + ) + + try: + _validate_skeleton(skeleton) + transformer_weight = _resolve_weight_spec(transformer_input, role="transformer") + transformer_2_weight = _resolve_weight_spec(transformer_2_input, role="transformer_2") + + _ensure_clean_output(output_dir, overwrite=args.overwrite) + transformer_copied, transformer_2_copied = _assemble( + skeleton=skeleton, + output_dir=output_dir, + transformer_weight=transformer_weight, + transformer_2_weight=transformer_2_weight, + asset_mode=args.asset_mode, + ) + _validate_output(output_dir, transformer_copied, transformer_2_copied) + except AssembleError as exc: + print(f"[ERROR] {exc}", file=sys.stderr) + return 2 + + def _weight_summary(copied: tuple[Path, ...]) -> str: + if len(copied) == 1: + return copied[0].name + return f"{copied[0].name} + {len(copied) - 1} shard files" + + print("[OK] Assembled Wan2.2 I2V Diffusers directory:") + print(f" output_dir: {output_dir}") + print(f" transformer weight: {_weight_summary(transformer_copied)}") + print(f" transformer_2 weight: {_weight_summary(transformer_2_copied)}") + print("\nUse it with vLLM-Omni, for example:") + print(f" vllm serve {output_dir} --omni --port 8091") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index a550e576f01..84d89619e86 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -24,6 +24,7 @@ from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin, _is_rank_zero from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler +from vllm_omni.diffusion.models.wan2_2.scheduling_wan_euler import WanEulerScheduler from vllm_omni.diffusion.models.wan2_2.wan2_2_transformer import WanTransformer3DModel from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest @@ -32,6 +33,46 @@ logger = logging.getLogger(__name__) DEBUG_PERF = False +WAN_SAMPLE_SOLVER_CHOICES = {"unipc", "euler"} + + +def build_wan_scheduler(sample_solver: str, flow_shift: float) -> Any: + if sample_solver == "unipc": + return FlowUniPCMultistepScheduler( + num_train_timesteps=1000, + shift=flow_shift, + prediction_type="flow_prediction", + ) + if sample_solver == "euler": + return WanEulerScheduler( + num_train_timesteps=1000, + shift=flow_shift, + ) + + raise ValueError( + f"Unsupported Wan sample_solver: {sample_solver}. Expected one of: {sorted(WAN_SAMPLE_SOLVER_CHOICES)}" + ) + + +def resolve_wan_sample_solver(req: OmniDiffusionRequest, default: str = "unipc") -> str: + extra_args = getattr(req.sampling_params, "extra_args", {}) or {} + raw = extra_args.get("sample_solver", default) + sample_solver = str(raw).strip().lower() + if sample_solver not in WAN_SAMPLE_SOLVER_CHOICES: + raise ValueError(f"Invalid sample_solver={raw!r}. Expected one of: {sorted(WAN_SAMPLE_SOLVER_CHOICES)}") + return sample_solver + + +def resolve_wan_flow_shift(req: OmniDiffusionRequest, od_config: OmniDiffusionConfig) -> float: + extra_args = getattr(req.sampling_params, "extra_args", {}) or {} + raw_flow_shift = extra_args.get("flow_shift") + if raw_flow_shift is None: + raw_flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 + + try: + return float(raw_flow_shift) + except (TypeError, ValueError) as exc: + raise ValueError(f"Invalid flow_shift={raw_flow_shift!r}. flow_shift must be a float.") from exc def retrieve_latents( @@ -296,13 +337,9 @@ def __init__( else: raise RuntimeError("No transformer loaded") - # Initialize UniPC scheduler - flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 # default for 720p - self.scheduler = FlowUniPCMultistepScheduler( - num_train_timesteps=1000, - shift=flow_shift, - prediction_type="flow_prediction", - ) + self._sample_solver = "unipc" + self._flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 + self.scheduler = build_wan_scheduler(self._sample_solver, self._flow_shift) self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 @@ -462,6 +499,13 @@ def forward( current_omni_platform.synchronize() _t_text_enc_ms = (time.perf_counter() - _t_text_enc_start) * 1000 + sample_solver = resolve_wan_sample_solver(req, default=self._sample_solver) + flow_shift = resolve_wan_flow_shift(req, self.od_config) + if sample_solver != self._sample_solver or abs(flow_shift - self._flow_shift) > 1e-6: + self.scheduler = build_wan_scheduler(sample_solver, flow_shift) + self._sample_solver = sample_solver + self._flow_shift = flow_shift + # Timesteps self.scheduler.set_timesteps(num_steps, device=device) timesteps = self.scheduler.timesteps diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py index c05ecc9c9a2..46484cd789d 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py @@ -24,10 +24,12 @@ from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.interface import SupportImageInput from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin, _is_rank_zero -from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import ( + build_wan_scheduler, create_transformer_from_config, load_transformer_config, + resolve_wan_flow_shift, + resolve_wan_sample_solver, retrieve_latents, ) from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin @@ -230,13 +232,9 @@ def __init__( else: self.transformer_2 = None - # Initialize UniPC scheduler - flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 # default for 720p - self.scheduler = FlowUniPCMultistepScheduler( - num_train_timesteps=1000, - shift=flow_shift, - prediction_type="flow_prediction", - ) + self._sample_solver = "unipc" + self._flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 + self.scheduler = build_wan_scheduler(self._sample_solver, self._flow_shift) # VAE scale factors self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if hasattr(self.vae, "config") else 4 @@ -440,6 +438,13 @@ def forward( current_omni_platform.synchronize() _t_img_enc_ms = (time.perf_counter() - _t_img_enc_start) * 1000 + sample_solver = resolve_wan_sample_solver(req, default=self._sample_solver) + flow_shift = resolve_wan_flow_shift(req, self.od_config) + if sample_solver != self._sample_solver or abs(flow_shift - self._flow_shift) > 1e-6: + self.scheduler = build_wan_scheduler(sample_solver, flow_shift) + self._sample_solver = sample_solver + self._flow_shift = flow_shift + # Timesteps self.scheduler.set_timesteps(num_steps, device=device) timesteps = self.scheduler.timesteps diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py index 261f62fb798..939fe294a33 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py @@ -36,10 +36,12 @@ from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.interface import SupportImageInput from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin -from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import ( + build_wan_scheduler, create_transformer_from_config, load_transformer_config, + resolve_wan_flow_shift, + resolve_wan_sample_solver, retrieve_latents, ) from vllm_omni.diffusion.request import OmniDiffusionRequest @@ -183,13 +185,9 @@ def __init__( transformer_config = load_transformer_config(model, "transformer", local_files_only) self.transformer = create_transformer_from_config(transformer_config) - # Initialize UniPC scheduler - flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 # default for 720p - self.scheduler = FlowUniPCMultistepScheduler( - num_train_timesteps=1000, - shift=flow_shift, - prediction_type="flow_prediction", - ) + self._sample_solver = "unipc" + self._flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 + self.scheduler = build_wan_scheduler(self._sample_solver, self._flow_shift) # VAE scale factors self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if hasattr(self.vae, "config") else 4 @@ -323,6 +321,13 @@ def forward( batch_size = prompt_embeds.shape[0] + sample_solver = resolve_wan_sample_solver(req, default=self._sample_solver) + flow_shift = resolve_wan_flow_shift(req, self.od_config) + if sample_solver != self._sample_solver or abs(flow_shift - self._flow_shift) > 1e-6: + self.scheduler = build_wan_scheduler(sample_solver, flow_shift) + self._sample_solver = sample_solver + self._flow_shift = flow_shift + # Timesteps self.scheduler.set_timesteps(num_steps, device=device) timesteps = self.scheduler.timesteps diff --git a/vllm_omni/diffusion/models/wan2_2/scheduling_wan_euler.py b/vllm_omni/diffusion/models/wan2_2/scheduling_wan_euler.py new file mode 100644 index 00000000000..25444044c2d --- /dev/null +++ b/vllm_omni/diffusion/models/wan2_2/scheduling_wan_euler.py @@ -0,0 +1,147 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +from dataclasses import dataclass +from types import SimpleNamespace + +import numpy as np +import torch + + +@dataclass +class WanEulerSchedulerOutput: + prev_sample: torch.FloatTensor + + +def _unsqueeze_to_ndim(in_tensor: torch.Tensor, target_ndim: int) -> torch.Tensor: + if in_tensor.ndim >= target_ndim: + return in_tensor + return in_tensor[(...,) + (None,) * (target_ndim - in_tensor.ndim)] + + +def _get_timesteps(num_steps: int, max_steps: int = 1000) -> np.ndarray: + # Keep num_steps + 1 points so Euler update can always access sigma_next. + return np.linspace(max_steps, 0, num_steps + 1, dtype=np.float32) + + +def _timestep_shift(timesteps: torch.Tensor, shift: float = 1.0) -> torch.Tensor: + return shift * timesteps / (1 + (shift - 1) * timesteps) + + +class WanEulerScheduler: + order = 1 + + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + device: torch.device | str = "cpu", + ) -> None: + self.num_train_timesteps = int(num_train_timesteps) + self._shift = float(shift) + self.device = device + self.config = SimpleNamespace(num_train_timesteps=self.num_train_timesteps) + self.init_noise_sigma = 1.0 + + self._step_index: int | None = None + self._begin_index: int | None = None + + self.timesteps = torch.empty(0, dtype=torch.float32) + self.sigmas = torch.empty(0, dtype=torch.float32) + self.timesteps_ori = torch.empty(0, dtype=torch.float32) + + self.set_timesteps(num_inference_steps=self.num_train_timesteps, device=self.device) + + @property + def step_index(self) -> int | None: + return self._step_index + + @property + def begin_index(self) -> int | None: + return self._begin_index + + def set_begin_index(self, begin_index: int = 0) -> None: + self._begin_index = int(begin_index) + + def index_for_timestep(self, timestep: torch.Tensor) -> int: + indices = (self.timesteps == timestep).nonzero() + if len(indices) > 0: + pos = 1 if len(indices) > 1 else 0 + return int(indices[pos].item()) + # Fallback for tiny float drift + return int(torch.argmin(torch.abs(self.timesteps - timestep)).item()) + + def _init_step_index(self, timestep: float | torch.Tensor) -> None: + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep_t = timestep.to(self.timesteps.device, dtype=self.timesteps.dtype) + else: + timestep_t = torch.tensor(timestep, device=self.timesteps.device, dtype=self.timesteps.dtype) + self._step_index = self.index_for_timestep(timestep_t) + else: + self._step_index = self._begin_index + + def set_shift(self, shift: float = 1.0) -> None: + # Compute shifted sigma schedule on [0, 1]. + sigmas_full = self.timesteps_ori / float(self.num_train_timesteps) + sigmas_full = _timestep_shift(sigmas_full, shift=float(shift)) + self.sigmas = sigmas_full + # Public timesteps are the first N points; next point is consumed as sigma_next. + self.timesteps = self.sigmas[:-1] * self.num_train_timesteps + self._shift = float(shift) + + def set_timesteps( + self, + num_inference_steps: int, + device: torch.device | str | int | None = None, + **kwargs, # noqa: ARG002 - kept for scheduler API compatibility + ) -> None: + timesteps = _get_timesteps( + num_steps=int(num_inference_steps), + max_steps=self.num_train_timesteps, + ) + self.timesteps_ori = torch.from_numpy(timesteps).to( + dtype=torch.float32, + device=device or self.device, + ) + self.set_shift(self._shift) + self._step_index = None + self._begin_index = None + + def scale_model_input(self, sample: torch.Tensor, timestep: int | None = None) -> torch.Tensor: # noqa: ARG002 + return sample + + def step( + self, + model_output: torch.FloatTensor, + timestep: float | torch.FloatTensor, + sample: torch.FloatTensor, + return_dict: bool = True, + **kwargs, # noqa: ARG002 - kept for scheduler API compatibility + ) -> WanEulerSchedulerOutput | tuple[torch.FloatTensor]: + if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)): + raise ValueError( + "Passing integer indices as timesteps is not supported. Use one value from scheduler.timesteps instead." + ) + + if self.step_index is None: + self._init_step_index(timestep) + assert self._step_index is not None + + sample_fp32 = sample.to(torch.float32) + sigma = _unsqueeze_to_ndim(self.sigmas[self._step_index], sample_fp32.ndim).to(sample_fp32.device) + sigma_next = _unsqueeze_to_ndim(self.sigmas[self._step_index + 1], sample_fp32.ndim).to(sample_fp32.device) + + prev_sample = sample_fp32 + (sigma_next - sigma) * model_output + prev_sample = prev_sample.to(model_output.dtype) + + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + return WanEulerSchedulerOutput(prev_sample=prev_sample) + + def __len__(self) -> int: + return self.num_train_timesteps diff --git a/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py b/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py index 65a2d4390ae..3b43f3eaf51 100644 --- a/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py +++ b/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py @@ -1015,6 +1015,14 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: if ".to_out.0." in lookup_name: lookup_name = lookup_name.replace(".to_out.0.", ".to_out.") + # Compatibility: some Wan conversion pipelines still keep + # block modulation keys as `blocks.N.modulation` instead of + # `blocks.N.scale_shift_table`. + if lookup_name.endswith(".modulation"): + modulation_alias = lookup_name[: -len(".modulation")] + ".scale_shift_table" + if modulation_alias in params_dict: + lookup_name = modulation_alias + if lookup_name not in params_dict: logger.warning(f"Skipping weight {original_name} -> {lookup_name}") continue diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py index f7e7d53d58b..50882c59669 100644 --- a/vllm_omni/engine/async_omni_engine.py +++ b/vllm_omni/engine/async_omni_engine.py @@ -935,6 +935,8 @@ def _create_default_diffusion_stage_cfg(kwargs: dict[str, Any]) -> list: "enable_cpu_offload": kwargs.get("enable_cpu_offload", False), "enable_layerwise_offload": kwargs.get("enable_layerwise_offload", False), "enforce_eager": kwargs.get("enforce_eager", False), + "boundary_ratio": kwargs.get("boundary_ratio", None), + "flow_shift": kwargs.get("flow_shift", None), "diffusion_load_format": kwargs.get("diffusion_load_format", "default"), "custom_pipeline_args": kwargs.get("custom_pipeline_args", None), "worker_extension_cls": kwargs.get("worker_extension_cls", None),