diff --git a/examples/offline_inference/skyreels_v3/README.md b/examples/offline_inference/skyreels_v3/README.md new file mode 100644 index 00000000000..a6804451071 --- /dev/null +++ b/examples/offline_inference/skyreels_v3/README.md @@ -0,0 +1,144 @@ +# SkyReels-V3 Offline Inference Examples + +This directory contains examples for using the SkyReels-V3 multimodal video generation models with vLLM-Omni. + +## Models + +SkyReels-V3 is a family of multimodal video generation models that support: + +- **Image-to-Video (R2V)**: Generate videos from reference images +- **Video-to-Video (V2V)**: Transform existing videos +- **Audio-to-Video (A2V)**: Generate videos guided by audio + +### Available Models + +- `Skywork/SkyReels-V3-R2V-14B`: Image-to-Video (14B parameters) +- `Skywork/SkyReels-V3-V2V-14B`: Video-to-Video (14B parameters) +- `Skywork/SkyReels-V3-A2V-19B`: Audio-to-Video (19B parameters) + +## Installation + +Install the required dependencies: + +```bash +pip install vllm-omni +pip install imageio imageio-ffmpeg # For video I/O +``` + +## Usage + +### Image-to-Video (R2V) + +Generate a video from a reference image: + +```bash +python image_to_video.py \ + --model Skywork/SkyReels-V3-R2V-14B \ + --image path/to/your/image.jpg \ + --prompt "A person walking through a beautiful garden" \ + --height 480 \ + --width 832 \ + --num-frames 81 \ + --num-inference-steps 50 \ + --guidance-scale 7.5 \ + --seed 42 \ + --output-dir ./outputs/skyreels_v3 \ + --output-format mp4 +``` + +### Parameters + +- `--model`: Model name or path (default: `Skywork/SkyReels-V3-R2V-14B`) +- `--image`: Path to the reference image (required) +- `--prompt`: Text prompt describing the desired video +- `--negative-prompt`: Negative prompt to avoid certain content (optional) +- `--height`: Video height in pixels (default: 480) +- `--width`: Video width in pixels (default: 832) +- `--num-frames`: Number of frames to generate (default: 81) +- `--num-inference-steps`: Number of denoising steps (default: 50, higher = better quality but slower) +- `--guidance-scale`: Classifier-free guidance scale (default: 7.5, higher = more prompt adherence) +- `--seed`: Random seed for reproducibility (default: 42) +- `--output-dir`: Output directory for generated videos +- `--output-format`: Output format: `mp4`, `gif`, or `frames` + +## Examples + +### Basic Image-to-Video + +```bash +python image_to_video.py \ + --image examples/sample_image.jpg \ + --prompt "A cinematic video of the scene" +``` + +### High-Quality Generation + +```bash +python image_to_video.py \ + --image examples/sample_image.jpg \ + --prompt "A dramatic video with dynamic camera movement" \ + --num-inference-steps 100 \ + --guidance-scale 9.0 \ + --num-frames 121 +``` + +### Generate GIF + +```bash +python image_to_video.py \ + --image examples/sample_image.jpg \ + --prompt "A looping animation" \ + --output-format gif \ + --num-frames 49 +``` + +## Tips + +1. **Image Quality**: Use high-quality reference images for best results +2. **Aspect Ratio**: The model works best with 16:9 aspect ratio (e.g., 832x480) +3. **Frame Count**: More frames = longer videos but slower generation +4. **Guidance Scale**: + - Lower (3-5): More creative, less adherence to prompt + - Medium (7-9): Balanced + - Higher (10+): Strong prompt adherence, may reduce quality +5. **Inference Steps**: 50 steps is usually sufficient; 100+ for highest quality + +## Performance + +- **GPU Memory**: ~24GB VRAM required for R2V-14B model +- **Generation Time**: ~2-5 minutes for 81 frames on A100 GPU +- **Batch Size**: Currently supports batch size of 1 + +## Troubleshooting + +### Out of Memory + +If you encounter OOM errors: +- Reduce `--num-frames` +- Reduce `--height` and `--width` +- Use a smaller model variant if available + +### Poor Quality + +If the output quality is poor: +- Increase `--num-inference-steps` (try 75-100) +- Adjust `--guidance-scale` (try 8-10) +- Use a higher quality reference image +- Refine your prompt to be more specific + +## Citation + +If you use SkyReels-V3 in your research, please cite: + +```bibtex +@article{skyreels2025, + title={SkyReels-V3: Multimodal Video Generation with Unified In-Context Learning}, + author={Skywork Team}, + journal={arXiv preprint}, + year={2025} +} +``` + +## License + +SkyReels-V3 models are released under the Skywork License. Please refer to the model card on Hugging Face for details. diff --git a/examples/offline_inference/skyreels_v3/image_to_video.py b/examples/offline_inference/skyreels_v3/image_to_video.py new file mode 100644 index 00000000000..81b132cb277 --- /dev/null +++ b/examples/offline_inference/skyreels_v3/image_to_video.py @@ -0,0 +1,194 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +SkyReels-V3 Image-to-Video (R2V) Offline Inference Example. + +This script demonstrates how to use the SkyReels-V3 R2V model to generate +videos from reference images using the vLLM-Omni framework. + +Usage: + python image_to_video.py --model Skywork/SkyReels-V3-R2V-14B \ + --image path/to/image.jpg \ + --prompt "A person walking in the park" +""" + +import argparse +import os +from pathlib import Path + +from PIL import Image + +from vllm_omni.entrypoints.omni_diffusion import OmniDiffusion +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.outputs import OmniRequestOutput + + +def main(): + parser = argparse.ArgumentParser(description="SkyReels-V3 Image-to-Video Generation") + parser.add_argument( + "--model", + type=str, + default="Skywork/SkyReels-V3-R2V-14B", + help="Model name or path (default: Skywork/SkyReels-V3-R2V-14B)", + ) + parser.add_argument( + "--image", + type=str, + required=True, + help="Path to the reference image", + ) + parser.add_argument( + "--prompt", + type=str, + default="A cinematic video", + help="Text prompt describing the desired video", + ) + parser.add_argument( + "--negative-prompt", + type=str, + default="", + help="Negative prompt (optional)", + ) + parser.add_argument( + "--height", + type=int, + default=480, + help="Video height (default: 480)", + ) + parser.add_argument( + "--width", + type=int, + default=832, + help="Video width (default: 832)", + ) + parser.add_argument( + "--num-frames", + type=int, + default=81, + help="Number of frames to generate (default: 81)", + ) + parser.add_argument( + "--num-inference-steps", + type=int, + default=50, + help="Number of denoising steps (default: 50)", + ) + parser.add_argument( + "--guidance-scale", + type=float, + default=7.5, + help="Guidance scale for classifier-free guidance (default: 7.5)", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for reproducibility (default: 42)", + ) + parser.add_argument( + "--output-dir", + type=str, + default="./outputs/skyreels_v3", + help="Output directory for generated videos (default: ./outputs/skyreels_v3)", + ) + parser.add_argument( + "--output-format", + type=str, + default="mp4", + choices=["mp4", "gif", "frames"], + help="Output format: mp4, gif, or frames (default: mp4)", + ) + + args = parser.parse_args() + + # Create output directory + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Load reference image + if not os.path.exists(args.image): + raise FileNotFoundError(f"Image not found: {args.image}") + + image = Image.open(args.image).convert("RGB") + print(f"Loaded reference image: {args.image} ({image.size})") + + # Initialize the model + print(f"Loading SkyReels-V3 model: {args.model}") + model = OmniDiffusion( + model=args.model, + model_class_name="SkyReelsV3R2VPipeline", + trust_remote_code=True, + ) + + # Prepare the request + print(f"\nGenerating video with prompt: '{args.prompt}'") + print("Parameters:") + print(f" - Resolution: {args.width}x{args.height}") + print(f" - Frames: {args.num_frames}") + print(f" - Steps: {args.num_inference_steps}") + print(f" - Guidance Scale: {args.guidance_scale}") + print(f" - Seed: {args.seed}") + + # Generate video + outputs = model.generate( + prompts=[ + { + "prompt": args.prompt, + "multi_modal_data": {"image": image}, + } + ], + sampling_params=OmniDiffusionSamplingParams( + height=args.height, + width=args.width, + num_frames=args.num_frames, + num_inference_steps=args.num_inference_steps, + guidance_scale=args.guidance_scale, + seed=args.seed, + ), + ) + + # Save the generated video + for idx, output in enumerate(outputs): + # Extract video frames from OmniRequestOutput + video_frames = None + if isinstance(output, OmniRequestOutput): + # In diffusion mode, output.images is the full list of frames + if hasattr(output, "images") and output.images: + video_frames = output.images + else: + raise ValueError("No video data found in diffusion output.") + else: + raise TypeError(f"Unexpected output type: {type(output)}") + + if args.output_format == "mp4": + output_path = output_dir / f"video_{idx:04d}.mp4" + # Save as MP4 video + import imageio + + imageio.mimsave(output_path, video_frames, fps=24, codec="libx264") + print(f"\nSaved video to: {output_path}") + + elif args.output_format == "gif": + output_path = output_dir / f"video_{idx:04d}.gif" + # Save as GIF + import imageio + + imageio.mimsave(output_path, video_frames, fps=12) + print(f"\nSaved GIF to: {output_path}") + + else: # frames + frames_dir = output_dir / f"video_{idx:04d}_frames" + frames_dir.mkdir(exist_ok=True) + # Save individual frames + for frame_idx, frame in enumerate(video_frames): + frame_path = frames_dir / f"frame_{frame_idx:04d}.png" + Image.fromarray(frame).save(frame_path) + print(f"\nSaved {len(video_frames)} frames to: {frames_dir}") + + print("\nGeneration complete!") + + +if __name__ == "__main__": + main() diff --git a/vllm_omni/diffusion/models/skyreels_v3/__init__.py b/vllm_omni/diffusion/models/skyreels_v3/__init__.py new file mode 100644 index 00000000000..cd8cb042553 --- /dev/null +++ b/vllm_omni/diffusion/models/skyreels_v3/__init__.py @@ -0,0 +1,15 @@ +"""SkyReels-V3 multimodal video generation models.""" + +from .pipeline_skyreels_v3_r2v import ( + SkyReelsV3R2VPipeline, + get_skyreels_v3_r2v_post_process_func, + get_skyreels_v3_r2v_pre_process_func, +) +from .skyreels_v3_transformer import SkyReelsTransformer3DModel + +__all__ = [ + "SkyReelsV3R2VPipeline", + "get_skyreels_v3_r2v_post_process_func", + "get_skyreels_v3_r2v_pre_process_func", + "SkyReelsTransformer3DModel", +] diff --git a/vllm_omni/diffusion/models/skyreels_v3/pipeline_skyreels_v3_r2v.py b/vllm_omni/diffusion/models/skyreels_v3/pipeline_skyreels_v3_r2v.py new file mode 100644 index 00000000000..d1613da7885 --- /dev/null +++ b/vllm_omni/diffusion/models/skyreels_v3/pipeline_skyreels_v3_r2v.py @@ -0,0 +1,457 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +SkyReels-V3 Image-to-Video (R2V) Pipeline Implementation. + +This pipeline supports generating videos from reference images using the +SkyReels-V3 multimodal video generation model. +""" + +from __future__ import annotations + +import json +import logging +import os + +import numpy as np +import PIL.Image +import torch +from diffusers import AutoencoderKLWan +from diffusers.utils.torch_utils import randn_tensor +from torch import nn +from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler +from vllm_omni.diffusion.models.skyreels_v3.skyreels_v3_transformer import SkyReelsTransformer3DModel +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import OmniTextPrompt + +logger = logging.getLogger(__name__) + + +def retrieve_latents( + encoder_output: torch.Tensor, + generator: torch.Generator | None = None, + sample_mode: str = "sample", +): + """Retrieve latents from VAE encoder output.""" + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def load_transformer_config(model_path: str, subfolder: str = "transformer", local_files_only: bool = True) -> dict: + """Load transformer config from model directory or HF Hub.""" + if local_files_only: + config_path = os.path.join(model_path, subfolder, "config.json") + if os.path.exists(config_path): + with open(config_path) as f: + return json.load(f) + else: + try: + from huggingface_hub import hf_hub_download + + config_path = hf_hub_download( + repo_id=model_path, + filename=f"{subfolder}/config.json", + ) + with open(config_path) as f: + return json.load(f) + except Exception: + pass + return {} + + +def create_transformer_from_config(config: dict) -> SkyReelsTransformer3DModel: + """Create SkyReelsTransformer3DModel from config dict.""" + kwargs = {} + + if "patch_size" in config: + kwargs["patch_size"] = tuple(config["patch_size"]) + if "num_attention_heads" in config: + kwargs["num_attention_heads"] = config["num_attention_heads"] + if "attention_head_dim" in config: + kwargs["attention_head_dim"] = config["attention_head_dim"] + if "in_channels" in config: + kwargs["in_channels"] = config["in_channels"] + if "out_channels" in config: + kwargs["out_channels"] = config["out_channels"] + if "text_dim" in config: + kwargs["text_dim"] = config["text_dim"] + if "ffn_dim" in config: + kwargs["ffn_dim"] = config["ffn_dim"] + if "num_layers" in config: + kwargs["num_layers"] = config["num_layers"] + if "cross_attn_norm" in config: + kwargs["cross_attn_norm"] = config["cross_attn_norm"] + if "eps" in config: + kwargs["eps"] = config["eps"] + if "image_dim" in config: + kwargs["image_dim"] = config["image_dim"] + if "max_seq_len" in config: + kwargs["max_seq_len"] = config["max_seq_len"] + + return SkyReelsTransformer3DModel(**kwargs) + + +def get_skyreels_v3_r2v_post_process_func( + od_config: OmniDiffusionConfig, +): + """Post-process function for R2V: convert latents to video frames.""" + from diffusers.video_processor import VideoProcessor + + video_processor = VideoProcessor(vae_scale_factor=8) + + def post_process_func( + video: torch.Tensor, + output_type: str = "np", + ): + if output_type == "latent": + return video + return video_processor.postprocess_video(video, output_type=output_type) + + return post_process_func + + +def get_skyreels_v3_r2v_pre_process_func( + od_config: OmniDiffusionConfig, +): + """Pre-process function for R2V: load and resize input image.""" + from diffusers.video_processor import VideoProcessor + + video_processor = VideoProcessor(vae_scale_factor=8) + + def pre_process_func(request: OmniDiffusionRequest) -> OmniDiffusionRequest: + for i, prompt in enumerate(request.prompts): + multi_modal_data = prompt.get("multi_modal_data", {}) if not isinstance(prompt, str) else None + raw_image = multi_modal_data.get("image", None) if multi_modal_data is not None else None + if isinstance(prompt, str): + prompt = OmniTextPrompt(prompt=prompt) + if "additional_information" not in prompt: + prompt["additional_information"] = {} + + if raw_image is None: + raise ValueError( + "No image is provided. This model requires an image to run. " + 'Please correctly set `"multi_modal_data": {"image": , …}`' + ) + if not isinstance(raw_image, (str, PIL.Image.Image)): + raise TypeError( + f"Unsupported image format {raw_image.__class__}. " + 'Please correctly set `"multi_modal_data": {"image": , …}`' + ) + image = PIL.Image.open(raw_image).convert("RGB") if isinstance(raw_image, str) else raw_image + + # Calculate dimensions based on aspect ratio if not provided + if request.sampling_params.height is None or request.sampling_params.width is None: + # Default max area for 480P + max_area = 480 * 832 + aspect_ratio = image.height / image.width + + # Calculate dimensions maintaining aspect ratio + mod_value = 16 # Must be divisible by 16 + height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + + if request.sampling_params.height is None: + request.sampling_params.height = height + if request.sampling_params.width is None: + request.sampling_params.width = width + + # Resize image to target dimensions + image = image.resize( + (request.sampling_params.width, request.sampling_params.height), # type: ignore + PIL.Image.Resampling.LANCZOS, + ) + prompt["multi_modal_data"]["image"] = image # type: ignore + + # Preprocess for VAE + prompt["additional_information"]["preprocessed_image"] = video_processor.preprocess( + image, height=request.sampling_params.height, width=request.sampling_params.width + ) + request.prompts[i] = prompt + return request + + return pre_process_func + + +class SkyReelsV3R2VPipeline(nn.Module, SupportImageInput): + """ + SkyReels-V3 Image-to-Video (R2V) Pipeline. + + Generates videos from reference images using text prompts. + """ + + def __init__( + self, + od_config: OmniDiffusionConfig, + ): + super().__init__() + self.od_config = od_config + model = od_config.model + + # Load model components + loader = DiffusersPipelineLoader(model, local_files_only=od_config.local_files_only) + + # Load VAE + self.vae = loader.load_model(AutoencoderKLWan, "vae") + self.vae.requires_grad_(False) + self.vae.eval() + + # Load text encoder and tokenizer + self.text_encoder = loader.load_model(UMT5EncoderModel, "text_encoder") + self.text_encoder.requires_grad_(False) + self.text_encoder.eval() + self.tokenizer = loader.load_tokenizer(AutoTokenizer, "tokenizer") + + # Load CLIP for image conditioning + self.image_encoder = loader.load_model(CLIPVisionModel, "image_encoder") + self.image_encoder.requires_grad_(False) + self.image_encoder.eval() + self.image_processor = loader.load_processor(CLIPImageProcessor, "image_processor") + + # Load or create transformer + transformer_config = load_transformer_config(model, local_files_only=od_config.local_files_only) + if transformer_config: + self.transformer = create_transformer_from_config(transformer_config) + else: + # Default configuration for SkyReels-V3-R2V-14B + self.transformer = SkyReelsTransformer3DModel( + num_attention_heads=16, + attention_head_dim=88, + in_channels=16, + num_layers=28, + text_dim=4096, + image_dim=1024, # CLIP image embedding dimension + patch_size=(1, 2, 2), + ) + + # Load transformer weights + loader.load_module_weights(self.transformer, "transformer") + + # Load scheduler + self.scheduler = loader.load_scheduler(FlowUniPCMultistepScheduler, "scheduler") + + if hasattr(od_config, "flow_shift") and od_config.flow_shift is not None: + self.scheduler.config.shift = od_config.flow_shift + + # Move to device + device = get_local_device() + self.to(device) + + # Set VAE scaling factor + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device, + num_videos_per_prompt: int = 1, + do_classifier_free_guidance: bool = False, + negative_prompt: str | list[str] | None = None, + ) -> torch.Tensor: + """Encode text prompt using UMT5.""" + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # Tokenize + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(device) + + # Encode + prompt_embeds = self.text_encoder(text_input_ids)[0] + + # Duplicate for multiple videos per prompt + if num_videos_per_prompt > 1: + prompt_embeds = prompt_embeds.repeat_interleave(num_videos_per_prompt, dim=0) + + # Handle classifier-free guidance + if do_classifier_free_guidance: + if negative_prompt is None: + negative_prompt = [""] * batch_size + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + + uncond_tokens = self.tokenizer( + negative_prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + uncond_input_ids = uncond_tokens.input_ids.to(device) + negative_prompt_embeds = self.text_encoder(uncond_input_ids)[0] + + if num_videos_per_prompt > 1: + negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_videos_per_prompt, dim=0) + + # Concatenate for CFG + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def encode_image( + self, + image: PIL.Image.Image | list[PIL.Image.Image] | torch.Tensor, + device: torch.device, + num_videos_per_prompt: int = 1, + ) -> torch.Tensor: + """Encode reference image using CLIP.""" + if isinstance(image, (PIL.Image.Image, list)): + image = self.image_processor(images=image, return_tensors="pt").pixel_values + image = image.to(device=device, dtype=self.image_encoder.dtype) + + # Encode + image_embeds = self.image_encoder(image).pooler_output + + # Duplicate for multiple videos per prompt + if num_videos_per_prompt > 1: + image_embeds = image_embeds.repeat_interleave(num_videos_per_prompt, dim=0) + + return image_embeds + + @torch.no_grad() + def forward( + self, + request: OmniDiffusionRequest, + ) -> DiffusionOutput: + """ + Generate video from image and text prompt. + + Args: + request: Diffusion request containing prompts and parameters + + Returns: + Generated video frames + """ + device = next(self.parameters()).device + dtype = next(self.parameters()).dtype + + # Extract parameters + prompt = [p["prompt"] if isinstance(p, dict) else p for p in request.prompts] + batch_size = len(prompt) + + # Get sampling parameters + height = request.sampling_params.height or 480 + width = request.sampling_params.width or 832 + num_frames = request.sampling_params.num_frames or 81 + num_inference_steps = request.sampling_params.num_inference_steps or 50 + guidance_scale = ( + request.sampling_params.guidance_scale if request.sampling_params.guidance_scale is not None else 7.5 + ) + num_videos_per_prompt = getattr(request.sampling_params, "num_outputs_per_prompt", None) or 1 + + do_classifier_free_guidance = guidance_scale > 1.0 + + # Encode text prompt + prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + do_classifier_free_guidance, + request.sampling_params.negative_prompt, + ) + + # Encode reference image using CLIP + # Use the resized PIL image from multi_modal_data, not the VAE-preprocessed one + pil_images = [] + for p in request.prompts: + multi_modal_data = p.get("multi_modal_data") or {} if isinstance(p, dict) else {} + img = multi_modal_data.get("image") + if img is not None: + pil_images.append(img) + + if not pil_images: + raise ValueError("No reference images found in request (expected in prompt['multi_modal_data']['image'])") + + # Use CLIPImageProcessor to obtain CLIP-ready image tensor + image_embeds = self.encode_image(pil_images, device, num_videos_per_prompt) + + # Prepare latents + # Use separate temporal and spatial VAE scale factors, as required by AutoencoderKLWan. + vae_scale_factor_temporal = getattr(self.vae.config, "scale_factor_temporal", 1) + vae_scale_factor_spatial = getattr( + self.vae.config, + "scale_factor_spatial", + getattr(self, "vae_scale_factor", 1), + ) + + # Latent temporal length is typically shorter than the decoded frame count. + num_latent_frames = (num_frames - 1) // vae_scale_factor_temporal + 1 + + num_channels_latents = self.transformer.in_channels + latents_shape = ( + batch_size * num_videos_per_prompt, + num_channels_latents, + num_latent_frames, + height // vae_scale_factor_spatial, + width // vae_scale_factor_spatial, + ) + + generator = torch.Generator(device=device) + if request.sampling_params.seed is not None: + generator.manual_seed(request.sampling_params.seed) + + latents = randn_tensor(latents_shape, generator=generator, device=device, dtype=dtype) + + # Prepare scheduler + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # Denoising loop + for i, t in enumerate(timesteps): + # Expand latents for CFG + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # Expand timestep + timestep = t.expand(latent_model_input.shape[0]) + + # Predict noise + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + image_hidden_states=image_embeds.repeat(2, 1) if do_classifier_free_guidance else image_embeds, + ).sample + + # Perform CFG + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # Compute previous noisy sample + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + # Decode latents using AutoencoderKLWan normalization + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + + return DiffusionOutput( + output=video, + request_id=request.request_id, + ) diff --git a/vllm_omni/diffusion/models/skyreels_v3/skyreels_v3_transformer.py b/vllm_omni/diffusion/models/skyreels_v3/skyreels_v3_transformer.py new file mode 100644 index 00000000000..e4e69d70c87 --- /dev/null +++ b/vllm_omni/diffusion/models/skyreels_v3/skyreels_v3_transformer.py @@ -0,0 +1,609 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +SkyReels-V3 Transformer Model Implementation. + +This module implements the transformer architecture for SkyReels-V3, +a multimodal video generation model supporting: +- Image-to-Video (R2V) +- Video-to-Video (V2V) +- Audio-to-Video (A2V) +""" + +from collections.abc import Iterable + +import torch +import torch.nn as nn +from diffusers.models.attention import FeedForward +from diffusers.models.embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from vllm.logger import init_logger +from vllm.model_executor.layers.conv import Conv3dLayer +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import QKVParallelLinear, ReplicatedLinear +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from vllm_omni.diffusion.attention.layer import Attention + +logger = init_logger(__name__) + + +def apply_rotary_emb_skyreels( + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, +) -> torch.Tensor: + """ + Apply rotary embeddings to input tensors using the given frequency tensors. + + Args: + hidden_states: Input tensor of shape [B, S, H, D] + freqs_cos: Cosine frequencies + freqs_sin: Sine frequencies + + Returns: + Tensor with rotary embeddings applied + """ + x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) + # freqs_cos and freqs_sin are 2D [seq_len, dim], so use : indexing not ... + cos = freqs_cos[:, 0::2] + sin = freqs_sin[:, 1::2] + out = torch.empty_like(hidden_states) + out[..., 0::2] = x1 * cos - x2 * sin + out[..., 1::2] = x1 * sin + x2 * cos + return out.type_as(hidden_states) + + +class SkyReelsRotaryPosEmbed(nn.Module): + """ + Rotary position embeddings for 3D video data (temporal + spatial dimensions). + Adapted for SkyReels-V3 architecture. + """ + + def __init__( + self, + attention_head_dim: int, + patch_size: tuple[int, int, int], + max_seq_len: int, + theta: float = 10000.0, + ): + super().__init__() + + self.attention_head_dim = attention_head_dim + self.patch_size = patch_size + self.max_seq_len = max_seq_len + + # Split dimensions for temporal, height, width + h_dim = w_dim = 2 * (attention_head_dim // 6) + t_dim = attention_head_dim - h_dim - w_dim + freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 + + freqs_cos = [] + freqs_sin = [] + + for dim in [t_dim, h_dim, w_dim]: + freq_cos, freq_sin = self._get_1d_rotary_pos_embed(dim, max_seq_len, theta, freqs_dtype) + freqs_cos.append(freq_cos) + freqs_sin.append(freq_sin) + + self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False) + self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False) + + @staticmethod + def _get_1d_rotary_pos_embed( + dim: int, + max_seq_len: int, + theta: float, + freqs_dtype: torch.dtype, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Generate 1D rotary position embeddings.""" + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype) / dim)) + t = torch.arange(max_seq_len, dtype=freqs_dtype) + freqs = torch.outer(t, freqs) + # Repeat interleave for real representation + freqs_cos = freqs.cos().repeat_interleave(2, dim=-1) + freqs_sin = freqs.sin().repeat_interleave(2, dim=-1) + return freqs_cos, freqs_sin + + def forward( + self, + hidden_states: torch.Tensor, + t: int, + h: int, + w: int, + ) -> torch.Tensor: + """ + Apply rotary position embeddings. + + Args: + hidden_states: Input tensor [B, S, H, D] + t: Temporal dimension + h: Height dimension + w: Width dimension + + Returns: + Tensor with rotary embeddings applied + """ + # Get position indices + seq_len = t * h * w + freqs_cos = self.freqs_cos[:seq_len] # type: ignore + freqs_sin = self.freqs_sin[:seq_len] # type: ignore + + return apply_rotary_emb_skyreels(hidden_states, freqs_cos, freqs_sin) + + +class SkyReelsSelfAttention(nn.Module): + """ + Optimized self-attention module for SkyReels-V3 using vLLM layers. + """ + + def __init__( + self, + dim: int, + num_heads: int, + head_dim: int, + eps: float = 1e-5, + dropout: float = 0.0, + ): + super().__init__() + + self.dim = dim + self.num_heads = num_heads + self.head_dim = head_dim + self.inner_dim = num_heads * head_dim + + # Fused QKV projection using vLLM's optimized layer + self.to_qkv = QKVParallelLinear( + hidden_size=dim, + head_size=head_dim, + total_num_heads=num_heads, + bias=True, + disable_tp=True, + ) + + # QK normalization using vLLM's RMSNorm + self.norm_q = RMSNorm(self.inner_dim, eps=eps) + self.norm_k = RMSNorm(self.inner_dim, eps=eps) + + # Output projection + self.to_out = nn.ModuleList( + [ + ReplicatedLinear(self.inner_dim, dim, bias=True), + nn.Dropout(dropout), + ] + ) + + # Unified attention layer + self.attn = Attention( + num_heads=num_heads, + head_size=head_dim, + softmax_scale=1.0 / (head_dim**0.5), + causal=False, + ) + + def forward( + self, + hidden_states: torch.Tensor, + rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + # Fused QKV projection + qkv, _ = self.to_qkv(hidden_states) + query, key, value = qkv.chunk(3, dim=-1) + + # Apply QK normalization + query = self.norm_q(query) + key = self.norm_k(key) + + # Reshape for multi-head attention + query = query.unflatten(2, (self.num_heads, -1)) + key = key.unflatten(2, (self.num_heads, -1)) + value = value.unflatten(2, (self.num_heads, -1)) + + # Apply rotary embeddings + if rotary_emb is not None: + freqs_cos, freqs_sin = rotary_emb + # freqs_* are 2D [S, head_dim]; make them broadcastable with + # query/key of shape [B, S, num_heads, head_dim] by reshaping to + # [1, S, 1, head_dim]. + if freqs_cos.dim() == 2: + freqs_cos = freqs_cos.unsqueeze(0).unsqueeze(2) + if freqs_sin.dim() == 2: + freqs_sin = freqs_sin.unsqueeze(0).unsqueeze(2) + query = apply_rotary_emb_skyreels(query, freqs_cos, freqs_sin) + key = apply_rotary_emb_skyreels(key, freqs_cos, freqs_sin) + + # Compute attention using unified attention layer + hidden_states = self.attn(query, key, value) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + # Output projection + hidden_states, _ = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states) + + return hidden_states + + +class SkyReelsCrossAttention(nn.Module): + """ + Optimized cross-attention module for SkyReels-V3 using vLLM layers. + """ + + def __init__( + self, + dim: int, + num_heads: int, + head_dim: int, + text_dim: int, + eps: float = 1e-5, + dropout: float = 0.0, + cross_attn_norm: bool = False, + ): + super().__init__() + + self.dim = dim + self.num_heads = num_heads + self.head_dim = head_dim + self.inner_dim = num_heads * head_dim + self.cross_attn_norm = cross_attn_norm + + # Query projection + self.to_q = ReplicatedLinear(dim, self.inner_dim, bias=True) + + # Key and Value projections for encoder + self.to_kv = ReplicatedLinear(text_dim, self.inner_dim * 2, bias=True) + + # QK normalization + self.norm_q = RMSNorm(self.inner_dim, eps=eps) + self.norm_k = RMSNorm(self.inner_dim, eps=eps) + + # Optional encoder normalization + if cross_attn_norm: + self.norm_encoder = RMSNorm(text_dim, eps=eps) + + # Output projection + self.to_out = nn.ModuleList( + [ + ReplicatedLinear(self.inner_dim, dim, bias=True), + nn.Dropout(dropout), + ] + ) + + # Unified attention layer + self.attn = Attention( + num_heads=num_heads, + head_size=head_dim, + softmax_scale=1.0 / (head_dim**0.5), + causal=False, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + ) -> torch.Tensor: + # Normalize encoder if needed + if self.cross_attn_norm: + encoder_hidden_states = self.norm_encoder(encoder_hidden_states) + + # Project query + query, _ = self.to_q(hidden_states) + query = self.norm_q(query) + + # Project key and value + kv, _ = self.to_kv(encoder_hidden_states) + key, value = kv.chunk(2, dim=-1) + key = self.norm_k(key) + + # Reshape for multi-head attention + query = query.unflatten(2, (self.num_heads, -1)) + key = key.unflatten(2, (self.num_heads, -1)) + value = value.unflatten(2, (self.num_heads, -1)) + + # Compute attention + hidden_states = self.attn(query, key, value) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + # Output projection + hidden_states, _ = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states) + + return hidden_states + + +class SkyReelsTransformerBlock(nn.Module): + """ + Transformer block for SkyReels-V3. + Includes self-attention, cross-attention, and feed-forward layers. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + text_dim: int, + ffn_dim: int | None = None, + dropout: float = 0.0, + cross_attn_norm: bool = False, + eps: float = 1e-6, + ): + super().__init__() + + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + + # Self-attention + self.norm1 = RMSNorm(dim, eps=eps) + self.attn1 = SkyReelsSelfAttention( + dim=dim, + num_heads=num_attention_heads, + head_dim=attention_head_dim, + eps=eps, + dropout=dropout, + ) + + # Cross-attention + self.norm2 = RMSNorm(dim, eps=eps) + self.attn2 = SkyReelsCrossAttention( + dim=dim, + num_heads=num_attention_heads, + head_dim=attention_head_dim, + text_dim=text_dim, + eps=eps, + dropout=dropout, + cross_attn_norm=cross_attn_norm, + ) + + # Feed-forward + self.norm3 = RMSNorm(dim, eps=eps) + ffn_dim = ffn_dim or dim * 4 + self.ff = FeedForward( + dim=dim, + dim_out=dim, + mult=ffn_dim // dim, + dropout=dropout, + activation_fn="gelu-approximate", + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + """ + Forward pass through the transformer block. + + Args: + hidden_states: Input tensor [B, S, D] + encoder_hidden_states: Encoder outputs for cross-attention [B, T, D_text] + rotary_emb: Rotary position embeddings (cos, sin) + + Returns: + Output tensor [B, S, D] + """ + # Self-attention + norm_hidden_states = self.norm1(hidden_states) + attn_output = self.attn1(norm_hidden_states, rotary_emb=rotary_emb) + hidden_states = hidden_states + attn_output + + # Cross-attention + if encoder_hidden_states is not None: + norm_hidden_states = self.norm2(hidden_states) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states) + hidden_states = hidden_states + attn_output + + # Feed-forward + norm_hidden_states = self.norm3(hidden_states) + ff_output = self.ff(norm_hidden_states) + hidden_states = hidden_states + ff_output + + return hidden_states + + +class SkyReelsTransformer3DModel(nn.Module): + """ + SkyReels-V3 3D Transformer Model for video generation. + + Supports multiple modalities: + - Text-to-Video + - Image-to-Video (R2V) + - Video-to-Video (V2V) + - Audio-to-Video (A2V) + """ + + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: int = 16, + out_channels: int | None = None, + num_layers: int = 28, + dropout: float = 0.0, + text_dim: int = 4096, + ffn_dim: int | None = None, + patch_size: tuple[int, int, int] = (1, 2, 2), + cross_attn_norm: bool = False, + eps: float = 1e-6, + max_seq_len: int = 16384, + image_dim: int | None = None, + added_kv_proj_dim: int | None = None, + ): + super().__init__() + + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.inner_dim = num_attention_heads * attention_head_dim + self.in_channels = in_channels + self.out_channels = out_channels or in_channels + self.patch_size = patch_size + self.text_dim = text_dim + self.image_dim = image_dim + self.added_kv_proj_dim = added_kv_proj_dim + + # Input projection + self.proj_in = Conv3dLayer( + in_channels=in_channels, + out_channels=self.inner_dim, + kernel_size=patch_size, + stride=patch_size, + padding=0, + bias=True, + ) + + # Timestep embedding + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.time_embedding = TimestepEmbedding( + in_channels=256, + time_embed_dim=self.inner_dim, + ) + + # Text projection + self.text_proj = PixArtAlphaTextProjection( + in_features=text_dim, + hidden_size=self.inner_dim, + ) + + # Optional image conditioning projection + if image_dim is not None: + self.image_proj = ReplicatedLinear( + image_dim, + self.inner_dim, + bias=True, + ) + + # Rotary position embeddings + self.pos_embed = SkyReelsRotaryPosEmbed( + attention_head_dim=attention_head_dim, + patch_size=patch_size, + max_seq_len=max_seq_len, + ) + + # Transformer blocks + self.transformer_blocks = nn.ModuleList( + [ + SkyReelsTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + text_dim=self.inner_dim, + ffn_dim=ffn_dim, + dropout=dropout, + cross_attn_norm=cross_attn_norm, + eps=eps, + ) + for _ in range(num_layers) + ] + ) + + # Output layers + self.norm_out = RMSNorm(self.inner_dim, eps=eps) + self.proj_out = ReplicatedLinear( + self.inner_dim, + self.out_channels * patch_size[0] * patch_size[1] * patch_size[2], + bias=True, + ) + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_attention_mask: torch.Tensor | None = None, + image_hidden_states: torch.Tensor | None = None, + return_dict: bool = True, + ) -> Transformer2DModelOutput | tuple: + """ + Forward pass through the SkyReels-V3 transformer. + + Args: + hidden_states: Latent video tensor [B, C, T, H, W] + timestep: Diffusion timestep [B] + encoder_hidden_states: Text embeddings [B, seq_len, text_dim] + encoder_attention_mask: Attention mask for text + image_hidden_states: Optional image conditioning [B, image_dim] + return_dict: Whether to return a dict or tuple + + Returns: + Transformer output + """ + batch_size, channels, num_frames, height, width = hidden_states.shape + + # Project input + hidden_states = self.proj_in(hidden_states) # [B, inner_dim, T', H', W'] + + # Reshape to sequence + t_out, h_out, w_out = ( + num_frames // self.patch_size[0], + height // self.patch_size[1], + width // self.patch_size[2], + ) + hidden_states = hidden_states.flatten(2).transpose(1, 2) # [B, T'*H'*W', inner_dim] + + # Timestep embedding + timestep_emb = self.time_proj(timestep) + timestep_emb = self.time_embedding(timestep_emb) # [B, inner_dim] + timestep_emb = timestep_emb.unsqueeze(1) # [B, 1, inner_dim] + + # Text projection + encoder_hidden_states = self.text_proj(encoder_hidden_states) # [B, seq_len, inner_dim] + + # Add timestep to encoder hidden states + encoder_hidden_states = torch.cat([timestep_emb, encoder_hidden_states], dim=1) + + # Optional image conditioning + if image_hidden_states is not None and self.image_dim is not None: + image_emb, _ = self.image_proj(image_hidden_states) + image_emb = image_emb.unsqueeze(1) # [B, 1, inner_dim] + encoder_hidden_states = torch.cat([encoder_hidden_states, image_emb], dim=1) + + # Get rotary position embeddings + seq_len = t_out * h_out * w_out + freqs_cos = self.pos_embed.freqs_cos[:seq_len] # type: ignore + freqs_sin = self.pos_embed.freqs_sin[:seq_len] # type: ignore + rotary_emb = (freqs_cos, freqs_sin) + + # Transformer blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + rotary_emb=rotary_emb, + ) + + # Output projection + hidden_states = self.norm_out(hidden_states) + hidden_states, _ = self.proj_out(hidden_states) + + # Reshape back to video format + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, + self.out_channels, + t_out, + h_out, + w_out, + self.patch_size[0], + self.patch_size[1], + self.patch_size[2], + ) + hidden_states = hidden_states.permute(0, 1, 2, 5, 3, 6, 4, 7).reshape( + batch_size, self.out_channels, num_frames, height, width + ) + + if not return_dict: + return (hidden_states,) + + return Transformer2DModelOutput(sample=hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + """Load model weights.""" + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if name not in params_dict: + continue + param = params_dict[name] + default_weight_loader(param, loaded_weight) diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index 173c576a243..e6c842b905f 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -100,6 +100,11 @@ "pipeline_flux", "FluxPipeline", ), + "SkyReelsV3R2VPipeline": ( + "skyreels_v3", + "pipeline_skyreels_v3_r2v", + "SkyReelsV3R2VPipeline", + ), } @@ -277,6 +282,7 @@ def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) - "StableDiffusion3Pipeline": "get_sd3_image_post_process_func", "Flux2KleinPipeline": "get_flux2_klein_post_process_func", "FluxPipeline": "get_flux_post_process_func", + "SkyReelsV3R2VPipeline": "get_skyreels_v3_r2v_post_process_func", } _DIFFUSION_PRE_PROCESS_FUNCS = { @@ -290,6 +296,7 @@ def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) - "QwenImageLayeredPipeline": "get_qwen_image_layered_pre_process_func", "WanPipeline": "get_wan22_pre_process_func", "WanImageToVideoPipeline": "get_wan22_i2v_pre_process_func", + "SkyReelsV3R2VPipeline": "get_skyreels_v3_r2v_pre_process_func", } diff --git a/vllm_omni/model_executor/stage_configs/skyreels_v3_r2v.yaml b/vllm_omni/model_executor/stage_configs/skyreels_v3_r2v.yaml new file mode 100644 index 00000000000..b4f53c32d2a --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/skyreels_v3_r2v.yaml @@ -0,0 +1,47 @@ +stage_args: + - stage_id: 0 + stage_type: diffusion + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: dit + model_class_name: SkyReelsV3R2VPipeline + gpu_memory_utilization: 0.85 + enforce_eager: true + trust_remote_code: true + engine_output_type: video + distributed_executor_backend: "mp" + enable_prefix_caching: false + max_num_batched_tokens: 32768 + tensor_parallel_size: 1 + # SkyReels-V3 specific settings + flow_shift: 5.0 # Flow matching shift parameter + attention_backend: "sdpa" # Use SDPA for attention + final_output: true + final_output_type: video + is_comprehension: false + default_sampling_params: + height: 480 + width: 832 + num_frames: 81 + num_inference_steps: 50 + guidance_scale: 7.5 + seed: 42 + num_outputs_per_prompt: 1 + +# Runtime configuration +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 + + # Distributed connectors configuration + connectors: + shared_memory_connector: + name: SharedMemoryConnector + extra: + shm_threshold_bytes: 65536 # 64KB threshold + + edges: [] # Single stage, no edges needed