diff --git a/examples/offline_inference/text_to_video/run_t2v.sh b/examples/offline_inference/text_to_video/run_t2v.sh new file mode 100755 index 00000000000..a1b97e0f0d7 --- /dev/null +++ b/examples/offline_inference/text_to_video/run_t2v.sh @@ -0,0 +1,19 @@ +export VLLM_OMNI_WAN_DUMMY_TEXT_ENCODER=1 +export VLLM_OMNI_SKIP_DUMMY_RUN=1 +export VLLM_OMNI_WAN_PROFILE_TRANSFORMER_ONLY=1 +export VLLM_TORCH_PROFILER_DIR="./" + +MODEL_PATH="/mnt/disk2/hf_models/Wan2.1-T2V-1.3B-Diffusers" +MODEL_PATH="/mnt/disk2/hf_models/Wan2.2-T2V-A14B-Diffusers" + +python text_to_video.py \ + --model "$MODEL_PATH" \ + --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ + --height 480 --width 832 --num-frames 9 \ + --guidance-scale 1.0 --guidance-scale-high 1.0 \ + --boundary-ratio 0.0 --flow-shift 12.0 \ + --num-inference-steps 40 --fps 16 \ + --output t2v_out.mp4 \ + --enforce-eager + +# --profiler-config '{"profiler":"torch","torch_profiler_dir":"./perf"}' diff --git a/examples/offline_inference/text_to_video/text_to_video.py b/examples/offline_inference/text_to_video/text_to_video.py index a96949fa54c..d1bbf27cb45 100644 --- a/examples/offline_inference/text_to_video/text_to_video.py +++ b/examples/offline_inference/text_to_video/text_to_video.py @@ -319,6 +319,7 @@ def main(): if args.negative_prompt: prompt_dict["negative_prompt"] = args.negative_prompt + output_type = "latent" sampling_kwargs = dict( height=args.height, width=args.width, @@ -326,6 +327,7 @@ def main(): guidance_scale=args.guidance_scale, num_inference_steps=args.num_inference_steps, num_frames=args.num_frames, + output_type=output_type, ) if args.guidance_scale_high is not None: sampling_kwargs["guidance_scale_2"] = args.guidance_scale_high @@ -340,6 +342,12 @@ def main(): # Print profiling results print(f"Total generation time: {generation_time:.4f} seconds ({generation_time * 1000:.2f} ms)") + if output_type == "latent": + if profiler_enabled: + print("\n[Profiler] Stopping profiler and collecting results...") + profile_results = omni.stop_profile() + print(profile_results) + return audio = None if isinstance(frames, list): diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index abaf5989598..d0215385a9d 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -403,6 +403,10 @@ def profile(self, is_start: bool = True, profile_prefix: str | None = None) -> N raise RuntimeError(f"Could not {action} profiler: {e}") from e def _dummy_run(self): + if os.environ.get("VLLM_OMNI_SKIP_DUMMY_RUN", "0") == "1": + logger.warning("Skipping diffusion dummy run because VLLM_OMNI_SKIP_DUMMY_RUN=1") + return + """A dummy run to warm up the model.""" num_inference_steps = 1 height = 512 diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index 188eb70b2b2..34045a8aaa1 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -316,13 +316,25 @@ def __init__( ) ) - self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) - self.text_encoder = UMT5EncoderModel.from_pretrained( - model, subfolder="text_encoder", torch_dtype=dtype, local_files_only=local_files_only - ).to(self.device) - self.vae = DistributedAutoencoderKLWan.from_pretrained( - model, subfolder="vae", torch_dtype=dtype, local_files_only=local_files_only - ).to(self.device) + self.transformer_only_profile = os.environ.get("VLLM_OMNI_WAN_PROFILE_TRANSFORMER_ONLY", "0") == "1" + + if self.transformer_only_profile: + logger.warning( + "VLLM_OMNI_WAN_PROFILE_TRANSFORMER_ONLY=1: " + "skipping tokenizer, text_encoder and VAE loading. " + "This mode is only for transformer/operator profiling." + ) + self.tokenizer = None + self.text_encoder = None + self.vae = None + else: + self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) + self.text_encoder = UMT5EncoderModel.from_pretrained( + model, subfolder="text_encoder", torch_dtype=dtype, local_files_only=local_files_only + ).to(self.device) + self.vae = DistributedAutoencoderKLWan.from_pretrained( + model, subfolder="vae", torch_dtype=torch.float32, local_files_only=local_files_only + ).to(self.device) # Initialize transformers with correct config (weights loaded via load_weights) if load_transformer: @@ -510,6 +522,11 @@ def forward( width = req.sampling_params.width or width num_frames = req.sampling_params.num_frames if req.sampling_params.num_frames else frame_num + if req.sampling_params.output_type is not None: + output_type = req.sampling_params.output_type + if os.environ.get("VLLM_OMNI_WAN_PROFILE_TRANSFORMER_ONLY", "0") == "1": + output_type = "latent" + # Ensure dimensions are compatible with VAE and patch size # For expand_timesteps mode, we need latent dims to be even (divisible by patch_size) patch_size = self.transformer_config.patch_size @@ -821,6 +838,25 @@ def encode_prompt( prompt_clean = [self._prompt_clean(p) for p in prompt] batch_size = len(prompt_clean) + if ( + os.environ.get("VLLM_OMNI_WAN_DUMMY_TEXT_ENCODER", "0") == "1" + or os.environ.get("VLLM_OMNI_WAN_PROFILE_TRANSFORMER_ONLY", "0") == "1" + ): + text_dim = self.transformer_config.text_dim + prompt_embeds = torch.zeros( + batch_size * num_videos_per_prompt, + max_sequence_length, + text_dim, + device=device, + dtype=dtype, + ) + + negative_prompt_embeds = None + if do_classifier_free_guidance: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + + return prompt_embeds, negative_prompt_embeds + text_inputs = self.tokenizer( prompt_clean, padding="max_length", diff --git a/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py b/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py index 98e10b40879..0edd2214282 100644 --- a/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py +++ b/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py @@ -995,6 +995,11 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: if weight_name not in original_name: continue lookup_name = original_name.replace(weight_name, param_name) + + if lookup_name not in params_dict: + logger.warning(f"Skipping weight {original_name} -> {lookup_name}") + break + param = params_dict[lookup_name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id)