diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 0fbbe6eb7c8..34a16a313d9 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -40,6 +40,7 @@ th { |`Qwen3TTSForConditionalGeneration` | Qwen3-TTS-12Hz-1.7B-CustomVoice | `Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice` | |`Qwen3TTSForConditionalGeneration` | Qwen3-TTS-12Hz-1.7B-VoiceDesign | `Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign` | |`Qwen3TTSForConditionalGeneration` | Qwen3-TTS-12Hz-1.7B-Base | `Qwen/Qwen3-TTS-12Hz-0.6B-Base` | +|`NextStep11Pipeline` | NextStep-1.1 | `stepfun-ai/NextStep-1.1` | ## List of Supported Models for NPU diff --git a/docs/user_guide/diffusion_acceleration.md b/docs/user_guide/diffusion_acceleration.md index 1a2e0a7d23f..4a7ba65eeb8 100644 --- a/docs/user_guide/diffusion_acceleration.md +++ b/docs/user_guide/diffusion_acceleration.md @@ -65,6 +65,7 @@ The following table shows which models are currently supported by each accelerat | **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | | **Bagel** | `ByteDance-Seed/BAGEL-7B-MoT` | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | | **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | +| **NextStep-1.1** | `stepfun-ai/NextStep-1.1` | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ### VideoGen diff --git a/examples/offline_inference/text_to_image/README.md b/examples/offline_inference/text_to_image/README.md index 9c5ceb927b7..2bfae847238 100644 --- a/examples/offline_inference/text_to_image/README.md +++ b/examples/offline_inference/text_to_image/README.md @@ -1,6 +1,6 @@ # Text-To-Image -This folder provides several entrypoints for experimenting with `Qwen/Qwen-Image` `Qwen/Qwen-Image-2512` `Tongyi-MAI/Z-Image-Turbo` using vLLM-Omni: +This folder provides several entrypoints for experimenting with `Qwen/Qwen-Image` `Qwen/Qwen-Image-2512` `Tongyi-MAI/Z-Image-Turbo` `stepfun-ai/NextStep-1.1` using vLLM-Omni, note that NextStep-1.1 has different architecture so we treat it differently regarding running arguments and pipeline. - `text_to_image.py`: command-line script for single image generation with advanced options. - `web_demo.py`: lightweight Gradio UI for interactive prompt/seed/CFG exploration. @@ -74,6 +74,8 @@ if __name__ == "__main__": ## Local CLI Usage +### Qwen/Tongyi Models + ```bash python text_to_image.py \ --model Tongyi-MAI/Z-Image-Turbo \ @@ -87,7 +89,26 @@ python text_to_image.py \ --output outputs/coffee.png ``` -Key arguments: +### NextStep Models + +NextStep-1.1 can have extra arguments +```bash +python text_to_image.py \ + --model stepfun-ai/NextStep-1.1 \ + --prompt "A baby panda wearing an Iron Man mask, holding a board with 'NextStep-1' written on it" \ + --height 512 \ + --width 512 \ + --num-inference-steps 28 \ + --guidance-scale 7.5 \ + --guidance-scale-2 1.0 \ + --cfg-schedule constant \ + --output nextstep_output.png \ + --seed 42 +``` + +### Key Arguments + +**Common arguments:** - `--prompt`: text description (string). - `--seed`: integer seed for deterministic sampling. @@ -98,8 +119,15 @@ Key arguments: - `--output`: path to save the generated PNG. - `--vae-use-slicing`: enable VAE slicing for memory optimization. - `--vae-use-tiling`: enable VAE tiling for memory optimization. -- `--cfg-parallel-size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](../../../docs/user_guide/diffusion/parallelism_acceleration.md#cfg-parallel). +- `--cfg-parallel-size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](../../../docs/user_guide/diffusion_acceleration.md#using-cfg-parallel). - `--enable-cpu-offload`: enable CPU offloading for diffusion models. +- `--guidance-scale`: classifier-free guidance scale. + +**NextStep-1.1 specific:** +- `--guidance-scale-2`: secondary guidance scale, e.g. image-level CFG (default: 1.0). +- `--timesteps-shift`: timesteps shift parameter for sampling (default: 1.0). +- `--cfg-schedule`: CFG schedule type, "constant" or "linear" (default: "constant"). +- `--use-norm`: apply layer normalization to sampled tokens. > ℹ️ If you encounter OOM errors, try using `--vae-use-slicing` and `--vae-use-tiling` to reduce memory usage. diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py index db87abf008d..04f66863c60 100644 --- a/examples/offline_inference/text_to_image/text_to_image.py +++ b/examples/offline_inference/text_to_image/text_to_image.py @@ -16,13 +16,26 @@ from vllm_omni.platforms import current_omni_platform +def is_nextstep_model(model_name: str) -> bool: + """Check if the model is a NextStep model by reading its config.""" + from vllm.transformers_utils.config import get_hf_file_to_dict + + try: + cfg = get_hf_file_to_dict("config.json", model_name) + if cfg and cfg.get("model_type") == "nextstep": + return True + except Exception: + pass + return False + + def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Generate an image with Qwen-Image.") + parser = argparse.ArgumentParser(description="Generate an image with supported diffusion models.") parser.add_argument( "--model", default="Qwen/Qwen-Image", help="Diffusion model name or local path. Supported models: " - "Qwen/Qwen-Image, Tongyi-MAI/Z-Image-Turbo, Qwen/Qwen-Image-2512", + "Qwen/Qwen-Image, Tongyi-MAI/Z-Image-Turbo, Qwen/Qwen-Image-2512, stepfun-ai/NextStep-1.1", ) parser.add_argument("--prompt", default="a cup of coffee on the table", help="Text prompt for image generation.") parser.add_argument( @@ -153,16 +166,43 @@ def parse_args() -> argparse.Namespace: default=1, help="Number of ranks used for VAE patch/tile parallelism (decode/encode).", ) + # NextStep-1.1 specific arguments + parser.add_argument( + "--guidance-scale-2", + type=float, + default=1.0, + help="Secondary guidance scale (e.g. image-level CFG for NextStep-1.1).", + ) + parser.add_argument( + "--timesteps-shift", + type=float, + default=1.0, + help="[NextStep-1.1 only] Timesteps shift parameter for sampling.", + ) + parser.add_argument( + "--cfg-schedule", + type=str, + default="constant", + choices=["constant", "linear"], + help="[NextStep-1.1 only] CFG schedule type.", + ) + parser.add_argument( + "--use-norm", + action="store_true", + help="[NextStep-1.1 only] Apply layer normalization to sampled tokens.", + ) return parser.parse_args() def main(): args = parse_args() generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(args.seed) + use_nextstep = is_nextstep_model(args.model) - # Configure cache based on backend type cache_config = None - if args.cache_backend == "cache_dit": + cache_backend = args.cache_backend + + if cache_backend == "cache_dit": # cache-dit configuration: Hybrid DBCache + SCM + TaylorSeer # All parameters marked with [cache-dit only] in DiffusionCacheConfig cache_config = { @@ -179,7 +219,7 @@ def main(): "scm_steps_mask_policy": None, # SCM mask policy: None (disabled), "slow", "medium", "fast", "ultra" "scm_steps_policy": "dynamic", # SCM steps policy: "dynamic" or "static" } - elif args.cache_backend == "tea_cache": + elif cache_backend == "tea_cache": # TeaCache configuration # All parameters marked with [tea_cache only] in DiffusionCacheConfig cache_config = { @@ -213,19 +253,24 @@ def main(): elif args.quantization: quant_kwargs["quantization"] = args.quantization - omni = Omni( - model=args.model, - enable_layerwise_offload=args.enable_layerwise_offload, - vae_use_slicing=args.vae_use_slicing, - vae_use_tiling=args.vae_use_tiling, - cache_backend=args.cache_backend, - cache_config=cache_config, - enable_cache_dit_summary=args.enable_cache_dit_summary, - parallel_config=parallel_config, - enforce_eager=args.enforce_eager, - enable_cpu_offload=args.enable_cpu_offload, + # Initialize Omni with model-specific settings + omni_kwargs = { + "model": args.model, + "enable_layerwise_offload": args.enable_layerwise_offload, + "vae_use_slicing": args.vae_use_slicing, + "vae_use_tiling": args.vae_use_tiling, + "cache_backend": cache_backend, + "cache_config": cache_config, + "enable_cache_dit_summary": args.enable_cache_dit_summary, + "parallel_config": parallel_config, + "enforce_eager": args.enforce_eager, + "enable_cpu_offload": args.enable_cpu_offload, **quant_kwargs, - ) + } + if use_nextstep: + # NextStep-1.1 requires explicit pipeline class + omni_kwargs["model_class_name"] = "NextStep11Pipeline" + omni = Omni(**omni_kwargs) if profiler_enabled: print("[Profiler] Starting profiling...") @@ -236,7 +281,7 @@ def main(): print("Generation Configuration:") print(f" Model: {args.model}") print(f" Inference steps: {args.num_inference_steps}") - print(f" Cache backend: {args.cache_backend if args.cache_backend else 'None (no acceleration)'}") + print(f" Cache backend: {cache_backend if cache_backend else 'None (no acceleration)'}") print(f" Quantization: {args.quantization if args.quantization else 'None (BF16)'}") if ignored_layers: print(f" Ignored layers: {ignored_layers}") @@ -250,6 +295,13 @@ def main(): print(f"{'=' * 60}\n") generation_start = time.perf_counter() + + extra_args = { + "timesteps_shift": args.timesteps_shift, + "cfg_schedule": args.cfg_schedule, + "use_norm": args.use_norm, + } + outputs = omni.generate( { "prompt": args.prompt, @@ -261,10 +313,13 @@ def main(): generator=generator, true_cfg_scale=args.cfg_scale, guidance_scale=args.guidance_scale, + guidance_scale_2=args.guidance_scale_2, num_inference_steps=args.num_inference_steps, num_outputs_per_prompt=args.num_images_per_prompt, + extra_args=extra_args, ), ) + generation_end = time.perf_counter() generation_time = generation_end - generation_start diff --git a/tests/diffusion/models/nextstep_1_1/test_nextstep_cfg_parallel_layout.py b/tests/diffusion/models/nextstep_1_1/test_nextstep_cfg_parallel_layout.py new file mode 100644 index 00000000000..9caf667be6a --- /dev/null +++ b/tests/diffusion/models/nextstep_1_1/test_nextstep_cfg_parallel_layout.py @@ -0,0 +1,217 @@ +from types import SimpleNamespace + +import pytest +import torch +from PIL import Image + +import vllm_omni.diffusion.models.nextstep_1_1.pipeline_nextstep_1_1 as nextstep_pipeline_module +from vllm_omni.diffusion.models.nextstep_1_1.modeling_nextstep_heads import FlowMatchingHead +from vllm_omni.diffusion.models.nextstep_1_1.pipeline_nextstep_1_1 import NextStep11Pipeline + + +class _DummyImageHead: + def __init__(self, token_dim: int): + self.token_dim = token_dim + self.calls = [] + + def sample( + self, + c: torch.Tensor, + cfg: float, + cfg_img: float, + cfg_mult: int, + timesteps_shift: float, + num_sampling_steps: int, + noise_repeat: int, + ) -> torch.Tensor: + self.calls.append( + { + "batch": c.shape[0], + "cfg": cfg, + "cfg_img": cfg_img, + "cfg_mult": cfg_mult, + "noise_repeat": noise_repeat, + } + ) + batch_per_prompt = c.shape[0] // cfg_mult + return torch.ones(batch_per_prompt, self.token_dim, dtype=c.dtype, device=c.device) + + +class _DummyModel: + def __init__(self, hidden_dim: int, token_dim: int): + self.hidden_dim = hidden_dim + self.image_head = _DummyImageHead(token_dim) + self.forward_batches = [] + + def image_out_projector(self, c: torch.Tensor) -> torch.Tensor: + return c + + def image_in_projector(self, sampled_tokens: torch.Tensor) -> torch.Tensor: + bsz = sampled_tokens.shape[0] + return torch.zeros(bsz, 1, self.hidden_dim, dtype=sampled_tokens.dtype, device=sampled_tokens.device) + + def forward_model(self, inputs_embeds: torch.Tensor, attention_mask, past_key_values, use_cache: bool): + del attention_mask, use_cache + self.forward_batches.append(inputs_embeds.shape[0]) + return SimpleNamespace( + last_hidden_state=torch.zeros( + inputs_embeds.shape[0], + 1, + self.hidden_dim, + dtype=inputs_embeds.dtype, + device=inputs_embeds.device, + ), + past_key_values=past_key_values, + ) + + +def _make_minimal_pipeline_for_decoding(hidden_dim: int = 8, token_dim: int = 4): + pipeline = object.__new__(NextStep11Pipeline) + pipeline.config = SimpleNamespace(latent_channels=token_dim, latent_patch_size=1, use_gen_pos_embed=False) + pipeline.model = _DummyModel(hidden_dim=hidden_dim, token_dim=token_dim) + return pipeline + + +@pytest.mark.parametrize( + ("cfg", "cfg_img", "has_image_conditions", "expected_cfg_mult", "expected_cfg_img"), + [ + (1.0, 1.0, False, 1, 1.0), + (7.5, 1.0, False, 2, 1.0), + (7.5, 8.0, False, 2, 1.0), + (7.5, 1.5, True, 3, 1.5), + ], +) +def test_resolve_cfg_layout(cfg, cfg_img, has_image_conditions, expected_cfg_mult, expected_cfg_img): + cfg_mult, effective_cfg_img = NextStep11Pipeline._resolve_cfg_layout(cfg, cfg_img, has_image_conditions) + assert cfg_mult == expected_cfg_mult + assert effective_cfg_img == expected_cfg_img + + +def test_build_captions_ignores_image_cfg_without_image_conditions(): + pipeline = object.__new__(NextStep11Pipeline) + pipeline._image_str = lambda hw: f"" + + captions, images, cfg_mult, effective_cfg_img = pipeline._build_captions( + captions=["a prompt"], + images=None, + num_images_per_caption=1, + positive_prompt=None, + negative_prompt="bad quality", + cfg=7.5, + cfg_img=8.0, + ) + + assert cfg_mult == 2 + assert effective_cfg_img == 1.0 + assert images is None + assert captions == ["a prompt", "bad quality"] + + +def test_build_captions_enables_three_way_cfg_when_image_conditions_exist(): + pipeline = object.__new__(NextStep11Pipeline) + pipeline._image_str = lambda hw: f"" + + image = Image.new("RGB", (64, 32)) + captions, images, cfg_mult, effective_cfg_img = pipeline._build_captions( + captions=["a prompt"], + images=[image], + num_images_per_caption=1, + positive_prompt=None, + negative_prompt="bad quality", + cfg=7.5, + cfg_img=1.5, + ) + + assert cfg_mult == 3 + assert effective_cfg_img == 1.5 + assert len(captions) == 3 + assert captions[1].startswith(" Tensor: + return x * torch.sigmoid(x) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) + + def attention(self, h_: Tensor) -> Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() + k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() + v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() + h_ = nn.functional.scaled_dot_product_attention(q, k, v) + + return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + + def forward(self, x: Tensor) -> Tensor: + return x + self.proj_out(self.attention(x)) + + +class ResnetBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h = x + h = self.norm1(h) + h = swish(h) + h = self.conv1(h) + + h = self.norm2(h) + h = swish(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x + h + + +class Downsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x: Tensor): + pad = (0, 1, 0, 1) + x = nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor): + x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + return x + + +class Encoder(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + ch: int, + ch_mult: list[int], + num_res_blocks: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + # downsampling + self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + block_in = self.ch + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor) -> Tensor: + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + ch: int, + out_ch: int, + ch_mult: list[int], + num_res_blocks: int, + in_channels: int, + resolution: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.ffactor = 2 ** (self.num_resolutions - 1) + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z: Tensor) -> Tensor: + # get dtype for proper tracing + upscale_dtype = next(self.up.parameters()).dtype + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # cast to proper dtype + h = h.to(upscale_dtype) + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +def layer_norm_2d(input: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: + # input.shape = (bsz, c, h, w) + _input = input.permute(0, 2, 3, 1) + _input = F.layer_norm(_input, _input.size()[-1:], None, None, eps) + _input = _input.permute(0, 3, 1, 2) + return _input + + +class AutoencoderKL(nn.Module): + def __init__(self, params: AutoEncoderParams): + super().__init__() + self.params = params + + # Create a config-like object for compatibility + class Config: + def __init__(self, params): + self.latent_channels = params.z_channels + self.block_out_channels = params.ch_mult + self.scaling_factor = params.scaling_factor + self.shift_factor = params.shift_factor + self.out_channels = params.out_ch + + self.config = Config(params) + + self.encoder = Encoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.decoder = Decoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + out_ch=params.out_ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + + self.encoder_norm = params.encoder_norm + self.psz = params.psz + + # Tiling attributes for VAE patch parallelism + self.use_tiling = False + vae_scale_factor = 2 ** (len(params.ch_mult) - 1) + self.tile_latent_min_size = int(params.resolution / vae_scale_factor) + self.tile_overlap_factor = 0.25 + self.tile_sample_min_size = params.resolution + + self.apply(self._init_weights) + + def _init_weights(self, module): + std = 0.02 + if isinstance(module, (nn.Conv2d, nn.Linear)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.GroupNorm): + if module.weight is not None: + module.weight.data.fill_(1.0) + if module.bias is not None: + module.bias.data.zero_() + + @property + def dtype(self): + return self.encoder.conv_in.weight.dtype + + @property + def device(self): + return self.encoder.conv_in.weight.device + + def patchify(self, img: torch.Tensor): + """ + img: (bsz, C, H, W) + x: (bsz, patch_size**2 * C, H / patch_size, W / patch_size) + """ + bsz, c, h, w = img.shape + p = self.psz + h_, w_ = h // p, w // p + + img = img.reshape(bsz, c, h_, p, w_, p) + img = torch.einsum("nchpwq->ncpqhw", img) + x = img.reshape(bsz, c * p**2, h_, w_) + return x + + def unpatchify(self, x: torch.Tensor): + """ + x: (bsz, patch_size**2 * C, H / patch_size, W / patch_size) + img: (bsz, C, H, W) + """ + bsz = x.shape[0] + p = self.psz + c = self.config.latent_channels + h_, w_ = x.shape[2], x.shape[3] + + x = x.reshape(bsz, c, p, p, h_, w_) + x = torch.einsum("ncpqhw->nchpwq", x) + img = x.reshape(bsz, c, h_ * p, w_ * p) + return img + + def encode(self, x: torch.Tensor, return_dict: bool = True): + moments = self.encoder(x) + + mean, logvar = torch.chunk(moments, 2, dim=1) + if self.psz is not None: + mean = self.patchify(mean) + + if self.encoder_norm: + mean = layer_norm_2d(mean) + + if self.psz is not None: + mean = self.unpatchify(mean) + + moments = torch.cat([mean, logvar], dim=1).contiguous() + + posterior = DiagonalGaussianDistribution(moments, deterministic=self.params.deterministic) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def decode(self, z: torch.Tensor, return_dict: bool = True): + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + for y in range(blend_extent): + b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for x in range(blend_extent): + b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) + return b + + def forward(self, input, sample_posterior=True, noise_strength=0.0): + posterior = self.encode(input).latent_dist + z = posterior.sample() if sample_posterior else posterior.mode() + if noise_strength > 0.0: + p = torch.distributions.Uniform(0, noise_strength) + z = z + p.sample((z.shape[0],)).reshape(-1, 1, 1, 1).to(z.device) * randn_tensor( + z.shape, device=z.device, dtype=z.dtype + ) + dec = self.decode(z).sample + return dec, posterior + + @classmethod + def from_pretrained(cls, model_path, **kwargs): + config_path = os.path.join(model_path, "config.json") + ckpt_path = os.path.join(model_path, "checkpoint.pt") + + if not os.path.isdir(model_path): + raise ValueError(f"Model path does not exist: {model_path}") + if not os.path.isfile(config_path): + raise ValueError(f"Config file not found: {config_path}") + if not os.path.isfile(ckpt_path): + raise ValueError(f"Checkpoint file not found: {ckpt_path}") + + state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) + + with open(config_path) as f: + config: dict = json.load(f) + config.update(kwargs) + + # Filter out kwargs that are not in AutoEncoderParams + valid_kwargs = {} + valid_params = { + "resolution", + "in_channels", + "ch", + "out_ch", + "ch_mult", + "num_res_blocks", + "z_channels", + "scaling_factor", + "shift_factor", + "deterministic", + "encoder_norm", + "psz", + } + for key, value in config.items(): + if key in valid_params: + valid_kwargs[key] = value + + params = AutoEncoderParams(**valid_kwargs) + model = cls(params) + model.load_state_dict(state_dict, strict=False) + return model diff --git a/vllm_omni/diffusion/models/nextstep_1_1/modeling_nextstep.py b/vllm_omni/diffusion/models/nextstep_1_1/modeling_nextstep.py new file mode 100644 index 00000000000..ded3079265e --- /dev/null +++ b/vllm_omni/diffusion/models/nextstep_1_1/modeling_nextstep.py @@ -0,0 +1,446 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from NextStep-1.1 (https://huggingface.co/stepfun-ai/NextStep-1.1) +# Original: models/nextstep_model.py — local version with TP-aware layers. + +from __future__ import annotations + +import json +from collections.abc import Iterable + +import numpy as np +import torch +import torch.nn as nn +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.llama.configuration_llama import LlamaConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from vllm_omni.diffusion.models.nextstep_1_1.modeling_nextstep_heads import ( + FlowMatchingHead, +) +from vllm_omni.diffusion.models.nextstep_1_1.modeling_nextstep_llama import ( + LlamaDecoderLayer, + LlamaRMSNorm, + LlamaRotaryEmbedding, +) + +# --------------------------------------------------------------------------- +# Positional embedding utilities (inlined from remote utils/model_utils.py) +# --------------------------------------------------------------------------- + + +def _get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: np.ndarray) -> np.ndarray: + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega + pos = pos.reshape(-1) + out = np.einsum("m,d->md", pos, omega) + return np.concatenate([np.sin(out), np.cos(out)], axis=1) + + +def _get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: np.ndarray) -> np.ndarray: + assert embed_dim % 2 == 0 + emb_h = _get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) + emb_w = _get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) + return np.concatenate([emb_h, emb_w], axis=1) + + +def get_2d_sincos_pos_embed(embed_dim: int, grid_size: int) -> np.ndarray: + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) + grid = np.stack(grid, axis=0).reshape(2, 1, grid_size, grid_size) + return _get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + + +# --------------------------------------------------------------------------- +# NextStepConfig — extends LlamaConfig with NextStep-specific fields. +# This mirrors the remote models/config.py. +# --------------------------------------------------------------------------- + + +class NextStepConfig(LlamaConfig): + model_type = "nextstep" + + def __init__( + self, + vae_name_or_path: str | None = None, + latent_size: int = 32, + latent_patch_size: int = 2, + latent_channels: int = 16, + boi: int | None = None, + eoi: int | None = None, + image_placeholder_id: int | None = None, + pad_token_id_added: int | None = None, + lm_loss_weight: float = 0.01, + im_loss_weight: float = 1.0, + fm_head_dim: int = 1536, + fm_head_layers: int = 12, + fm_head_batch_mul: int = 4, + o_attention_bias: bool | None = None, + **kwargs, + ): + super().__init__(**kwargs) + self.vae_name_or_path = vae_name_or_path + self.latent_size = latent_size + self.latent_patch_size = latent_patch_size + self.latent_channels = latent_channels + self.boi = boi + self.eoi = eoi + self.image_placeholder_id = image_placeholder_id + self.pad_token_id_added = pad_token_id_added + self.lm_loss_weight = lm_loss_weight + self.im_loss_weight = im_loss_weight + self.fm_head_dim = fm_head_dim + self.fm_head_layers = fm_head_layers + self.fm_head_batch_mul = fm_head_batch_mul + self.o_attention_bias = self.attention_bias if o_attention_bias is None else o_attention_bias + + @classmethod + def from_json(cls, path: str) -> NextStepConfig: + with open(path) as f: + data = json.load(f) + # Remove keys that are not constructor parameters (accept extra keys) + # LlamaConfig.__init__ uses **kwargs to absorb extras. + return cls(**data) + + +# --------------------------------------------------------------------------- +# NextStepModel — main model class +# --------------------------------------------------------------------------- + + +class NextStepModel(nn.Module): + def __init__(self, config: NextStepConfig): + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + + # lm_head is part of the checkpoint but not used during image generation + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Image projectors. + token_dim = config.latent_channels * config.latent_patch_size**2 + self.image_in_projector = nn.Linear(token_dim, config.hidden_size) + self.image_out_projector = nn.Linear(config.hidden_size, config.hidden_size) + + # Flow-matching head (no TP — tiny network) + self.image_head = FlowMatchingHead( + input_dim=token_dim, + cond_dim=config.hidden_size, + dim=config.fm_head_dim, + layers=config.fm_head_layers, + ) + + # Optional generation position embeddings + if getattr(config, "use_gen_pos_embed", False): + self._init_gen_pos_embed() + + # ------------------------------------------------------------------ + # Generation positional embeddings + # ------------------------------------------------------------------ + + def _init_gen_pos_embed(self): + self.register_buffer( + "gen_pos_embed", + torch.from_numpy(get_2d_sincos_pos_embed(self.config.hidden_size, self.config.base_image_grid_size)) + .float() + .unsqueeze(0), + ) + + def gen_pos_embed_with_ar(self, h: int, w: int) -> torch.Tensor: + bsz, hw, dim = self.gen_pos_embed.shape + side = int(hw**0.5) + gen_pos_embed = self.gen_pos_embed.reshape(bsz, side, side, dim) + gen_pos_embed = gen_pos_embed[:, :h, :w, :] + return gen_pos_embed.reshape(bsz, -1, dim) + + # ------------------------------------------------------------------ + # Patchify / Unpatchify + # ------------------------------------------------------------------ + + def patchify(self, img: torch.Tensor) -> torch.Tensor: + bsz, c, h, w = img.shape + p = self.config.latent_patch_size + h_, w_ = h // p, w // p + img = img.reshape(bsz, c, h_, p, w_, p) + img = torch.einsum("nchpwq->nhwcpq", img) + return img.reshape(bsz, h_ * w_, c * p**2) + + def unpatchify( + self, + x: torch.Tensor, + h: int | None = None, + w: int | None = None, + ) -> torch.Tensor: + bsz = x.shape[0] + p = self.config.latent_patch_size + c = self.config.latent_channels + if h is None and w is None: + h_ = w_ = int(x.shape[1] ** 0.5) + else: + h_, w_ = h, w + assert h_ * w_ == x.shape[1], f"Invalid sequence length {x.shape[1]}." + x = x.reshape(bsz, h_, w_, c, p, p) + x = torch.einsum("nhwcpq->nchpwq", x) + return x.reshape(bsz, c, h_ * p, w_ * p) + + # ------------------------------------------------------------------ + # Input embedding preparation + # ------------------------------------------------------------------ + + def prepare_inputs_embeds( + self, + input_ids: torch.LongTensor, + latents: torch.FloatTensor | None = None, + ) -> torch.Tensor: + if latents is None: + return self.embed_tokens(input_ids) + + bs, seq_length = input_ids.shape + inputs_embeds = torch.zeros( + (bs, seq_length, self.config.hidden_size), + device=self.embed_tokens.weight.device, + dtype=self.embed_tokens.weight.dtype, + ) + im_indices = input_ids == self.config.image_placeholder_id + lm_indices = ~im_indices + + if isinstance(latents, list): + tokens = torch.cat([self.patchify(latent) for latent in latents], dim=1) + else: + tokens = self.patchify(latents) + + image_embeds = self.image_in_projector(tokens) + image_embeds = image_embeds.view(-1, self.config.hidden_size) + + token_embeds = self.embed_tokens(input_ids[lm_indices]) + + inputs_embeds[im_indices] = image_embeds.to(inputs_embeds.dtype) + inputs_embeds[lm_indices] = token_embeds + + return inputs_embeds + + # ------------------------------------------------------------------ + # Causal mask utilities + # ------------------------------------------------------------------ + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ) -> torch.Tensor | None: + # For SDPA we build the 4D mask; flash_attention_2 uses None. + attn_impl = getattr(self.config, "_attn_implementation", "sdpa") + if attn_impl == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + if attn_impl == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + attn_impl == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ) -> torch.Tensor: + if attention_mask is not None and attention_mask.dim() == 4: + return attention_mask + + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), + fill_value=min_dtype, + dtype=dtype, + device=device, + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(padding_mask, min_dtype) + + return causal_mask + + # ------------------------------------------------------------------ + # Forward through decoder layers + # ------------------------------------------------------------------ + + def forward_model( + self, + inputs_embeds: torch.FloatTensor, + attention_mask: torch.Tensor | None = None, + past_key_values: Cache | list[torch.FloatTensor] | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + cache_position: torch.LongTensor | None = None, + ) -> BaseModelOutputWithPast: + output_attentions = ( + output_attentions if output_attentions is not None else getattr(self.config, "output_attentions", False) + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else getattr(self.config, "output_hidden_states", False) + ) + use_cache = use_cache if use_cache is not None else getattr(self.config, "use_cache", True) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values, + output_attentions, + ) + hidden_states = inputs_embeds + + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + # ------------------------------------------------------------------ + # Weight loading with TP sharding support + # ------------------------------------------------------------------ + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name_in_model, weight_name_in_checkpoint, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params diff --git a/vllm_omni/diffusion/models/nextstep_1_1/modeling_nextstep_heads.py b/vllm_omni/diffusion/models/nextstep_1_1/modeling_nextstep_heads.py new file mode 100644 index 00000000000..7768ac76518 --- /dev/null +++ b/vllm_omni/diffusion/models/nextstep_1_1/modeling_nextstep_heads.py @@ -0,0 +1,309 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from NextStep-1.1 (https://huggingface.co/stepfun-ai/NextStep-1.1) +# Original: models/heads.py — FlowMatchingHead and components. +# No TP needed: the FM head is tiny (dim=1536, 12 layers). + +from __future__ import annotations + +import math + +import torch +import torch.nn as nn + +# --------------------------------------------------------------------------- +# Utilities (inlined from remote utils/model_utils.py) +# --------------------------------------------------------------------------- + + +def expand_t(t: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + dims = [1] * (len(x.size()) - 1) + return t.view(t.size(0), *dims) + + +def randn_tensor( + shape: tuple[int, ...], + noise_repeat: int, + device: torch.device, + dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + bsz = shape[0] + if bsz % noise_repeat != 0: + raise ValueError(f"Batch size ({bsz}) must be divisible by noise repeat ({noise_repeat})") + _shape = (noise_repeat,) + shape[1:] + _tensor = torch.randn(_shape, device=device, dtype=dtype).repeat(bsz // noise_repeat, 1) + return _tensor + + +# --------------------------------------------------------------------------- +# Adaptive LayerNorm modulation +# --------------------------------------------------------------------------- + + +def modulate( + x: torch.Tensor, + shift: torch.Tensor | None, + scale: torch.Tensor | None = None, +) -> torch.Tensor: + if shift is None: + return x * (1 + scale) + return x * (1 + scale) + shift + + +# --------------------------------------------------------------------------- +# ResBlock +# --------------------------------------------------------------------------- + + +class ResBlock(nn.Module): + def __init__(self, channels: int, mlp_ratio: float = 1.0): + super().__init__() + self.channels = channels + self.intermediate_size = int(channels * mlp_ratio) + self.in_ln = nn.LayerNorm(self.channels, eps=1e-6) + self.mlp = nn.Sequential( + nn.Linear(self.channels, self.intermediate_size), + nn.SiLU(), + nn.Linear(self.intermediate_size, self.channels), + ) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(channels, 3 * channels, bias=True)) + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1) + h = modulate(self.in_ln(x), shift_mlp, scale_mlp) + h = self.mlp(h) + return x + gate_mlp * h + + +# --------------------------------------------------------------------------- +# FinalLayer +# --------------------------------------------------------------------------- + + +class FinalLayer(nn.Module): + def __init__(self, model_channels: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(model_channels, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(model_channels, 2 * model_channels, bias=True)) + + def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +# --------------------------------------------------------------------------- +# TimestepEmbedder +# --------------------------------------------------------------------------- + + +class TimestepEmbedder(nn.Module): + def __init__(self, hidden_size: int, frequency_embedding_size: int = 256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t: torch.Tensor, dim: int, max_period: float = 10000.0) -> torch.Tensor: + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + device=t.device + ) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t: torch.Tensor) -> torch.Tensor: + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + return self.mlp(t_freq.to(self.mlp[0].weight.dtype)) + + +# --------------------------------------------------------------------------- +# SimpleMLPAdaLN +# --------------------------------------------------------------------------- + + +class SimpleMLPAdaLN(nn.Module): + def __init__( + self, + input_dim: int, + cond_dim: int, + dim: int = 1536, + layers: int = 12, + mlp_ratio: float = 1.0, + ): + super().__init__() + self.input_dim = input_dim + self.cond_dim = cond_dim + self.dim = dim + + self.time_embed = TimestepEmbedder(dim) + self.cond_embed = nn.Linear(cond_dim, dim) + self.input_proj = nn.Linear(input_dim, dim) + + self.res_blocks = nn.ModuleList([ResBlock(dim, mlp_ratio) for _ in range(layers)]) + self.final_layer = FinalLayer(dim, input_dim) + + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02) + nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02) + + for block in self.res_blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def forward(self, x: torch.Tensor, t: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + x = self.input_proj(x) + t = self.time_embed(t) + c = self.cond_embed(c) + y = t + c + + for block in self.res_blocks: + x = block(x, y) + + return self.final_layer(x, y) + + +# --------------------------------------------------------------------------- +# FlowMatchingHead +# --------------------------------------------------------------------------- + + +class FlowMatchingHead(nn.Module): + def __init__( + self, + input_dim: int, + cond_dim: int, + dim: int = 1536, + layers: int = 12, + mlp_ratio: float = 1.0, + ): + super().__init__() + self.input_dim = input_dim + self.net = SimpleMLPAdaLN( + input_dim=input_dim, + cond_dim=cond_dim, + dim=dim, + layers=layers, + mlp_ratio=mlp_ratio, + ) + + @property + def dtype(self): + return self.net.input_proj.weight.dtype + + @property + def device(self): + return self.net.input_proj.weight.device + + def get_score_from_velocity( + self, + velocity: torch.Tensor, + x: torch.Tensor, + t: torch.Tensor, + ) -> torch.Tensor: + t = expand_t(t, x) + alpha_t, d_alpha_t = t, 1 + sigma_t, d_sigma_t = 1 - t, -1 + mean = x + reverse_alpha_ratio = alpha_t / d_alpha_t + var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t + score = (reverse_alpha_ratio * velocity - mean) / var + return score + + def get_velocity_from_cfg( + self, + velocity: torch.Tensor, + cfg: float, + cfg_img: float, + cfg_mult: int, + ) -> torch.Tensor: + if cfg_mult == 2: + cond_v, uncond_v = torch.chunk(velocity, 2, dim=0) + velocity = uncond_v + cfg * (cond_v - uncond_v) + elif cfg_mult == 3: + cond_v, uncond_v1, uncond_v2 = torch.chunk(velocity, 3, dim=0) + velocity = uncond_v2 + cfg_img * (uncond_v1 - uncond_v2) + cfg * (cond_v - uncond_v1) + return velocity + + @torch.no_grad() + def sample( + self, + c: torch.Tensor, + cfg: float = 1.0, + cfg_img: float = 1.0, + cfg_mult: int | None = None, + timesteps_shift: float = 1.0, + num_sampling_steps: int = 20, + last_step_size: float = 0.0, + noise_repeat: int = 1, + ) -> torch.Tensor: + if cfg_mult is None: + cfg_mult = 1 + if cfg > 1.0: + cfg_mult += 1 + if cfg_img > 1.0: + cfg_mult += 1 + + if cfg_mult <= 0: + raise ValueError(f"Invalid cfg_mult={cfg_mult}; expected a positive value.") + if c.shape[0] % cfg_mult != 0: + raise ValueError( + f"Invalid CFG layout: condition batch size {c.shape[0]} is not divisible by cfg_mult={cfg_mult}." + ) + + noise = randn_tensor((c.shape[0] // cfg_mult, self.input_dim), noise_repeat, self.device) + + x = noise + xs = [] + + t0, t1 = 0, 1 + timesteps = torch.linspace(t0, t1, num_sampling_steps + 1, device=c.device)[:-1] + timesteps = timesteps / (timesteps_shift - (timesteps_shift - 1) * timesteps) + timesteps = torch.cat([timesteps, torch.ones(1, device=c.device)]) + + for ti, tj in zip(timesteps[:-1], timesteps[1:]): + dt = tj - ti + + combined = torch.cat([x] * cfg_mult, dim=0) + velocity = self.net(combined.to(c.dtype), ti.expand(c.shape[0]).to(c), c) + velocity = velocity.to(torch.float32) + + velocity = self.get_velocity_from_cfg(velocity, cfg, cfg_img, cfg_mult) + score = self.get_score_from_velocity(velocity, x, ti.expand(x.shape[0]).to(x)) + drift = velocity + (1 - expand_t(ti.expand(x.shape[0]).to(x), x)) * score + + w_cur = randn_tensor((c.shape[0] // cfg_mult, self.input_dim), noise_repeat, self.device) + dw = w_cur * torch.sqrt(dt) + + mean_x = x + drift * dt + x = mean_x + torch.sqrt(2 * (1 - expand_t(ti.expand(x.shape[0]).to(x), x))) * dw + xs.append(x) + + if len(xs) != num_sampling_steps: + raise ValueError(f"Samples ({len(xs)}) does not match the number of steps ({num_sampling_steps})") + + return xs[-1].to(c.dtype) diff --git a/vllm_omni/diffusion/models/nextstep_1_1/modeling_nextstep_llama.py b/vllm_omni/diffusion/models/nextstep_1_1/modeling_nextstep_llama.py new file mode 100644 index 00000000000..7b367b6ff49 --- /dev/null +++ b/vllm_omni/diffusion/models/nextstep_1_1/modeling_nextstep_llama.py @@ -0,0 +1,285 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from NextStep-1.1 (https://huggingface.co/stepfun-ai/NextStep-1.1) +# Original: models/llama_model.py — made TP-aware for vLLM-Omni. + +from __future__ import annotations + +import torch +import torch.nn as nn +from transformers import ROPE_INIT_FUNCTIONS +from transformers.cache_utils import Cache +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) + +# --------------------------------------------------------------------------- +# Utilities (inlined from remote utils/model_utils.py) +# --------------------------------------------------------------------------- + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + unsqueeze_dim: int = 1, +) -> tuple[torch.Tensor, torch.Tensor]: + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + batch, num_kv_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_kv_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_kv_heads * n_rep, slen, head_dim) + + +# --------------------------------------------------------------------------- +# LlamaRMSNorm — no TP needed +# --------------------------------------------------------------------------- + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +# --------------------------------------------------------------------------- +# LlamaRotaryEmbedding — no TP needed +# --------------------------------------------------------------------------- + + +class LlamaRotaryEmbedding(nn.Module): + def __init__(self, config, device=None): + super().__init__() + self.rope_type = "default" + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + @torch.no_grad() + def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# --------------------------------------------------------------------------- +# LlamaAttention — TP-aware (fused QKV + RowParallel o_proj) +# --------------------------------------------------------------------------- + + +class LlamaAttention(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.is_causal = True + + # TP-aware: fused QKV projection + self.qkv_proj = QKVParallelLinear( + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=self.num_heads, + total_num_kv_heads=self.num_key_value_heads, + bias=config.attention_bias, + ) + # TP-aware: row-parallel output projection + self.o_proj = RowParallelLinear( + self.num_heads * self.head_dim, + self.hidden_size, + bias=getattr(config, "o_attention_bias", config.attention_bias), + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + past_key_value: Cache | None = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor | None, Cache | None]: + bsz, q_len, _ = hidden_states.size() + + # Fused QKV projection + qkv, _ = self.qkv_proj(hidden_states) + # Split into Q, K, V — sizes account for TP sharding + qkv = qkv.view(bsz, q_len, -1, self.head_dim) + # QKVParallelLinear output: [q_heads_local, k_heads_local, v_heads_local] + num_local_heads = self.qkv_proj.num_heads + num_local_kv_heads = self.qkv_proj.num_kv_heads + split_sizes = [num_local_heads, num_local_kv_heads, num_local_kv_heads] + query_states, key_states, value_states = qkv.split(split_sizes, dim=2) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + } + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + num_local_kv_groups = num_local_heads // num_local_kv_heads + key_states = repeat_kv(key_states, num_local_kv_groups) + value_states = repeat_kv(value_states, num_local_kv_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + attn_output, _ = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +# --------------------------------------------------------------------------- +# LlamaMLP — TP-aware (fused gate_up + RowParallel down) +# --------------------------------------------------------------------------- + + +class LlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + # TP-aware: fused gate + up projection + self.gate_up_proj = MergedColumnParallelLinear( + self.hidden_size, + [self.intermediate_size] * 2, + bias=config.mlp_bias, + ) + # TP-aware: row-parallel down projection + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=config.mlp_bias, + ) + self.act_fn = nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + gate, up = gate_up.chunk(2, dim=-1) + down, _ = self.down_proj(self.act_fn(gate) * up) + return down + + +# --------------------------------------------------------------------------- +# LlamaDecoderLayer — uses TP-aware Attention + MLP +# --------------------------------------------------------------------------- + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx) + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + past_key_value: Cache | None = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, + ) -> tuple: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + if use_cache: + outputs += (present_key_value,) + + return outputs diff --git a/vllm_omni/diffusion/models/nextstep_1_1/pipeline_nextstep_1_1.py b/vllm_omni/diffusion/models/nextstep_1_1/pipeline_nextstep_1_1.py new file mode 100644 index 00000000000..4a20c28c5e6 --- /dev/null +++ b/vllm_omni/diffusion/models/nextstep_1_1/pipeline_nextstep_1_1.py @@ -0,0 +1,710 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from NextStep-1.1 (https://huggingface.co/stepfun-ai/NextStep-1.1) + +import os +import re +from collections.abc import Iterable +from typing import Literal + +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms as transforms +from diffusers.image_processor import VaeImageProcessor +from PIL import Image +from torch import nn +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PreTrainedTokenizer +from transformers.cache_utils import StaticCache +from vllm.logger import init_logger +from vllm.model_executor.models.utils import AutoWeightsLoader + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.parallel_state import ( + get_cfg_group, + get_classifier_free_guidance_rank, + get_classifier_free_guidance_world_size, +) +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.nextstep_1_1.modeling_flux_vae import AutoencoderKL +from vllm_omni.diffusion.models.nextstep_1_1.modeling_nextstep import ( + NextStepConfig, + NextStepModel, +) +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.model_executor.model_loader.weight_utils import ( + download_weights_from_hf_specific, +) + +logger = init_logger(__name__) + + +def layer_norm(input: torch.Tensor, normalized_shape: torch.Size, eps: float = 1e-6) -> torch.Tensor: + return F.layer_norm(input, normalized_shape, None, None, eps) + + +def set_seed(seed: int): + """Set random seed for reproducibility.""" + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def to_pil(image: torch.Tensor, mode: str = "11") -> Image.Image: + """Convert tensor to PIL Image.""" + if mode == "11": + # Assuming image is in [-1, 1] range + image = (image + 1) / 2 + image = image.clamp(0, 1) + image = image.permute(1, 2, 0).cpu().numpy() + image = (image * 255).astype(np.uint8) + return Image.fromarray(image) + + +################################################################################# +# Sine/Cosine Positional Embedding Functions # +################################################################################# + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, pe_interpolation=1.0): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] + """ + grid_h = np.arange(grid_size, dtype=np.float32) / pe_interpolation + grid_w = np.arange(grid_size, dtype=np.float32) / pe_interpolation + grid = np.meshgrid(grid_w, grid_h) + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) + emb = np.concatenate([emb_h, emb_w], axis=1) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega + + pos = pos.reshape(-1) + out = np.einsum("m,d->md", pos, omega) + + emb_sin = np.sin(out) + emb_cos = np.cos(out) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) + return emb + + +def hw2str(h: int, w: int) -> str: + """Convert height and width to string format.""" + return f"{h}*{w}" + + +DEFAULT_IMAGE_AREA_TOKEN = "<|image_area|>" + + +def get_nextstep11_post_process_func(od_config: OmniDiffusionConfig): + """Return post-processing function for NextStep-1.1 pipeline outputs.""" + vae_scale_factor = 8 # Default for NextStep VAE + + image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor) + + def post_process_func(images: torch.Tensor): + return image_processor.postprocess(images) + + return post_process_func + + +class NextStep11Pipeline(nn.Module): + """ + NextStep-1.1 Pipeline for text-to-image generation. + + This pipeline implements the autoregressive flow-based image generation + model from StepFun. It uses an LLM backbone with a flow matching head + to generate images autoregressively. + """ + + def __init__( + self, + *, + od_config: OmniDiffusionConfig, + prefix: str = "", + ): + super().__init__() + self.od_config = od_config + self._execution_device = get_local_device() + + model_path = od_config.model + local_files_only = os.path.exists(model_path) + + if not local_files_only: + model_path = download_weights_from_hf_specific(model_path, None, ["*"]) + + # Load tokenizer (still uses trust_remote_code for tokenizer only) + self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained( + model_path, + local_files_only=True, + model_max_length=512, + padding_side="left", + use_fast=True, + trust_remote_code=True, + ) + self.tokenizer.add_eos_token = False + + # Load model from local TP-aware code (weights loaded later via load_weights) + config = NextStepConfig.from_json(os.path.join(model_path, "config.json")) + self.model = NextStepModel(config) + self.model.eval() + + # Load config + self.config = self.model.config + + # Load VAE + vae_path = getattr(self.config, "vae_name_or_path", None) + if vae_path is None: + vae_path = os.path.join(model_path, "vae") + elif not os.path.isabs(vae_path): + # Resolve relative vae_name_or_path (e.g. "vae/") against model dir + vae_path = os.path.join(model_path, vae_path) + + if os.path.exists(vae_path): + self.vae = AutoencoderKL.from_pretrained(vae_path) + else: + # Try loading from model directory + vae_checkpoint = os.path.join(model_path, "vae", "checkpoint.pt") + vae_config = os.path.join(model_path, "vae", "config.json") + if os.path.exists(vae_checkpoint) and os.path.exists(vae_config): + self.vae = AutoencoderKL.from_pretrained(os.path.join(model_path, "vae")) + else: + raise ValueError(f"Could not find VAE at {vae_path}") + + self.vae.eval() + + # Calculate down factor + vae_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + latent_patch_size = getattr(self.config, "latent_patch_size", 2) + self.down_factor = vae_factor * latent_patch_size + + # Get VAE parameters + self.shift_factor = getattr(self.vae.config, "shift_factor", 0.0) + self.scaling_factor = getattr(self.vae.config, "scaling_factor", 1.0) + + # Get special token IDs from config + self.boi = getattr(self.config, "boi", None) + self.eoi = getattr(self.config, "eoi", None) + self.image_placeholder_id = getattr(self.config, "image_placeholder_id", None) + + # Image processing + self.pil2tensor = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), + ] + ) + + self.__device = self._execution_device + self.__dtype = od_config.dtype + + # Weight sources: model weights from safetensors, prefixed with "model." + # so AutoWeightsLoader dispatches to self.model.load_weights() + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=model_path, + subfolder=None, + revision=None, + prefix="model.", + fall_back_to_pt=True, + allow_patterns_overrides=["model-*.safetensors", "model.safetensors"], + ) + ] + + @property + def device(self): + return self.__device + + @property + def dtype(self): + return self.__dtype + + def to(self, device=None, dtype=None): + if device is not None: + self.__device = device + if dtype is not None: + self.__dtype = dtype + self.model.to(self.__device, dtype=self.__dtype) + self.vae.to(self.__device, dtype=self.__dtype) + return self + + def _image_str(self, hw: tuple[int, int] = (256, 256)): + """Generate image token string for given height/width.""" + latent_hw = (hw[0] // self.down_factor, hw[1] // self.down_factor) + image_ids = [self.boi] + [self.image_placeholder_id] * (latent_hw[0] * latent_hw[1]) + [self.eoi] + image_str = DEFAULT_IMAGE_AREA_TOKEN + hw2str(*latent_hw) + self.tokenizer.decode(image_ids) + return image_str + + def _check_input( + self, captions: str | list[str], images: Image.Image | list[Image.Image] | None + ) -> tuple[list[str], list[Image.Image] | None]: + """Validate and process input captions and images.""" + if not isinstance(captions, list): + captions = [captions] + + if images is not None: + if not isinstance(images, list): + images = [images] + + # Validate image count matches tokens in captions + image_token_count = 0 + for caption in captions: + num_image_token = len(re.findall(r"", caption)) + if num_image_token != 1: + raise ValueError( + f"Caption must contain exactly one token. " + f"Found {num_image_token} in: {caption[:100]}..." + ) + image_token_count += num_image_token + if image_token_count != len(images): + raise ValueError( + f"Number of images ({len(images)}) does not match number of image tokens ({image_token_count})." + ) + + hws = [(image.size[1], image.size[0]) for image in images] + + # Replace tokens with corresponding image_str + processed_captions = [] + image_idx = 0 + for caption in captions: + processed_caption = caption + num_image_tokens = processed_caption.count("") + + for _ in range(num_image_tokens): + processed_caption = processed_caption.replace("", self._image_str(hws[image_idx]), 1) + image_idx += 1 + + processed_captions.append(processed_caption) + + captions = processed_captions + return captions, images + + @staticmethod + def _resolve_cfg_layout(cfg: float, cfg_img: float, has_image_conditions: bool) -> tuple[int, float]: + """Resolve the active CFG branch layout for the current request.""" + use_text_cfg = cfg > 1.0 + # Image CFG branch is only meaningful when the request has an image condition. + use_img_cfg = use_text_cfg and has_image_conditions and cfg_img != 1.0 + cfg_mult = 1 + int(use_text_cfg) + int(use_img_cfg) + effective_cfg_img = cfg_img if use_img_cfg else 1.0 + return cfg_mult, effective_cfg_img + + def _build_captions( + self, + captions: str | list[str], + images: list[Image.Image] | None = None, + num_images_per_caption: int = 1, + positive_prompt: str | None = None, + negative_prompt: str | None = None, + cfg: float = 1.0, + cfg_img: float = 1.0, + ) -> tuple[list[str], list[Image.Image] | None, int, float]: + """Build captions with CFG support.""" + if not isinstance(captions, list): + captions = [captions] + captions = [caption for caption in captions for _ in range(num_images_per_caption)] + if images is not None: + images = [image for image in images for _ in range(num_images_per_caption)] + + # Add positive prompt + if positive_prompt is not None and positive_prompt != "": + captions = [f"{caption} {positive_prompt}" for caption in captions] + + cfg_mult, effective_cfg_img = self._resolve_cfg_layout(cfg, cfg_img, images is not None) + + # Add negative prompt for CFG + if negative_prompt is None: + negative_prompt = "" + num_samples = len(captions) + if cfg_mult == 3: + w, h = images[0].size + captions = captions + [self._image_str((h, w)) + negative_prompt] * num_samples + images = images + images + captions = captions + [negative_prompt] * num_samples + elif cfg_mult == 2: + captions = captions + [negative_prompt] * num_samples + + return captions, images, cfg_mult, effective_cfg_img + + def _add_prefix_ids(self, hw: tuple[int, int], input_ids: torch.Tensor, attention_mask: torch.Tensor): + """Add prefix IDs for image generation.""" + prefix_str = DEFAULT_IMAGE_AREA_TOKEN + hw2str(hw[0] // self.down_factor, hw[1] // self.down_factor) + prefix_output = self.tokenizer(prefix_str, truncation=False, add_special_tokens=True, return_tensors="pt") + prefix_input_ids = prefix_output.input_ids.to(input_ids.device, dtype=input_ids.dtype) + prefix_attention_mask = prefix_output.attention_mask.to(attention_mask.device, dtype=attention_mask.dtype) + + # Remove bos token + if self.tokenizer.bos_token is not None: + prefix_input_ids = prefix_input_ids[:, 1:] + prefix_attention_mask = prefix_attention_mask[:, 1:] + + # Add boi token + prefix_input_ids = torch.cat( + [ + prefix_input_ids, + prefix_input_ids.new_tensor([self.boi]).unsqueeze(0), + ], + dim=1, + ) + prefix_attention_mask = torch.cat( + [ + prefix_attention_mask, + prefix_attention_mask.new_ones((prefix_attention_mask.shape[0], 1)), + ], + dim=1, + ) + + bsz = input_ids.shape[0] + input_ids = torch.cat([input_ids, prefix_input_ids.expand(bsz, -1)], dim=1) + attention_mask = torch.cat([attention_mask, prefix_attention_mask.expand(bsz, -1)], dim=1) + + return input_ids, attention_mask + + @torch.no_grad() + def decoding( + self, + c: torch.Tensor, + attention_mask: torch.Tensor, + past_key_values, + max_new_len: int, + num_images_per_caption: int, + use_norm: bool = False, + cfg: float = 1.0, + cfg_img: float = 1.0, + cfg_mult: int = 1, + cfg_schedule: Literal["linear", "constant"] = "constant", + timesteps_shift: float = 1.0, + num_sampling_steps: int = 20, + progress: bool = True, + hw: tuple[int, int] = (256, 256), + ): + """Autoregressive image token decoding with optional CFG-Parallel.""" + if cfg_mult <= 0: + raise ValueError(f"Invalid cfg_mult={cfg_mult}; expected a positive value.") + + full_bsz = c.shape[0] + if full_bsz % cfg_mult != 0: + raise ValueError( + f"Invalid CFG layout: batch size {full_bsz} is not divisible by active CFG multiplier {cfg_mult}." + ) + + # CFG-Parallel: each rank handles one portion of the CFG batch + cfg_world_size = get_classifier_free_guidance_world_size() + cfg_parallel = cfg_world_size > 1 and cfg_mult > 1 + if cfg_parallel and cfg_world_size != cfg_mult: + logger.warning( + "CFG parallel world size (%d) does not match the number of active " + "CFG branches (%d); falling back to non-parallel CFG.", + cfg_world_size, + cfg_mult, + ) + cfg_parallel = False + if cfg_parallel: + cfg_rank = get_classifier_free_guidance_rank() + cfg_group = get_cfg_group() + # Split batch: rank 0 gets positive, rank 1 gets first uncond, etc. + batch_per_rank = full_bsz // cfg_mult + start = cfg_rank * batch_per_rank + end = start + batch_per_rank + c = c[start:end] + attention_mask = attention_mask[start:end] + # Slice the StaticCache to keep only this rank's portion. + # We rebuild a new StaticCache for this rank's batch size and copy + # the relevant slices from the pre-filled cache. + if isinstance(past_key_values, StaticCache): + old_cache = past_key_values + new_cache = StaticCache( + config=self.config, + max_cache_len=old_cache.get_max_cache_shape(), + ) + for layer_idx, layer in enumerate(old_cache.layers): + if not layer.is_initialized: + continue + # Force-initialize the new layer with sliced tensors + new_layer = new_cache.layers[layer_idx] + sliced_keys = layer.keys[start:end] + new_layer.lazy_initialization(sliced_keys) + new_layer.keys.copy_(sliced_keys) + new_layer.values.copy_(layer.values[start:end]) + # Preserve cumulative_length for sliding window layers + # (StaticSlidingWindowLayer tracks seq position via this + # counter, not by counting non-zero entries) + if hasattr(layer, "cumulative_length"): + new_layer.cumulative_length = layer.cumulative_length + past_key_values = new_cache + else: + cfg_rank = 0 + cfg_group = None + + token_dim = self.config.latent_channels * self.config.latent_patch_size**2 + indices = list(range(max_new_len)) + indices = tqdm(indices, desc="Generating") if progress else indices + + tokens = None + for step in indices: + # CFG schedule + if cfg_schedule == "linear": + tokens_len = 0 if tokens is None else tokens.shape[1] + cfg_iter = 1 + (cfg - 1) * (max_new_len - tokens_len) / max_new_len + cfg_img_iter = 1 + (cfg_img - 1) * (max_new_len - tokens_len) / max_new_len + elif cfg_schedule == "constant": + cfg_iter = cfg + cfg_img_iter = cfg_img + else: + raise NotImplementedError(f"Unknown cfg_schedule: {cfg_schedule}") + + if cfg_parallel: + # Each rank projects its own portion + c_proj = self.model.image_out_projector(c) + # Gather projected context from all CFG ranks + c_gathered = cfg_group.all_gather(c_proj, separate_tensors=True) + c_full = torch.cat(c_gathered, dim=0) # (full_bsz, 1, hidden) + + # Rank 0 runs the FM head sampling (it needs full CFG batch) + if cfg_rank == 0: + token_sampled = self.model.image_head.sample( + c=c_full.squeeze(1), + cfg=cfg_iter, + cfg_img=cfg_img_iter, + cfg_mult=cfg_mult, + timesteps_shift=timesteps_shift, + num_sampling_steps=num_sampling_steps, + noise_repeat=num_images_per_caption, + ) + else: + token_sampled = torch.empty( + batch_per_rank, + token_dim, + device=c.device, + dtype=c.dtype, + ) + # Broadcast sampled token from rank 0 to all ranks + token_sampled = token_sampled.contiguous() + cfg_group.broadcast(token_sampled, src=0) + else: + c_proj = self.model.image_out_projector(c) + token_sampled = self.model.image_head.sample( + c=c_proj.squeeze(1), + cfg=cfg_iter, + cfg_img=cfg_img_iter, + cfg_mult=cfg_mult, + timesteps_shift=timesteps_shift, + num_sampling_steps=num_sampling_steps, + noise_repeat=num_images_per_caption, + ) + + if use_norm: + token_sampled = layer_norm(token_sampled, normalized_shape=token_sampled.size()[1:]) + + if tokens is not None: + tokens = torch.cat([tokens, token_sampled.unsqueeze(1)], dim=1) + else: + tokens = token_sampled.unsqueeze(1) + + # Prepare input embeds for next LLM step + cur_inputs_embeds = self.model.image_in_projector(tokens[:, -1:]) + if not cfg_parallel and cfg_mult > 1: + # Non-parallel: duplicate embeds for all active CFG branches. + cur_inputs_embeds = torch.cat([cur_inputs_embeds] * cfg_mult, dim=0) + # In CFG-parallel mode, each rank already has its portion — no duplication needed + + attention_mask = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], + dim=-1, + ) + outputs = self.model.forward_model( + inputs_embeds=cur_inputs_embeds, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=True, + ) + past_key_values = outputs.past_key_values + c = outputs.last_hidden_state[:, -1:] + if getattr(self.config, "use_gen_pos_embed", False): + c = c + self.model.gen_pos_embed_with_ar(hw[0], hw[1])[:, step + 1 : step + 2, :] + + return tokens + + @torch.no_grad() + def forward( + self, + req: OmniDiffusionRequest, + prompt: str | list[str] | None = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 28, + guidance_scale: float = 7.5, + negative_prompt: str | list[str] | None = None, + num_images_per_prompt: int = 1, + generator: torch.Generator | None = None, + seed: int | None = None, + **kwargs, + ) -> DiffusionOutput: + """ + Generate images from text prompts. + + Args: + req: OmniDiffusionRequest containing generation parameters + prompt: Text prompt(s) for generation + height: Output image height + width: Output image width + num_inference_steps: Number of sampling steps (default 28 for NextStep-1.1) + guidance_scale: CFG scale + negative_prompt: Negative prompt for CFG + num_images_per_prompt: Number of images per prompt + generator: Random generator for reproducibility + seed: Random seed + + Returns: + DiffusionOutput containing generated images + """ + # Extract parameters from request + # req.prompts is a list of str or dict; req.sampling_params holds all generation params + first_prompt = req.prompts[0] if req.prompts else None + if first_prompt is not None: + if isinstance(first_prompt, str): + prompt = first_prompt + else: + prompt = first_prompt.get("prompt") or prompt + negative_prompt = first_prompt.get("negative_prompt", negative_prompt) + + height = req.sampling_params.height or height or 512 + width = req.sampling_params.width or width or 512 + num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps + if req.sampling_params.guidance_scale_provided: + guidance_scale = req.sampling_params.guidance_scale + num_images_per_prompt = ( + req.sampling_params.num_outputs_per_prompt + if req.sampling_params.num_outputs_per_prompt > 0 + else num_images_per_prompt + ) + seed = req.sampling_params.seed if req.sampling_params.seed is not None else seed + + # NextStep-specific parameters from request extra + cfg_img = ( + req.sampling_params.guidance_scale_2 + if req.sampling_params.guidance_scale_2 is not None + else req.sampling_params.extra_args.get("cfg_img", 1.0) + ) + timesteps_shift = req.sampling_params.extra_args.get("timesteps_shift", 1.0) + use_norm = req.sampling_params.extra_args.get("use_norm", False) + cfg_schedule = req.sampling_params.extra_args.get("cfg_schedule", "constant") + positive_prompt = req.sampling_params.extra_args.get("positive_prompt", None) + + # Set seed for reproducibility (use generator if provided, else fall back to seed) + if generator is None and seed is not None: + set_seed(seed) + elif generator is not None: + torch.manual_seed(generator.initial_seed()) + + # Prepare hw tuple + hw = (height, width) + + # Check and process inputs (no image inputs for t2i) + captions, images = self._check_input(prompt, None) + + # Build captions with CFG + captions, images, cfg_mult, effective_cfg_img = self._build_captions( + captions, + images, + num_images_per_prompt, + positive_prompt, + negative_prompt, + guidance_scale, + cfg_img, + ) + + # No input images for text-to-image + latents = None + + # Add BOS token to captions before tokenizing + captions = [ + self.tokenizer.bos_token + caption if self.tokenizer.bos_token is not None else caption + for caption in captions + ] + + # Tokenize captions and add prefix ids + output = self.tokenizer( + captions, + padding="longest", + truncation=False, + add_special_tokens=True, + return_tensors="pt", + ) + input_ids = output.input_ids.to(self.device) + attention_mask = output.attention_mask.to(self.device) + input_ids, attention_mask = self._add_prefix_ids(hw, input_ids, attention_mask) + + # LLM prefill + max_new_len = (hw[0] // self.down_factor) * (hw[1] // self.down_factor) + max_cache_len = input_ids.shape[1] + max_new_len + past_key_values = StaticCache( + config=self.config, + max_cache_len=max_cache_len, + ) + inputs_embeds = self.model.prepare_inputs_embeds(input_ids, latents) + outputs = self.model.forward_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=True, + ) + past_key_values = outputs.past_key_values + c = outputs.last_hidden_state[:, -1:] + if getattr(self.config, "use_gen_pos_embed", False): + c = c + self.model.gen_pos_embed_with_ar(height, width)[:, 0:1, :] + + # Decoding + tokens = self.decoding( + c=c, + attention_mask=attention_mask, + past_key_values=past_key_values, + max_new_len=max_new_len, + num_images_per_caption=num_images_per_prompt, + use_norm=use_norm, + cfg=guidance_scale, + cfg_img=effective_cfg_img, + cfg_mult=cfg_mult, + cfg_schedule=cfg_schedule, + timesteps_shift=timesteps_shift, + num_sampling_steps=num_inference_steps, + progress=True, + hw=hw, + ) + + # Unpatchify + latents = self.model.unpatchify(tokens) + latents = (latents / self.scaling_factor) + self.shift_factor + + # Decode latents + sampled_images = self.vae.decode(latents.to(self.vae.dtype)).sample + sampled_images = sampled_images.detach().cpu().to(torch.float32) + + return DiffusionOutput(output=sampled_images) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load model weights.""" + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index 173c576a243..f862dd0b522 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -95,6 +95,11 @@ "pipeline_flux2_klein", "Flux2KleinPipeline", ), + "NextStep11Pipeline": ( + "nextstep_1_1", + "pipeline_nextstep_1_1", + "NextStep11Pipeline", + ), "FluxPipeline": ( "flux", "pipeline_flux", @@ -116,6 +121,12 @@ _VAE_PATCH_PARALLEL_ALLOWLIST = { # Only enable for models we have validated end-to-end. "ZImagePipeline", + "NextStep11Pipeline", +} + +_NO_CACHE_ACCELERATION = { + # Pipelines that do not support cache acceleration (cache_dit / tea_cache). + "NextStep11Pipeline", } @@ -276,6 +287,7 @@ def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) - "LongCatImageEditPipeline": "get_longcat_image_post_process_func", "StableDiffusion3Pipeline": "get_sd3_image_post_process_func", "Flux2KleinPipeline": "get_flux2_klein_post_process_func", + "NextStep11Pipeline": "get_nextstep11_post_process_func", "FluxPipeline": "get_flux_post_process_func", } diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py index 134a2ae397a..5fee656451e 100644 --- a/vllm_omni/diffusion/worker/diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py @@ -27,6 +27,7 @@ from vllm_omni.diffusion.forward_context import set_forward_context from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.offloader import get_offload_backend +from vllm_omni.diffusion.registry import _NO_CACHE_ACCELERATION from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.distributed.omni_connectors.kv_transfer_manager import OmniKVTransferManager from vllm_omni.platforms import current_omni_platform @@ -148,7 +149,16 @@ def get_memory_context(): self.cache_backend = get_cache_backend(self.od_config.cache_backend, self.od_config.cache_config) if self.cache_backend is not None: - self.cache_backend.enable(self.pipeline) + if self.od_config.model_class_name in _NO_CACHE_ACCELERATION: + logger.warning( + "Cache backend '%s' is not supported for %s; disabling cache acceleration.", + self.od_config.cache_backend, + self.od_config.model_class_name, + ) + self.cache_backend = None + self.od_config.cache_backend = None + else: + self.cache_backend.enable(self.pipeline) logger.info("Model runner: Initialization complete.") @@ -196,7 +206,12 @@ def execute_model(self, req: OmniDiffusionRequest) -> DiffusionOutput: output = self.pipeline.forward(req) # NOTE: - if self.od_config.cache_backend == "cache_dit" and self.od_config.enable_cache_dit_summary: + if ( + self.cache_backend is not None + and self.cache_backend.is_enabled() + and self.od_config.cache_backend == "cache_dit" + and self.od_config.enable_cache_dit_summary + ): cache_summary(self.pipeline, details=True) return output diff --git a/vllm_omni/entrypoints/async_omni_diffusion.py b/vllm_omni/entrypoints/async_omni_diffusion.py index 11a3c07e135..072daee30d4 100644 --- a/vllm_omni/entrypoints/async_omni_diffusion.py +++ b/vllm_omni/entrypoints/async_omni_diffusion.py @@ -80,24 +80,36 @@ def __init__( if engine_input_source is not None: self.od_config.omni_kv_config.setdefault("engine_input_source", engine_input_source) + # Diffusers-style models expose `model_index.json` with `_class_name`. + # Non-diffusers models (e.g. Bagel, NextStep) only have `config.json`, + # so we fall back to reading that and mapping model_type manually. try: config_dict = get_hf_file_to_dict("model_index.json", od_config.model) - od_config.model_class_name = config_dict.get("_class_name", None) - od_config.update_multimodal_support() + if config_dict is not None: + od_config.model_class_name = config_dict.get("_class_name", None) + od_config.update_multimodal_support() - tf_config_dict = get_hf_file_to_dict("transformer/config.json", od_config.model) - od_config.tf_model_config = TransformerConfig.from_dict(tf_config_dict) - except (AttributeError, OSError, ValueError): + tf_config_dict = get_hf_file_to_dict("transformer/config.json", od_config.model) + od_config.tf_model_config = TransformerConfig.from_dict(tf_config_dict) + else: + raise FileNotFoundError("model_index.json not found") + except (AttributeError, OSError, ValueError, FileNotFoundError): cfg = get_hf_file_to_dict("config.json", od_config.model) if cfg is None: raise ValueError(f"Could not find config.json or model_index.json for model {od_config.model}") model_type = cfg.get("model_type") architectures = cfg.get("architectures") or [] + # Bagel/NextStep models don't have a model_index.json, so we set the pipeline class name manually if model_type == "bagel" or "BagelForConditionalGeneration" in architectures: od_config.model_class_name = "BagelPipeline" od_config.tf_model_config = TransformerConfig() od_config.update_multimodal_support() + elif model_type == "nextstep": + if od_config.model_class_name is None: + od_config.model_class_name = "NextStep11Pipeline" + od_config.tf_model_config = TransformerConfig() + od_config.update_multimodal_support() elif architectures and len(architectures) == 1: od_config.model_class_name = architectures[0] else: diff --git a/vllm_omni/entrypoints/omni_diffusion.py b/vllm_omni/entrypoints/omni_diffusion.py index 9f6dde15b1c..e0b8911d40b 100644 --- a/vllm_omni/entrypoints/omni_diffusion.py +++ b/vllm_omni/entrypoints/omni_diffusion.py @@ -55,31 +55,41 @@ def __init__(self, od_config: OmniDiffusionConfig | None = None, **kwargs): self.od_config.omni_kv_config.setdefault("engine_input_source", engine_input_source) # Diffusers-style models expose `model_index.json` with `_class_name`. - # Bagel models (and other non-diffusers) typically expose `config.json`. + # Non-diffusers models (e.g. Bagel, NextStep) only have `config.json`, + # so we fall back to reading that and mapping model_type manually. try: config_dict = get_hf_file_to_dict( "model_index.json", od_config.model, ) - od_config.model_class_name = config_dict.get("_class_name", None) - od_config.update_multimodal_support() + if config_dict is not None: + od_config.model_class_name = config_dict.get("_class_name", None) + od_config.update_multimodal_support() - tf_config_dict = get_hf_file_to_dict( - "transformer/config.json", - od_config.model, - ) - od_config.tf_model_config = TransformerConfig.from_dict(tf_config_dict) - except (AttributeError, OSError, ValueError): + tf_config_dict = get_hf_file_to_dict( + "transformer/config.json", + od_config.model, + ) + od_config.tf_model_config = TransformerConfig.from_dict(tf_config_dict) + else: + raise FileNotFoundError("model_index.json not found") + except (AttributeError, OSError, ValueError, FileNotFoundError): cfg = get_hf_file_to_dict("config.json", od_config.model) if cfg is None: raise ValueError(f"Could not find config.json or model_index.json for model {od_config.model}") model_type = cfg.get("model_type") architectures = cfg.get("architectures") or [] + # Bagel/NextStep models don't have a model_index.json, so we set the pipeline class name manually if model_type == "bagel" or "BagelForConditionalGeneration" in architectures: od_config.model_class_name = "BagelPipeline" od_config.tf_model_config = TransformerConfig() od_config.update_multimodal_support() + elif model_type == "nextstep": + if od_config.model_class_name is None: + od_config.model_class_name = "NextStep11Pipeline" + od_config.tf_model_config = TransformerConfig() + od_config.update_multimodal_support() elif architectures and len(architectures) == 1: od_config.model_class_name = architectures[0] else: