From cdc74cada45dec9e34d3892b55ac4c4eb70f8f1a Mon Sep 17 00:00:00 2001 From: gDINESH13 Date: Sun, 8 Feb 2026 11:35:55 +0530 Subject: [PATCH 1/5] New Model SkyReels added to vllm-omni Signed-off-by: gDINESH13 --- .../offline_inference/skyreels_v3/README.md | 144 ++++ .../skyreels_v3/image_to_video.py | 181 ++++++ .../diffusion/models/skyreels_v3/__init__.py | 15 + .../skyreels_v3/pipeline_skyreels_v3_r2v.py | 435 +++++++++++++ .../skyreels_v3/skyreels_v3_transformer.py | 614 ++++++++++++++++++ vllm_omni/diffusion/registry.py | 7 + .../stage_configs/skyreels_v3_r2v.yaml | 47 ++ 7 files changed, 1443 insertions(+) create mode 100644 examples/offline_inference/skyreels_v3/README.md create mode 100644 examples/offline_inference/skyreels_v3/image_to_video.py create mode 100644 vllm_omni/diffusion/models/skyreels_v3/__init__.py create mode 100644 vllm_omni/diffusion/models/skyreels_v3/pipeline_skyreels_v3_r2v.py create mode 100644 vllm_omni/diffusion/models/skyreels_v3/skyreels_v3_transformer.py create mode 100644 vllm_omni/model_executor/stage_configs/skyreels_v3_r2v.yaml diff --git a/examples/offline_inference/skyreels_v3/README.md b/examples/offline_inference/skyreels_v3/README.md new file mode 100644 index 00000000000..fe5cc5e539d --- /dev/null +++ b/examples/offline_inference/skyreels_v3/README.md @@ -0,0 +1,144 @@ +# SkyReels-V3 Offline Inference Examples + +This directory contains examples for using the SkyReels-V3 multimodal video generation models with vLLM-Omni. + +## Models + +SkyReels-V3 is a family of multimodal video generation models that support: + +- **Image-to-Video (R2V)**: Generate videos from reference images +- **Video-to-Video (V2V)**: Transform existing videos +- **Audio-to-Video (A2V)**: Generate videos guided by audio + +### Available Models + +- `Skywork/SkyReels-V3-R2V-14B`: Image-to-Video (14B parameters) +- `Skywork/SkyReels-V3-V2V-14B`: Video-to-Video (14B parameters) +- `Skywork/SkyReels-V3-A2V-19B`: Audio-to-Video (19B parameters) + +## Installation + +Install the required dependencies: + +```bash +pip install vllm-omni +pip install imageio imageio-ffmpeg # For video I/O +``` + +## Usage + +### Image-to-Video (R2V) + +Generate a video from a reference image: + +```bash +python image_to_video.py \ + --model Skywork/SkyReels-V3-R2V-14B \ + --image path/to/your/image.jpg \ + --prompt "A person walking through a beautiful garden" \ + --height 480 \ + --width 832 \ + --num-frames 81 \ + --num-inference-steps 50 \ + --guidance-scale 7.5 \ + --seed 42 \ + --output-dir ./outputs/skyreels_v3 \ + --output-format mp4 +``` + +### Parameters + +- `--model`: Model name or path (default: `Skywork/SkyReels-V3-R2V-14B`) +- `--image`: Path to the reference image (required) +- `--prompt`: Text prompt describing the desired video +- `--negative-prompt`: Negative prompt to avoid certain content (optional) +- `--height`: Video height in pixels (default: 480) +- `--width`: Video width in pixels (default: 832) +- `--num-frames`: Number of frames to generate (default: 81) +- `--num-inference-steps`: Number of denoising steps (default: 50, higher = better quality but slower) +- `--guidance-scale`: Classifier-free guidance scale (default: 7.5, higher = more prompt adherence) +- `--seed`: Random seed for reproducibility (default: 42) +- `--output-dir`: Output directory for generated videos +- `--output-format`: Output format: `mp4`, `gif`, or `frames` + +## Examples + +### Basic Image-to-Video + +```bash +python image_to_video.py \ + --image examples/sample_image.jpg \ + --prompt "A cinematic video of the scene" +``` + +### High-Quality Generation + +```bash +python image_to_video.py \ + --image examples/sample_image.jpg \ + --prompt "A dramatic video with dynamic camera movement" \ + --num-inference-steps 100 \ + --guidance-scale 9.0 \ + --num-frames 121 +``` + +### Generate GIF + +```bash +python image_to_video.py \ + --image examples/sample_image.jpg \ + --prompt "A looping animation" \ + --output-format gif \ + --num-frames 49 +``` + +## Tips + +1. **Image Quality**: Use high-quality reference images for best results +2. **Aspect Ratio**: The model works best with 16:9 aspect ratio (e.g., 832x480) +3. **Frame Count**: More frames = longer videos but slower generation +4. **Guidance Scale**: + - Lower (3-5): More creative, less adherence to prompt + - Medium (7-9): Balanced + - Higher (10+): Strong prompt adherence, may reduce quality +5. **Inference Steps**: 50 steps is usually sufficient; 100+ for highest quality + +## Performance + +- **GPU Memory**: ~24GB VRAM required for R2V-14B model +- **Generation Time**: ~2-5 minutes for 81 frames on A100 GPU +- **Batch Size**: Currently supports batch size of 1 + +## Troubleshooting + +### Out of Memory + +If you encounter OOM errors: +- Reduce `--num-frames` +- Reduce `--height` and `--width` +- Use a smaller model variant if available + +### Poor Quality + +If the output quality is poor: +- Increase `--num-inference-steps` (try 75-100) +- Adjust `--guidance-scale` (try 8-10) +- Use a higher quality reference image +- Refine your prompt to be more specific + +## Citation + +If you use SkyReels-V3 in your research, please cite: + +```bibtex +@article{skyreels2025, + title={SkyReels-V3: Multimodal Video Generation with Unified In-Context Learning}, + author={Skywork Team}, + journal={arXiv preprint}, + year={2025} +} +``` + +## License + +SkyReels-V3 models are released under the Skywork License. Please refer to the model card on Hugging Face for details. \ No newline at end of file diff --git a/examples/offline_inference/skyreels_v3/image_to_video.py b/examples/offline_inference/skyreels_v3/image_to_video.py new file mode 100644 index 00000000000..4ad4f31a97f --- /dev/null +++ b/examples/offline_inference/skyreels_v3/image_to_video.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +SkyReels-V3 Image-to-Video (R2V) Offline Inference Example. + +This script demonstrates how to use the SkyReels-V3 R2V model to generate +videos from reference images using the vLLM-Omni framework. + +Usage: + python image_to_video.py --model Skywork/SkyReels-V3-R2V-14B \ + --image path/to/image.jpg \ + --prompt "A person walking in the park" +""" + +import argparse +import os +from pathlib import Path + +from PIL import Image + +from vllm_omni.entrypoints.omni_diffusion import OmniDiffusion + + +def main(): + parser = argparse.ArgumentParser(description="SkyReels-V3 Image-to-Video Generation") + parser.add_argument( + "--model", + type=str, + default="Skywork/SkyReels-V3-R2V-14B", + help="Model name or path (default: Skywork/SkyReels-V3-R2V-14B)", + ) + parser.add_argument( + "--image", + type=str, + required=True, + help="Path to the reference image", + ) + parser.add_argument( + "--prompt", + type=str, + default="A cinematic video", + help="Text prompt describing the desired video", + ) + parser.add_argument( + "--negative-prompt", + type=str, + default="", + help="Negative prompt (optional)", + ) + parser.add_argument( + "--height", + type=int, + default=480, + help="Video height (default: 480)", + ) + parser.add_argument( + "--width", + type=int, + default=832, + help="Video width (default: 832)", + ) + parser.add_argument( + "--num-frames", + type=int, + default=81, + help="Number of frames to generate (default: 81)", + ) + parser.add_argument( + "--num-inference-steps", + type=int, + default=50, + help="Number of denoising steps (default: 50)", + ) + parser.add_argument( + "--guidance-scale", + type=float, + default=7.5, + help="Guidance scale for classifier-free guidance (default: 7.5)", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for reproducibility (default: 42)", + ) + parser.add_argument( + "--output-dir", + type=str, + default="./outputs/skyreels_v3", + help="Output directory for generated videos (default: ./outputs/skyreels_v3)", + ) + parser.add_argument( + "--output-format", + type=str, + default="mp4", + choices=["mp4", "gif", "frames"], + help="Output format: mp4, gif, or frames (default: mp4)", + ) + + args = parser.parse_args() + + # Create output directory + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Load reference image + if not os.path.exists(args.image): + raise FileNotFoundError(f"Image not found: {args.image}") + + image = Image.open(args.image).convert("RGB") + print(f"Loaded reference image: {args.image} ({image.size})") + + # Initialize the model + print(f"Loading SkyReels-V3 model: {args.model}") + model = OmniDiffusion( + model=args.model, + model_class_name="SkyReelsV3R2VPipeline", + trust_remote_code=True, + ) + + # Prepare the request + print(f"\nGenerating video with prompt: '{args.prompt}'") + print(f"Parameters:") + print(f" - Resolution: {args.width}x{args.height}") + print(f" - Frames: {args.num_frames}") + print(f" - Steps: {args.num_inference_steps}") + print(f" - Guidance Scale: {args.guidance_scale}") + print(f" - Seed: {args.seed}") + + # Generate video + outputs = model.generate( + prompts=[ + { + "prompt": args.prompt, + "multi_modal_data": {"image": image}, + } + ], + sampling_params={ + "height": args.height, + "width": args.width, + "num_frames": args.num_frames, + "num_inference_steps": args.num_inference_steps, + "guidance_scale": args.guidance_scale, + "seed": args.seed, + }, + ) + + # Save the generated video + for idx, output in enumerate(outputs): + video_frames = output.outputs[0] # Get the video frames + + if args.output_format == "mp4": + output_path = output_dir / f"video_{idx:04d}.mp4" + # Save as MP4 video + import imageio + imageio.mimsave(output_path, video_frames, fps=24, codec="libx264") + print(f"\nSaved video to: {output_path}") + + elif args.output_format == "gif": + output_path = output_dir / f"video_{idx:04d}.gif" + # Save as GIF + import imageio + imageio.mimsave(output_path, video_frames, fps=12) + print(f"\nSaved GIF to: {output_path}") + + else: # frames + frames_dir = output_dir / f"video_{idx:04d}_frames" + frames_dir.mkdir(exist_ok=True) + # Save individual frames + for frame_idx, frame in enumerate(video_frames): + frame_path = frames_dir / f"frame_{frame_idx:04d}.png" + Image.fromarray(frame).save(frame_path) + print(f"\nSaved {len(video_frames)} frames to: {frames_dir}") + + print("\nGeneration complete!") + + +if __name__ == "__main__": + main() diff --git a/vllm_omni/diffusion/models/skyreels_v3/__init__.py b/vllm_omni/diffusion/models/skyreels_v3/__init__.py new file mode 100644 index 00000000000..231f4509c86 --- /dev/null +++ b/vllm_omni/diffusion/models/skyreels_v3/__init__.py @@ -0,0 +1,15 @@ +"""SkyReels-V3 multimodal video generation models.""" + +from .pipeline_skyreels_v3_r2v import ( + SkyReelsV3R2VPipeline, + get_skyreels_v3_r2v_post_process_func, + get_skyreels_v3_r2v_pre_process_func, +) +from .skyreels_v3_transformer import SkyReelsTransformer3DModel + +__all__ = [ + "SkyReelsV3R2VPipeline", + "get_skyreels_v3_r2v_post_process_func", + "get_skyreels_v3_r2v_pre_process_func", + "SkyReelsTransformer3DModel", +] \ No newline at end of file diff --git a/vllm_omni/diffusion/models/skyreels_v3/pipeline_skyreels_v3_r2v.py b/vllm_omni/diffusion/models/skyreels_v3/pipeline_skyreels_v3_r2v.py new file mode 100644 index 00000000000..bf33941077c --- /dev/null +++ b/vllm_omni/diffusion/models/skyreels_v3/pipeline_skyreels_v3_r2v.py @@ -0,0 +1,435 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +SkyReels-V3 Image-to-Video (R2V) Pipeline Implementation. + +This pipeline supports generating videos from reference images using the +SkyReels-V3 multimodal video generation model. +""" + +from __future__ import annotations + +import json +import logging +import os +from typing import Any, cast + +import numpy as np +import PIL.Image +import torch +from diffusers import AutoencoderKLWan +from diffusers.utils.torch_utils import randn_tensor +from torch import nn +from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel +from vllm.model_executor.models.utils import AutoWeightsLoader + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler +from vllm_omni.diffusion.models.skyreels_v3.skyreels_v3_transformer import SkyReelsTransformer3DModel +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.inputs.data import OmniTextPrompt + +logger = logging.getLogger(__name__) + + +def retrieve_latents( + encoder_output: torch.Tensor, + generator: torch.Generator | None = None, + sample_mode: str = "sample", +): + """Retrieve latents from VAE encoder output.""" + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def load_transformer_config(model_path: str, subfolder: str = "transformer", local_files_only: bool = True) -> dict: + """Load transformer config from model directory or HF Hub.""" + if local_files_only: + config_path = os.path.join(model_path, subfolder, "config.json") + if os.path.exists(config_path): + with open(config_path) as f: + return json.load(f) + else: + try: + from huggingface_hub import hf_hub_download + + config_path = hf_hub_download( + repo_id=model_path, + filename=f"{subfolder}/config.json", + ) + with open(config_path) as f: + return json.load(f) + except Exception: + pass + return {} + + +def create_transformer_from_config(config: dict) -> SkyReelsTransformer3DModel: + """Create SkyReelsTransformer3DModel from config dict.""" + kwargs = {} + + if "patch_size" in config: + kwargs["patch_size"] = tuple(config["patch_size"]) + if "num_attention_heads" in config: + kwargs["num_attention_heads"] = config["num_attention_heads"] + if "attention_head_dim" in config: + kwargs["attention_head_dim"] = config["attention_head_dim"] + if "in_channels" in config: + kwargs["in_channels"] = config["in_channels"] + if "out_channels" in config: + kwargs["out_channels"] = config["out_channels"] + if "text_dim" in config: + kwargs["text_dim"] = config["text_dim"] + if "ffn_dim" in config: + kwargs["ffn_dim"] = config["ffn_dim"] + if "num_layers" in config: + kwargs["num_layers"] = config["num_layers"] + if "cross_attn_norm" in config: + kwargs["cross_attn_norm"] = config["cross_attn_norm"] + if "eps" in config: + kwargs["eps"] = config["eps"] + if "image_dim" in config: + kwargs["image_dim"] = config["image_dim"] + if "max_seq_len" in config: + kwargs["max_seq_len"] = config["max_seq_len"] + + return SkyReelsTransformer3DModel(**kwargs) + + +def get_skyreels_v3_r2v_post_process_func( + od_config: OmniDiffusionConfig, +): + """Post-process function for R2V: convert latents to video frames.""" + from diffusers.video_processor import VideoProcessor + + video_processor = VideoProcessor(vae_scale_factor=8) + + def post_process_func( + video: torch.Tensor, + output_type: str = "np", + ): + if output_type == "latent": + return video + return video_processor.postprocess_video(video, output_type=output_type) + + return post_process_func + + +def get_skyreels_v3_r2v_pre_process_func( + od_config: OmniDiffusionConfig, +): + """Pre-process function for R2V: load and resize input image.""" + from diffusers.video_processor import VideoProcessor + + video_processor = VideoProcessor(vae_scale_factor=8) + + def pre_process_func(request: OmniDiffusionRequest) -> OmniDiffusionRequest: + for i, prompt in enumerate(request.prompts): + multi_modal_data = prompt.get("multi_modal_data", {}) if not isinstance(prompt, str) else None + raw_image = multi_modal_data.get("image", None) if multi_modal_data is not None else None + if isinstance(prompt, str): + prompt = OmniTextPrompt(prompt=prompt) + if "additional_information" not in prompt: + prompt["additional_information"] = {} + + if raw_image is None: + raise ValueError( + "No image is provided. This model requires an image to run. " + 'Please correctly set `"multi_modal_data": {"image": , …}`' + ) + if not isinstance(raw_image, (str, PIL.Image.Image)): + raise TypeError( + f"Unsupported image format {raw_image.__class__}. " + 'Please correctly set `"multi_modal_data": {"image": , …}`' + ) + image = PIL.Image.open(raw_image).convert("RGB") if isinstance(raw_image, str) else raw_image + + # Calculate dimensions based on aspect ratio if not provided + if request.sampling_params.height is None or request.sampling_params.width is None: + # Default max area for 480P + max_area = 480 * 832 + aspect_ratio = image.height / image.width + + # Calculate dimensions maintaining aspect ratio + mod_value = 16 # Must be divisible by 16 + height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + + if request.sampling_params.height is None: + request.sampling_params.height = height + if request.sampling_params.width is None: + request.sampling_params.width = width + + # Resize image to target dimensions + image = image.resize( + (request.sampling_params.width, request.sampling_params.height), # type: ignore + PIL.Image.Resampling.LANCZOS, + ) + prompt["multi_modal_data"]["image"] = image # type: ignore + + # Preprocess for VAE + prompt["additional_information"]["preprocessed_image"] = video_processor.preprocess( + image, height=request.sampling_params.height, width=request.sampling_params.width + ) + request.prompts[i] = prompt + return request + + return pre_process_func + + +class SkyReelsV3R2VPipeline(nn.Module, SupportImageInput, CFGParallelMixin): + """ + SkyReels-V3 Image-to-Video (R2V) Pipeline. + + Generates videos from reference images using text prompts. + """ + + def __init__( + self, + od_config: OmniDiffusionConfig, + ): + super().__init__() + self.od_config = od_config + model = od_config.model + + # Load model components + loader = DiffusersPipelineLoader(model, local_files_only=od_config.local_files_only) + + # Load VAE + self.vae = loader.load_model(AutoencoderKLWan, "vae") + self.vae.requires_grad_(False) + self.vae.eval() + + # Load text encoder and tokenizer + self.text_encoder = loader.load_model(UMT5EncoderModel, "text_encoder") + self.text_encoder.requires_grad_(False) + self.text_encoder.eval() + self.tokenizer = loader.load_tokenizer(AutoTokenizer, "tokenizer") + + # Load CLIP for image conditioning + self.image_encoder = loader.load_model(CLIPVisionModel, "image_encoder") + self.image_encoder.requires_grad_(False) + self.image_encoder.eval() + self.image_processor = loader.load_processor(CLIPImageProcessor, "image_processor") + + # Load or create transformer + transformer_config = load_transformer_config(model, local_files_only=od_config.local_files_only) + if transformer_config: + self.transformer = create_transformer_from_config(transformer_config) + else: + # Default configuration for SkyReels-V3-R2V-14B + self.transformer = SkyReelsTransformer3DModel( + num_attention_heads=16, + attention_head_dim=88, + in_channels=16, + num_layers=28, + text_dim=4096, + image_dim=1024, # CLIP image embedding dimension + patch_size=(1, 2, 2), + ) + + # Load transformer weights + loader.load_module_weights(self.transformer, "transformer") + + # Load scheduler + self.scheduler = loader.load_scheduler(FlowUniPCMultistepScheduler, "scheduler") + + # Move to device + device = get_local_device() + self.to(device) + + # Set VAE scaling factor + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device, + num_videos_per_prompt: int = 1, + do_classifier_free_guidance: bool = False, + negative_prompt: str | list[str] | None = None, + ) -> torch.Tensor: + """Encode text prompt using UMT5.""" + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # Tokenize + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(device) + + # Encode + prompt_embeds = self.text_encoder(text_input_ids)[0] + + # Duplicate for multiple videos per prompt + if num_videos_per_prompt > 1: + prompt_embeds = prompt_embeds.repeat_interleave(num_videos_per_prompt, dim=0) + + # Handle classifier-free guidance + if do_classifier_free_guidance: + if negative_prompt is None: + negative_prompt = [""] * batch_size + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + + uncond_tokens = self.tokenizer( + negative_prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + uncond_input_ids = uncond_tokens.input_ids.to(device) + negative_prompt_embeds = self.text_encoder(uncond_input_ids)[0] + + if num_videos_per_prompt > 1: + negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_videos_per_prompt, dim=0) + + # Concatenate for CFG + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def encode_image( + self, + image: PIL.Image.Image | torch.Tensor, + device: torch.device, + num_videos_per_prompt: int = 1, + ) -> torch.Tensor: + """Encode reference image using CLIP.""" + if isinstance(image, PIL.Image.Image): + image = self.image_processor(images=image, return_tensors="pt").pixel_values + image = image.to(device=device, dtype=self.image_encoder.dtype) + + # Encode + image_embeds = self.image_encoder(image).pooler_output + + # Duplicate for multiple videos per prompt + if num_videos_per_prompt > 1: + image_embeds = image_embeds.repeat_interleave(num_videos_per_prompt, dim=0) + + return image_embeds + + @torch.no_grad() + def forward( + self, + request: OmniDiffusionRequest, + ) -> DiffusionOutput: + """ + Generate video from image and text prompt. + + Args: + request: Diffusion request containing prompts and parameters + + Returns: + Generated video frames + """ + device = next(self.parameters()).device + dtype = next(self.parameters()).dtype + + # Extract parameters + prompt = [p["prompt"] if isinstance(p, dict) else p for p in request.prompts] + batch_size = len(prompt) + + # Get sampling parameters + height = request.sampling_params.height or 480 + width = request.sampling_params.width or 832 + num_frames = request.sampling_params.num_frames or 81 + num_inference_steps = request.sampling_params.num_inference_steps or 50 + guidance_scale = request.sampling_params.guidance_scale or 7.5 + num_videos_per_prompt = request.sampling_params.num_videos_per_prompt or 1 + + do_classifier_free_guidance = guidance_scale > 1.0 + + # Encode text prompt + prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + do_classifier_free_guidance, + request.sampling_params.negative_prompt, + ) + + # Encode reference image + images = [] + for p in request.prompts: + if isinstance(p, dict) and "additional_information" in p: + img = p["additional_information"].get("preprocessed_image") + if img is not None: + images.append(img) + + if not images: + raise ValueError("No preprocessed images found in request") + + image_tensor = torch.cat(images, dim=0).to(device=device, dtype=dtype) + image_embeds = self.encode_image(image_tensor, device, num_videos_per_prompt) + + # Prepare latents + num_channels_latents = self.transformer.in_channels + latents_shape = ( + batch_size * num_videos_per_prompt, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + + generator = torch.Generator(device=device) + if request.sampling_params.seed is not None: + generator.manual_seed(request.sampling_params.seed) + + latents = randn_tensor(latents_shape, generator=generator, device=device, dtype=dtype) + + # Prepare scheduler + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # Denoising loop + for i, t in enumerate(timesteps): + # Expand latents for CFG + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # Expand timestep + timestep = t.expand(latent_model_input.shape[0]) + + # Predict noise + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + image_hidden_states=image_embeds if do_classifier_free_guidance else image_embeds.repeat(2, 1), + ).sample + + # Perform CFG + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # Compute previous noisy sample + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + # Decode latents + latents = latents / self.vae.config.scaling_factor + video = self.vae.decode(latents).sample + + return DiffusionOutput( + output=video, + request_id=request.request_id, + ) + diff --git a/vllm_omni/diffusion/models/skyreels_v3/skyreels_v3_transformer.py b/vllm_omni/diffusion/models/skyreels_v3/skyreels_v3_transformer.py new file mode 100644 index 00000000000..e30a237bfdd --- /dev/null +++ b/vllm_omni/diffusion/models/skyreels_v3/skyreels_v3_transformer.py @@ -0,0 +1,614 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +SkyReels-V3 Transformer Model Implementation. + +This module implements the transformer architecture for SkyReels-V3, +a multimodal video generation model supporting: +- Image-to-Video (R2V) +- Video-to-Video (V2V) +- Audio-to-Video (A2V) +""" + +import math +from collections.abc import Iterable +from typing import Any + +import torch +import torch.nn as nn +from diffusers.models.attention import FeedForward +from diffusers.models.embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.normalization import FP32LayerNorm +from vllm.logger import init_logger +from vllm.model_executor.layers.conv import Conv3dLayer +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import QKVParallelLinear, ReplicatedLinear +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from vllm_omni.diffusion.attention.layer import Attention +from vllm_omni.diffusion.distributed.sp_plan import ( + SequenceParallelInput, + SequenceParallelOutput, +) + +logger = init_logger(__name__) + + +def apply_rotary_emb_skyreels( + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, +) -> torch.Tensor: + """ + Apply rotary embeddings to input tensors using the given frequency tensors. + + Args: + hidden_states: Input tensor of shape [B, S, H, D] + freqs_cos: Cosine frequencies + freqs_sin: Sine frequencies + + Returns: + Tensor with rotary embeddings applied + """ + x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) + cos = freqs_cos[..., 0::2] + sin = freqs_sin[..., 1::2] + out = torch.empty_like(hidden_states) + out[..., 0::2] = x1 * cos - x2 * sin + out[..., 1::2] = x1 * sin + x2 * cos + return out.type_as(hidden_states) + + +class SkyReelsRotaryPosEmbed(nn.Module): + """ + Rotary position embeddings for 3D video data (temporal + spatial dimensions). + Adapted for SkyReels-V3 architecture. + """ + + def __init__( + self, + attention_head_dim: int, + patch_size: tuple[int, int, int], + max_seq_len: int, + theta: float = 10000.0, + ): + super().__init__() + + self.attention_head_dim = attention_head_dim + self.patch_size = patch_size + self.max_seq_len = max_seq_len + + # Split dimensions for temporal, height, width + h_dim = w_dim = 2 * (attention_head_dim // 6) + t_dim = attention_head_dim - h_dim - w_dim + freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 + + freqs_cos = [] + freqs_sin = [] + + for dim in [t_dim, h_dim, w_dim]: + freq_cos, freq_sin = self._get_1d_rotary_pos_embed(dim, max_seq_len, theta, freqs_dtype) + freqs_cos.append(freq_cos) + freqs_sin.append(freq_sin) + + self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False) + self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False) + + @staticmethod + def _get_1d_rotary_pos_embed( + dim: int, + max_seq_len: int, + theta: float, + freqs_dtype: torch.dtype, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Generate 1D rotary position embeddings.""" + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype) / dim)) + t = torch.arange(max_seq_len, dtype=freqs_dtype) + freqs = torch.outer(t, freqs) + # Repeat interleave for real representation + freqs_cos = freqs.cos().repeat_interleave(2, dim=-1) + freqs_sin = freqs.sin().repeat_interleave(2, dim=-1) + return freqs_cos, freqs_sin + + def forward( + self, + hidden_states: torch.Tensor, + t: int, + h: int, + w: int, + ) -> torch.Tensor: + """ + Apply rotary position embeddings. + + Args: + hidden_states: Input tensor [B, S, H, D] + t: Temporal dimension + h: Height dimension + w: Width dimension + + Returns: + Tensor with rotary embeddings applied + """ + # Get position indices + seq_len = t * h * w + freqs_cos = self.freqs_cos[:seq_len] # type: ignore + freqs_sin = self.freqs_sin[:seq_len] # type: ignore + + return apply_rotary_emb_skyreels(hidden_states, freqs_cos, freqs_sin) + + +class SkyReelsSelfAttention(nn.Module): + """ + Optimized self-attention module for SkyReels-V3 using vLLM layers. + """ + + def __init__( + self, + dim: int, + num_heads: int, + head_dim: int, + eps: float = 1e-5, + dropout: float = 0.0, + ): + super().__init__() + + self.dim = dim + self.num_heads = num_heads + self.head_dim = head_dim + self.inner_dim = num_heads * head_dim + + # Fused QKV projection using vLLM's optimized layer + self.to_qkv = QKVParallelLinear( + hidden_size=dim, + head_size=head_dim, + total_num_heads=num_heads, + bias=True, + disable_tp=True, + ) + + # QK normalization using vLLM's RMSNorm + self.norm_q = RMSNorm(self.inner_dim, eps=eps) + self.norm_k = RMSNorm(self.inner_dim, eps=eps) + + # Output projection + self.to_out = nn.ModuleList( + [ + ReplicatedLinear(self.inner_dim, dim, bias=True), + nn.Dropout(dropout), + ] + ) + + # Unified attention layer + self.attn = Attention( + num_heads=num_heads, + head_size=head_dim, + softmax_scale=1.0 / (head_dim**0.5), + causal=False, + ) + + def forward( + self, + hidden_states: torch.Tensor, + rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + # Fused QKV projection + qkv, _ = self.to_qkv(hidden_states) + query, key, value = qkv.chunk(3, dim=-1) + + # Apply QK normalization + query = self.norm_q(query) + key = self.norm_k(key) + + # Reshape for multi-head attention + query = query.unflatten(2, (self.num_heads, -1)) + key = key.unflatten(2, (self.num_heads, -1)) + value = value.unflatten(2, (self.num_heads, -1)) + + # Apply rotary embeddings + if rotary_emb is not None: + freqs_cos, freqs_sin = rotary_emb + query = apply_rotary_emb_skyreels(query, freqs_cos, freqs_sin) + key = apply_rotary_emb_skyreels(key, freqs_cos, freqs_sin) + + # Compute attention using unified attention layer + hidden_states = self.attn(query, key, value) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + # Output projection + hidden_states, _ = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states) + + return hidden_states + + +class SkyReelsCrossAttention(nn.Module): + """ + Optimized cross-attention module for SkyReels-V3 using vLLM layers. + """ + + def __init__( + self, + dim: int, + num_heads: int, + head_dim: int, + text_dim: int, + eps: float = 1e-5, + dropout: float = 0.0, + cross_attn_norm: bool = False, + ): + super().__init__() + + self.dim = dim + self.num_heads = num_heads + self.head_dim = head_dim + self.inner_dim = num_heads * head_dim + self.cross_attn_norm = cross_attn_norm + + # Query projection + self.to_q = ReplicatedLinear(dim, self.inner_dim, bias=True) + + # Key and Value projections for encoder + self.to_kv = ReplicatedLinear(text_dim, self.inner_dim * 2, bias=True) + + # QK normalization + self.norm_q = RMSNorm(self.inner_dim, eps=eps) + self.norm_k = RMSNorm(self.inner_dim, eps=eps) + + # Optional encoder normalization + if cross_attn_norm: + self.norm_encoder = RMSNorm(text_dim, eps=eps) + + # Output projection + self.to_out = nn.ModuleList( + [ + ReplicatedLinear(self.inner_dim, dim, bias=True), + nn.Dropout(dropout), + ] + ) + + # Unified attention layer + self.attn = Attention( + num_heads=num_heads, + head_size=head_dim, + softmax_scale=1.0 / (head_dim**0.5), + causal=False, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + ) -> torch.Tensor: + # Normalize encoder if needed + if self.cross_attn_norm: + encoder_hidden_states = self.norm_encoder(encoder_hidden_states) + + # Project query + query, _ = self.to_q(hidden_states) + query = self.norm_q(query) + + # Project key and value + kv, _ = self.to_kv(encoder_hidden_states) + key, value = kv.chunk(2, dim=-1) + key = self.norm_k(key) + + # Reshape for multi-head attention + query = query.unflatten(2, (self.num_heads, -1)) + key = key.unflatten(2, (self.num_heads, -1)) + value = value.unflatten(2, (self.num_heads, -1)) + + # Compute attention + hidden_states = self.attn(query, key, value) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + # Output projection + hidden_states, _ = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states) + + return hidden_states + + +class SkyReelsTransformerBlock(nn.Module): + """ + Transformer block for SkyReels-V3. + Includes self-attention, cross-attention, and feed-forward layers. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + text_dim: int, + ffn_dim: int | None = None, + dropout: float = 0.0, + cross_attn_norm: bool = False, + eps: float = 1e-6, + ): + super().__init__() + + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + + # Self-attention + self.norm1 = RMSNorm(dim, eps=eps) + self.attn1 = SkyReelsSelfAttention( + dim=dim, + num_heads=num_attention_heads, + head_dim=attention_head_dim, + eps=eps, + dropout=dropout, + ) + + # Cross-attention + self.norm2 = RMSNorm(dim, eps=eps) + self.attn2 = SkyReelsCrossAttention( + dim=dim, + num_heads=num_attention_heads, + head_dim=attention_head_dim, + text_dim=text_dim, + eps=eps, + dropout=dropout, + cross_attn_norm=cross_attn_norm, + ) + + # Feed-forward + self.norm3 = RMSNorm(dim, eps=eps) + ffn_dim = ffn_dim or dim * 4 + self.ff = FeedForward( + dim=dim, + dim_out=dim, + mult=ffn_dim // dim, + dropout=dropout, + activation_fn="gelu-approximate", + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + """ + Forward pass through the transformer block. + + Args: + hidden_states: Input tensor [B, S, D] + encoder_hidden_states: Encoder outputs for cross-attention [B, T, D_text] + rotary_emb: Rotary position embeddings (cos, sin) + + Returns: + Output tensor [B, S, D] + """ + # Self-attention + norm_hidden_states = self.norm1(hidden_states) + attn_output = self.attn1(norm_hidden_states, rotary_emb=rotary_emb) + hidden_states = hidden_states + attn_output + + # Cross-attention + if encoder_hidden_states is not None: + norm_hidden_states = self.norm2(hidden_states) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states) + hidden_states = hidden_states + attn_output + + # Feed-forward + norm_hidden_states = self.norm3(hidden_states) + ff_output = self.ff(norm_hidden_states) + hidden_states = hidden_states + ff_output + + return hidden_states + + +class SkyReelsTransformer3DModel(nn.Module): + """ + SkyReels-V3 3D Transformer Model for video generation. + + Supports multiple modalities: + - Text-to-Video + - Image-to-Video (R2V) + - Video-to-Video (V2V) + - Audio-to-Video (A2V) + """ + + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: int = 16, + out_channels: int | None = None, + num_layers: int = 28, + dropout: float = 0.0, + text_dim: int = 4096, + ffn_dim: int | None = None, + patch_size: tuple[int, int, int] = (1, 2, 2), + cross_attn_norm: bool = False, + eps: float = 1e-6, + max_seq_len: int = 16384, + image_dim: int | None = None, + added_kv_proj_dim: int | None = None, + ): + super().__init__() + + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.inner_dim = num_attention_heads * attention_head_dim + self.in_channels = in_channels + self.out_channels = out_channels or in_channels + self.patch_size = patch_size + self.text_dim = text_dim + self.image_dim = image_dim + self.added_kv_proj_dim = added_kv_proj_dim + + # Input projection + self.proj_in = Conv3dLayer( + in_channels=in_channels, + out_channels=self.inner_dim, + kernel_size=patch_size, + stride=patch_size, + padding=0, + bias=True, + ) + + # Timestep embedding + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.time_embedding = TimestepEmbedding( + in_channels=256, + time_embed_dim=self.inner_dim, + ) + + # Text projection + self.text_proj = PixArtAlphaTextProjection( + in_features=text_dim, + hidden_size=self.inner_dim, + ) + + # Optional image conditioning projection + if image_dim is not None: + self.image_proj = ReplicatedLinear( + image_dim, + self.inner_dim, + bias=True, + ) + + # Rotary position embeddings + self.pos_embed = SkyReelsRotaryPosEmbed( + attention_head_dim=attention_head_dim, + patch_size=patch_size, + max_seq_len=max_seq_len, + ) + + # Transformer blocks + self.transformer_blocks = nn.ModuleList( + [ + SkyReelsTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + text_dim=self.inner_dim, + ffn_dim=ffn_dim, + dropout=dropout, + cross_attn_norm=cross_attn_norm, + eps=eps, + ) + for _ in range(num_layers) + ] + ) + + # Output layers + self.norm_out = RMSNorm(self.inner_dim, eps=eps) + self.proj_out = ReplicatedLinear( + self.inner_dim, + self.out_channels * patch_size[0] * patch_size[1] * patch_size[2], + bias=True, + ) + + # Sequence parallel plan (for distributed training) + self._sp_plan = { + "input": SequenceParallelInput(split_dim=1, expected_dims=3), + "output": SequenceParallelOutput(gather_dim=1, expected_dims=3), + } + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_attention_mask: torch.Tensor | None = None, + image_hidden_states: torch.Tensor | None = None, + return_dict: bool = True, + ) -> Transformer2DModelOutput | tuple: + """ + Forward pass through the SkyReels-V3 transformer. + + Args: + hidden_states: Latent video tensor [B, C, T, H, W] + timestep: Diffusion timestep [B] + encoder_hidden_states: Text embeddings [B, seq_len, text_dim] + encoder_attention_mask: Attention mask for text + image_hidden_states: Optional image conditioning [B, image_dim] + return_dict: Whether to return a dict or tuple + + Returns: + Transformer output + """ + batch_size, channels, num_frames, height, width = hidden_states.shape + + # Project input + hidden_states = self.proj_in(hidden_states) # [B, inner_dim, T', H', W'] + + # Reshape to sequence + t_out, h_out, w_out = ( + num_frames // self.patch_size[0], + height // self.patch_size[1], + width // self.patch_size[2], + ) + hidden_states = hidden_states.flatten(2).transpose(1, 2) # [B, T'*H'*W', inner_dim] + + # Timestep embedding + timestep_emb = self.time_proj(timestep) + timestep_emb = self.time_embedding(timestep_emb) # [B, inner_dim] + timestep_emb = timestep_emb.unsqueeze(1) # [B, 1, inner_dim] + + # Text projection + encoder_hidden_states = self.text_proj(encoder_hidden_states) # [B, seq_len, inner_dim] + + # Add timestep to encoder hidden states + encoder_hidden_states = torch.cat([timestep_emb, encoder_hidden_states], dim=1) + + # Optional image conditioning + if image_hidden_states is not None and self.image_dim is not None: + image_emb = self.image_proj(image_hidden_states).unsqueeze(1) # [B, 1, inner_dim] + encoder_hidden_states = torch.cat([encoder_hidden_states, image_emb], dim=1) + + # Get rotary position embeddings + seq_len = t_out * h_out * w_out + freqs_cos = self.pos_embed.freqs_cos[:seq_len] # type: ignore + freqs_sin = self.pos_embed.freqs_sin[:seq_len] # type: ignore + rotary_emb = (freqs_cos, freqs_sin) + + # Transformer blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + rotary_emb=rotary_emb, + ) + + # Output projection + hidden_states = self.norm_out(hidden_states) + hidden_states = self.proj_out(hidden_states) + + # Reshape back to video format + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, + self.out_channels, + t_out, + h_out, + w_out, + self.patch_size[0], + self.patch_size[1], + self.patch_size[2], + ) + hidden_states = hidden_states.permute(0, 1, 2, 5, 3, 6, 4, 7).reshape( + batch_size, self.out_channels, num_frames, height, width + ) + + if not return_dict: + return (hidden_states,) + + return Transformer2DModelOutput(sample=hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + """Load model weights.""" + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if name not in params_dict: + continue + param = params_dict[name] + default_weight_loader(param, loaded_weight) + diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index 0243524793f..255ed825eb4 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -95,6 +95,11 @@ "pipeline_flux", "FluxPipeline", ), + "SkyReelsV3R2VPipeline": ( + "skyreels_v3", + "pipeline_skyreels_v3_r2v", + "SkyReelsV3R2VPipeline", + ), } @@ -233,6 +238,7 @@ def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) - "StableDiffusion3Pipeline": "get_sd3_image_post_process_func", "Flux2KleinPipeline": "get_flux2_klein_post_process_func", "FluxPipeline": "get_flux_post_process_func", + "SkyReelsV3R2VPipeline": "get_skyreels_v3_r2v_post_process_func", } _DIFFUSION_PRE_PROCESS_FUNCS = { @@ -246,6 +252,7 @@ def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) - "QwenImageLayeredPipeline": "get_qwen_image_layered_pre_process_func", "WanPipeline": "get_wan22_pre_process_func", "WanImageToVideoPipeline": "get_wan22_i2v_pre_process_func", + "SkyReelsV3R2VPipeline": "get_skyreels_v3_r2v_pre_process_func", } diff --git a/vllm_omni/model_executor/stage_configs/skyreels_v3_r2v.yaml b/vllm_omni/model_executor/stage_configs/skyreels_v3_r2v.yaml new file mode 100644 index 00000000000..b4f53c32d2a --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/skyreels_v3_r2v.yaml @@ -0,0 +1,47 @@ +stage_args: + - stage_id: 0 + stage_type: diffusion + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: dit + model_class_name: SkyReelsV3R2VPipeline + gpu_memory_utilization: 0.85 + enforce_eager: true + trust_remote_code: true + engine_output_type: video + distributed_executor_backend: "mp" + enable_prefix_caching: false + max_num_batched_tokens: 32768 + tensor_parallel_size: 1 + # SkyReels-V3 specific settings + flow_shift: 5.0 # Flow matching shift parameter + attention_backend: "sdpa" # Use SDPA for attention + final_output: true + final_output_type: video + is_comprehension: false + default_sampling_params: + height: 480 + width: 832 + num_frames: 81 + num_inference_steps: 50 + guidance_scale: 7.5 + seed: 42 + num_outputs_per_prompt: 1 + +# Runtime configuration +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 + + # Distributed connectors configuration + connectors: + shared_memory_connector: + name: SharedMemoryConnector + extra: + shm_threshold_bytes: 65536 # 64KB threshold + + edges: [] # Single stage, no edges needed From 4b0e05e81e2b951a788277232b0b67ae054af751 Mon Sep 17 00:00:00 2001 From: gDINESH13 Date: Sun, 8 Feb 2026 11:40:06 +0530 Subject: [PATCH 2/5] pre-commit formatting Signed-off-by: gDINESH13 --- examples/offline_inference/skyreels_v3/README.md | 4 ++-- .../skyreels_v3/image_to_video.py | 14 ++++++++------ vllm_omni/diffusion/models/skyreels_v3/__init__.py | 2 +- .../models/skyreels_v3/pipeline_skyreels_v3_r2v.py | 13 +++++-------- .../models/skyreels_v3/skyreels_v3_transformer.py | 6 +----- 5 files changed, 17 insertions(+), 22 deletions(-) diff --git a/examples/offline_inference/skyreels_v3/README.md b/examples/offline_inference/skyreels_v3/README.md index fe5cc5e539d..a6804451071 100644 --- a/examples/offline_inference/skyreels_v3/README.md +++ b/examples/offline_inference/skyreels_v3/README.md @@ -97,7 +97,7 @@ python image_to_video.py \ 1. **Image Quality**: Use high-quality reference images for best results 2. **Aspect Ratio**: The model works best with 16:9 aspect ratio (e.g., 832x480) 3. **Frame Count**: More frames = longer videos but slower generation -4. **Guidance Scale**: +4. **Guidance Scale**: - Lower (3-5): More creative, less adherence to prompt - Medium (7-9): Balanced - Higher (10+): Strong prompt adherence, may reduce quality @@ -141,4 +141,4 @@ If you use SkyReels-V3 in your research, please cite: ## License -SkyReels-V3 models are released under the Skywork License. Please refer to the model card on Hugging Face for details. \ No newline at end of file +SkyReels-V3 models are released under the Skywork License. Please refer to the model card on Hugging Face for details. diff --git a/examples/offline_inference/skyreels_v3/image_to_video.py b/examples/offline_inference/skyreels_v3/image_to_video.py index 4ad4f31a97f..a3b6d94c660 100644 --- a/examples/offline_inference/skyreels_v3/image_to_video.py +++ b/examples/offline_inference/skyreels_v3/image_to_video.py @@ -98,7 +98,7 @@ def main(): choices=["mp4", "gif", "frames"], help="Output format: mp4, gif, or frames (default: mp4)", ) - + args = parser.parse_args() # Create output directory @@ -108,7 +108,7 @@ def main(): # Load reference image if not os.path.exists(args.image): raise FileNotFoundError(f"Image not found: {args.image}") - + image = Image.open(args.image).convert("RGB") print(f"Loaded reference image: {args.image} ({image.size})") @@ -122,7 +122,7 @@ def main(): # Prepare the request print(f"\nGenerating video with prompt: '{args.prompt}'") - print(f"Parameters:") + print("Parameters:") print(f" - Resolution: {args.width}x{args.height}") print(f" - Frames: {args.num_frames}") print(f" - Steps: {args.num_inference_steps}") @@ -150,21 +150,23 @@ def main(): # Save the generated video for idx, output in enumerate(outputs): video_frames = output.outputs[0] # Get the video frames - + if args.output_format == "mp4": output_path = output_dir / f"video_{idx:04d}.mp4" # Save as MP4 video import imageio + imageio.mimsave(output_path, video_frames, fps=24, codec="libx264") print(f"\nSaved video to: {output_path}") - + elif args.output_format == "gif": output_path = output_dir / f"video_{idx:04d}.gif" # Save as GIF import imageio + imageio.mimsave(output_path, video_frames, fps=12) print(f"\nSaved GIF to: {output_path}") - + else: # frames frames_dir = output_dir / f"video_{idx:04d}_frames" frames_dir.mkdir(exist_ok=True) diff --git a/vllm_omni/diffusion/models/skyreels_v3/__init__.py b/vllm_omni/diffusion/models/skyreels_v3/__init__.py index 231f4509c86..cd8cb042553 100644 --- a/vllm_omni/diffusion/models/skyreels_v3/__init__.py +++ b/vllm_omni/diffusion/models/skyreels_v3/__init__.py @@ -12,4 +12,4 @@ "get_skyreels_v3_r2v_post_process_func", "get_skyreels_v3_r2v_pre_process_func", "SkyReelsTransformer3DModel", -] \ No newline at end of file +] diff --git a/vllm_omni/diffusion/models/skyreels_v3/pipeline_skyreels_v3_r2v.py b/vllm_omni/diffusion/models/skyreels_v3/pipeline_skyreels_v3_r2v.py index bf33941077c..24c8b117fe8 100644 --- a/vllm_omni/diffusion/models/skyreels_v3/pipeline_skyreels_v3_r2v.py +++ b/vllm_omni/diffusion/models/skyreels_v3/pipeline_skyreels_v3_r2v.py @@ -13,7 +13,6 @@ import json import logging import os -from typing import Any, cast import numpy as np import PIL.Image @@ -22,7 +21,6 @@ from diffusers.utils.torch_utils import randn_tensor from torch import nn from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel -from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin @@ -345,7 +343,7 @@ def forward( # Extract parameters prompt = [p["prompt"] if isinstance(p, dict) else p for p in request.prompts] batch_size = len(prompt) - + # Get sampling parameters height = request.sampling_params.height or 480 width = request.sampling_params.width or 832 @@ -372,10 +370,10 @@ def forward( img = p["additional_information"].get("preprocessed_image") if img is not None: images.append(img) - + if not images: raise ValueError("No preprocessed images found in request") - + image_tensor = torch.cat(images, dim=0).to(device=device, dtype=dtype) image_embeds = self.encode_image(image_tensor, device, num_videos_per_prompt) @@ -388,11 +386,11 @@ def forward( height // self.vae_scale_factor, width // self.vae_scale_factor, ) - + generator = torch.Generator(device=device) if request.sampling_params.seed is not None: generator.manual_seed(request.sampling_params.seed) - + latents = randn_tensor(latents_shape, generator=generator, device=device, dtype=dtype) # Prepare scheduler @@ -432,4 +430,3 @@ def forward( output=video, request_id=request.request_id, ) - diff --git a/vllm_omni/diffusion/models/skyreels_v3/skyreels_v3_transformer.py b/vllm_omni/diffusion/models/skyreels_v3/skyreels_v3_transformer.py index e30a237bfdd..963208a544f 100644 --- a/vllm_omni/diffusion/models/skyreels_v3/skyreels_v3_transformer.py +++ b/vllm_omni/diffusion/models/skyreels_v3/skyreels_v3_transformer.py @@ -11,16 +11,13 @@ - Audio-to-Video (A2V) """ -import math from collections.abc import Iterable -from typing import Any import torch import torch.nn as nn from diffusers.models.attention import FeedForward from diffusers.models.embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps from diffusers.models.modeling_outputs import Transformer2DModelOutput -from diffusers.models.normalization import FP32LayerNorm from vllm.logger import init_logger from vllm.model_executor.layers.conv import Conv3dLayer from vllm.model_executor.layers.layernorm import RMSNorm @@ -540,7 +537,7 @@ def forward( # Project input hidden_states = self.proj_in(hidden_states) # [B, inner_dim, T', H', W'] - + # Reshape to sequence t_out, h_out, w_out = ( num_frames // self.patch_size[0], @@ -611,4 +608,3 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): continue param = params_dict[name] default_weight_loader(param, loaded_weight) - From 2e7264710a248e2113b0b88fc2eed494e1bf066c Mon Sep 17 00:00:00 2001 From: gDINESH13 Date: Sun, 8 Feb 2026 15:12:59 +0530 Subject: [PATCH 3/5] feat(add new model): Review comments Signed-off-by: gDINESH13 --- .../skyreels_v3/image_to_video.py | 30 +++++++++++++------ .../skyreels_v3/pipeline_skyreels_v3_r2v.py | 18 +++++++++-- 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/examples/offline_inference/skyreels_v3/image_to_video.py b/examples/offline_inference/skyreels_v3/image_to_video.py index a3b6d94c660..a0bbdfc298f 100644 --- a/examples/offline_inference/skyreels_v3/image_to_video.py +++ b/examples/offline_inference/skyreels_v3/image_to_video.py @@ -21,6 +21,8 @@ from PIL import Image from vllm_omni.entrypoints.omni_diffusion import OmniDiffusion +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.outputs import OmniRequestOutput def main(): @@ -137,19 +139,29 @@ def main(): "multi_modal_data": {"image": image}, } ], - sampling_params={ - "height": args.height, - "width": args.width, - "num_frames": args.num_frames, - "num_inference_steps": args.num_inference_steps, - "guidance_scale": args.guidance_scale, - "seed": args.seed, - }, + sampling_params=OmniDiffusionSamplingParams( + height=args.height, + width=args.width, + num_frames=args.num_frames, + num_inference_steps=args.num_inference_steps, + guidance_scale=args.guidance_scale, + seed=args.seed, + ), ) # Save the generated video for idx, output in enumerate(outputs): - video_frames = output.outputs[0] # Get the video frames + # Extract video frames from OmniRequestOutput + video_frames = None + if isinstance(output, OmniRequestOutput): + if hasattr(output, "images") and output.images: + video_frames = output.images[0] + elif hasattr(output, "multimodal_output") and output.multimodal_output: + video_frames = output.multimodal_output[0] + else: + raise ValueError("No video data found in diffusion output.") + else: + raise TypeError(f"Unexpected output type: {type(output)}") if args.output_format == "mp4": output_path = output_dir / f"video_{idx:04d}.mp4" diff --git a/vllm_omni/diffusion/models/skyreels_v3/pipeline_skyreels_v3_r2v.py b/vllm_omni/diffusion/models/skyreels_v3/pipeline_skyreels_v3_r2v.py index 24c8b117fe8..bbbd7227b1d 100644 --- a/vllm_omni/diffusion/models/skyreels_v3/pipeline_skyreels_v3_r2v.py +++ b/vllm_omni/diffusion/models/skyreels_v3/pipeline_skyreels_v3_r2v.py @@ -243,6 +243,9 @@ def __init__( # Load scheduler self.scheduler = loader.load_scheduler(FlowUniPCMultistepScheduler, "scheduler") + if hasattr(od_config, "flow_shift") and od_config.flow_shift is not None: + self.scheduler.config.shift = od_config.flow_shift + # Move to device device = get_local_device() self.to(device) @@ -422,9 +425,18 @@ def forward( # Compute previous noisy sample latents = self.scheduler.step(noise_pred, t, latents).prev_sample - # Decode latents - latents = latents / self.vae.config.scaling_factor - video = self.vae.decode(latents).sample + # Decode latents using AutoencoderKLWan normalization + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] return DiffusionOutput( output=video, From 635e3046b295a384d70dbca7cd1faf5c9f063eff Mon Sep 17 00:00:00 2001 From: gDINESH13 Date: Mon, 9 Feb 2026 20:09:54 +0530 Subject: [PATCH 4/5] Removed CFG Parallel Mixin Signed-off-by: gDINESH13 --- .../diffusion/models/skyreels_v3/pipeline_skyreels_v3_r2v.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm_omni/diffusion/models/skyreels_v3/pipeline_skyreels_v3_r2v.py b/vllm_omni/diffusion/models/skyreels_v3/pipeline_skyreels_v3_r2v.py index bbbd7227b1d..dd7081ae793 100644 --- a/vllm_omni/diffusion/models/skyreels_v3/pipeline_skyreels_v3_r2v.py +++ b/vllm_omni/diffusion/models/skyreels_v3/pipeline_skyreels_v3_r2v.py @@ -23,7 +23,6 @@ from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig -from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.interface import SupportImageInput @@ -186,7 +185,7 @@ def pre_process_func(request: OmniDiffusionRequest) -> OmniDiffusionRequest: return pre_process_func -class SkyReelsV3R2VPipeline(nn.Module, SupportImageInput, CFGParallelMixin): +class SkyReelsV3R2VPipeline(nn.Module, SupportImageInput): """ SkyReels-V3 Image-to-Video (R2V) Pipeline. From fc4d64a3b002f7c8c4417ecef49a703fd617be72 Mon Sep 17 00:00:00 2001 From: gDINESH13 Date: Wed, 11 Feb 2026 16:20:18 +0530 Subject: [PATCH 5/5] Applied review comments Signed-off-by: gDINESH13 --- .../skyreels_v3/image_to_video.py | 5 +- .../skyreels_v3/pipeline_skyreels_v3_r2v.py | 50 ++++++++++++------- .../skyreels_v3/skyreels_v3_transformer.py | 27 +++++----- 3 files changed, 47 insertions(+), 35 deletions(-) diff --git a/examples/offline_inference/skyreels_v3/image_to_video.py b/examples/offline_inference/skyreels_v3/image_to_video.py index a0bbdfc298f..81b132cb277 100644 --- a/examples/offline_inference/skyreels_v3/image_to_video.py +++ b/examples/offline_inference/skyreels_v3/image_to_video.py @@ -154,10 +154,9 @@ def main(): # Extract video frames from OmniRequestOutput video_frames = None if isinstance(output, OmniRequestOutput): + # In diffusion mode, output.images is the full list of frames if hasattr(output, "images") and output.images: - video_frames = output.images[0] - elif hasattr(output, "multimodal_output") and output.multimodal_output: - video_frames = output.multimodal_output[0] + video_frames = output.images else: raise ValueError("No video data found in diffusion output.") else: diff --git a/vllm_omni/diffusion/models/skyreels_v3/pipeline_skyreels_v3_r2v.py b/vllm_omni/diffusion/models/skyreels_v3/pipeline_skyreels_v3_r2v.py index dd7081ae793..d1613da7885 100644 --- a/vllm_omni/diffusion/models/skyreels_v3/pipeline_skyreels_v3_r2v.py +++ b/vllm_omni/diffusion/models/skyreels_v3/pipeline_skyreels_v3_r2v.py @@ -307,12 +307,12 @@ def encode_prompt( def encode_image( self, - image: PIL.Image.Image | torch.Tensor, + image: PIL.Image.Image | list[PIL.Image.Image] | torch.Tensor, device: torch.device, num_videos_per_prompt: int = 1, ) -> torch.Tensor: """Encode reference image using CLIP.""" - if isinstance(image, PIL.Image.Image): + if isinstance(image, (PIL.Image.Image, list)): image = self.image_processor(images=image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=self.image_encoder.dtype) @@ -351,8 +351,10 @@ def forward( width = request.sampling_params.width or 832 num_frames = request.sampling_params.num_frames or 81 num_inference_steps = request.sampling_params.num_inference_steps or 50 - guidance_scale = request.sampling_params.guidance_scale or 7.5 - num_videos_per_prompt = request.sampling_params.num_videos_per_prompt or 1 + guidance_scale = ( + request.sampling_params.guidance_scale if request.sampling_params.guidance_scale is not None else 7.5 + ) + num_videos_per_prompt = getattr(request.sampling_params, "num_outputs_per_prompt", None) or 1 do_classifier_free_guidance = guidance_scale > 1.0 @@ -365,28 +367,40 @@ def forward( request.sampling_params.negative_prompt, ) - # Encode reference image - images = [] + # Encode reference image using CLIP + # Use the resized PIL image from multi_modal_data, not the VAE-preprocessed one + pil_images = [] for p in request.prompts: - if isinstance(p, dict) and "additional_information" in p: - img = p["additional_information"].get("preprocessed_image") - if img is not None: - images.append(img) + multi_modal_data = p.get("multi_modal_data") or {} if isinstance(p, dict) else {} + img = multi_modal_data.get("image") + if img is not None: + pil_images.append(img) - if not images: - raise ValueError("No preprocessed images found in request") + if not pil_images: + raise ValueError("No reference images found in request (expected in prompt['multi_modal_data']['image'])") - image_tensor = torch.cat(images, dim=0).to(device=device, dtype=dtype) - image_embeds = self.encode_image(image_tensor, device, num_videos_per_prompt) + # Use CLIPImageProcessor to obtain CLIP-ready image tensor + image_embeds = self.encode_image(pil_images, device, num_videos_per_prompt) # Prepare latents + # Use separate temporal and spatial VAE scale factors, as required by AutoencoderKLWan. + vae_scale_factor_temporal = getattr(self.vae.config, "scale_factor_temporal", 1) + vae_scale_factor_spatial = getattr( + self.vae.config, + "scale_factor_spatial", + getattr(self, "vae_scale_factor", 1), + ) + + # Latent temporal length is typically shorter than the decoded frame count. + num_latent_frames = (num_frames - 1) // vae_scale_factor_temporal + 1 + num_channels_latents = self.transformer.in_channels latents_shape = ( batch_size * num_videos_per_prompt, num_channels_latents, - num_frames, - height // self.vae_scale_factor, - width // self.vae_scale_factor, + num_latent_frames, + height // vae_scale_factor_spatial, + width // vae_scale_factor_spatial, ) generator = torch.Generator(device=device) @@ -413,7 +427,7 @@ def forward( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, - image_hidden_states=image_embeds if do_classifier_free_guidance else image_embeds.repeat(2, 1), + image_hidden_states=image_embeds.repeat(2, 1) if do_classifier_free_guidance else image_embeds, ).sample # Perform CFG diff --git a/vllm_omni/diffusion/models/skyreels_v3/skyreels_v3_transformer.py b/vllm_omni/diffusion/models/skyreels_v3/skyreels_v3_transformer.py index 963208a544f..e4e69d70c87 100644 --- a/vllm_omni/diffusion/models/skyreels_v3/skyreels_v3_transformer.py +++ b/vllm_omni/diffusion/models/skyreels_v3/skyreels_v3_transformer.py @@ -25,10 +25,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm_omni.diffusion.attention.layer import Attention -from vllm_omni.diffusion.distributed.sp_plan import ( - SequenceParallelInput, - SequenceParallelOutput, -) logger = init_logger(__name__) @@ -50,8 +46,9 @@ def apply_rotary_emb_skyreels( Tensor with rotary embeddings applied """ x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) - cos = freqs_cos[..., 0::2] - sin = freqs_sin[..., 1::2] + # freqs_cos and freqs_sin are 2D [seq_len, dim], so use : indexing not ... + cos = freqs_cos[:, 0::2] + sin = freqs_sin[:, 1::2] out = torch.empty_like(hidden_states) out[..., 0::2] = x1 * cos - x2 * sin out[..., 1::2] = x1 * sin + x2 * cos @@ -206,6 +203,13 @@ def forward( # Apply rotary embeddings if rotary_emb is not None: freqs_cos, freqs_sin = rotary_emb + # freqs_* are 2D [S, head_dim]; make them broadcastable with + # query/key of shape [B, S, num_heads, head_dim] by reshaping to + # [1, S, 1, head_dim]. + if freqs_cos.dim() == 2: + freqs_cos = freqs_cos.unsqueeze(0).unsqueeze(2) + if freqs_sin.dim() == 2: + freqs_sin = freqs_sin.unsqueeze(0).unsqueeze(2) query = apply_rotary_emb_skyreels(query, freqs_cos, freqs_sin) key = apply_rotary_emb_skyreels(key, freqs_cos, freqs_sin) @@ -504,12 +508,6 @@ def __init__( bias=True, ) - # Sequence parallel plan (for distributed training) - self._sp_plan = { - "input": SequenceParallelInput(split_dim=1, expected_dims=3), - "output": SequenceParallelOutput(gather_dim=1, expected_dims=3), - } - def forward( self, hidden_states: torch.Tensor, @@ -559,7 +557,8 @@ def forward( # Optional image conditioning if image_hidden_states is not None and self.image_dim is not None: - image_emb = self.image_proj(image_hidden_states).unsqueeze(1) # [B, 1, inner_dim] + image_emb, _ = self.image_proj(image_hidden_states) + image_emb = image_emb.unsqueeze(1) # [B, 1, inner_dim] encoder_hidden_states = torch.cat([encoder_hidden_states, image_emb], dim=1) # Get rotary position embeddings @@ -578,7 +577,7 @@ def forward( # Output projection hidden_states = self.norm_out(hidden_states) - hidden_states = self.proj_out(hidden_states) + hidden_states, _ = self.proj_out(hidden_states) # Reshape back to video format hidden_states = hidden_states.transpose(1, 2).reshape(