Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions examples/offline_inference/text_to_video/run_t2v.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
export VLLM_OMNI_WAN_DUMMY_TEXT_ENCODER=1
export VLLM_OMNI_SKIP_DUMMY_RUN=1
export VLLM_OMNI_WAN_PROFILE_TRANSFORMER_ONLY=1
export VLLM_TORCH_PROFILER_DIR="./"

MODEL_PATH="/mnt/disk2/hf_models/Wan2.1-T2V-1.3B-Diffusers"
MODEL_PATH="/mnt/disk2/hf_models/Wan2.2-T2V-A14B-Diffusers"

python text_to_video.py \
--model "$MODEL_PATH" \
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
--height 480 --width 832 --num-frames 9 \
--guidance-scale 1.0 --guidance-scale-high 1.0 \
--boundary-ratio 0.0 --flow-shift 12.0 \
--num-inference-steps 40 --fps 16 \
--output t2v_out.mp4 \
--enforce-eager

# --profiler-config '{"profiler":"torch","torch_profiler_dir":"./perf"}'
8 changes: 8 additions & 0 deletions examples/offline_inference/text_to_video/text_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,13 +319,15 @@ def main():
if args.negative_prompt:
prompt_dict["negative_prompt"] = args.negative_prompt

output_type = "latent"
sampling_kwargs = dict(
height=args.height,
width=args.width,
generator=generator,
guidance_scale=args.guidance_scale,
num_inference_steps=args.num_inference_steps,
num_frames=args.num_frames,
output_type=output_type,
)
if args.guidance_scale_high is not None:
sampling_kwargs["guidance_scale_2"] = args.guidance_scale_high
Expand All @@ -340,6 +342,12 @@ def main():

# Print profiling results
print(f"Total generation time: {generation_time:.4f} seconds ({generation_time * 1000:.2f} ms)")
if output_type == "latent":
if profiler_enabled:
print("\n[Profiler] Stopping profiler and collecting results...")
profile_results = omni.stop_profile()
print(profile_results)
return

audio = None
if isinstance(frames, list):
Expand Down
4 changes: 4 additions & 0 deletions vllm_omni/diffusion/diffusion_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,10 @@ def profile(self, is_start: bool = True, profile_prefix: str | None = None) -> N
raise RuntimeError(f"Could not {action} profiler: {e}") from e

def _dummy_run(self):
if os.environ.get("VLLM_OMNI_SKIP_DUMMY_RUN", "0") == "1":
logger.warning("Skipping diffusion dummy run because VLLM_OMNI_SKIP_DUMMY_RUN=1")
return

"""A dummy run to warm up the model."""
num_inference_steps = 1
height = 512
Expand Down
50 changes: 43 additions & 7 deletions vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,13 +316,25 @@ def __init__(
)
)

self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only)
self.text_encoder = UMT5EncoderModel.from_pretrained(
model, subfolder="text_encoder", torch_dtype=dtype, local_files_only=local_files_only
).to(self.device)
self.vae = DistributedAutoencoderKLWan.from_pretrained(
model, subfolder="vae", torch_dtype=dtype, local_files_only=local_files_only
).to(self.device)
self.transformer_only_profile = os.environ.get("VLLM_OMNI_WAN_PROFILE_TRANSFORMER_ONLY", "0") == "1"

if self.transformer_only_profile:
logger.warning(
"VLLM_OMNI_WAN_PROFILE_TRANSFORMER_ONLY=1: "
"skipping tokenizer, text_encoder and VAE loading. "
"This mode is only for transformer/operator profiling."
)
self.tokenizer = None
self.text_encoder = None
self.vae = None
else:
self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only)
self.text_encoder = UMT5EncoderModel.from_pretrained(
model, subfolder="text_encoder", torch_dtype=dtype, local_files_only=local_files_only
).to(self.device)
self.vae = DistributedAutoencoderKLWan.from_pretrained(
model, subfolder="vae", torch_dtype=torch.float32, local_files_only=local_files_only
).to(self.device)

# Initialize transformers with correct config (weights loaded via load_weights)
if load_transformer:
Expand Down Expand Up @@ -510,6 +522,11 @@ def forward(
width = req.sampling_params.width or width
num_frames = req.sampling_params.num_frames if req.sampling_params.num_frames else frame_num

if req.sampling_params.output_type is not None:
output_type = req.sampling_params.output_type
if os.environ.get("VLLM_OMNI_WAN_PROFILE_TRANSFORMER_ONLY", "0") == "1":
output_type = "latent"

# Ensure dimensions are compatible with VAE and patch size
# For expand_timesteps mode, we need latent dims to be even (divisible by patch_size)
patch_size = self.transformer_config.patch_size
Expand Down Expand Up @@ -821,6 +838,25 @@ def encode_prompt(
prompt_clean = [self._prompt_clean(p) for p in prompt]
batch_size = len(prompt_clean)

if (
os.environ.get("VLLM_OMNI_WAN_DUMMY_TEXT_ENCODER", "0") == "1"
or os.environ.get("VLLM_OMNI_WAN_PROFILE_TRANSFORMER_ONLY", "0") == "1"
):
text_dim = self.transformer_config.text_dim
prompt_embeds = torch.zeros(
batch_size * num_videos_per_prompt,
max_sequence_length,
text_dim,
device=device,
dtype=dtype,
)

negative_prompt_embeds = None
if do_classifier_free_guidance:
negative_prompt_embeds = torch.zeros_like(prompt_embeds)

return prompt_embeds, negative_prompt_embeds

text_inputs = self.tokenizer(
prompt_clean,
padding="max_length",
Expand Down
5 changes: 5 additions & 0 deletions vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,6 +995,11 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
if weight_name not in original_name:
continue
lookup_name = original_name.replace(weight_name, param_name)

if lookup_name not in params_dict:
logger.warning(f"Skipping weight {original_name} -> {lookup_name}")
break

param = params_dict[lookup_name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
Expand Down