diff --git a/benchmarks/diffusion/diffusion_benchmark_serving.py b/benchmarks/diffusion/diffusion_benchmark_serving.py index 32ec48a698f..c52f54d0963 100644 --- a/benchmarks/diffusion/diffusion_benchmark_serving.py +++ b/benchmarks/diffusion/diffusion_benchmark_serving.py @@ -64,6 +64,7 @@ import argparse import ast import asyncio +import base64 import glob import json import logging @@ -818,6 +819,71 @@ def calculate_metrics( return metrics +def _save_generated_outputs( + outputs: list[RequestFuncOutput], + requests_list: list[RequestFuncInput], + save_dir: str, +) -> None: + """Decode and save base64 images/videos from successful responses.""" + os.makedirs(save_dir, exist_ok=True) + saved = 0 + + for idx, (req, out) in enumerate(zip(requests_list, outputs)): + if not out.success or not out.response_body: + continue + + media_urls: list[str] = [] + + # Chat-completions style: choices[*].message.content[*].image_url.url + choices = out.response_body.get("choices", []) + if isinstance(choices, list): + for choice in choices: + content = (choice or {}).get("message", {}).get("content") + if not isinstance(content, list): + continue + for item in content: + if not isinstance(item, dict) or item.get("type") != "image_url": + continue + url = (item.get("image_url") or {}).get("url", "") + if isinstance(url, str) and url.startswith("data:"): + media_urls.append(url) + + # Images endpoint style: data[*].b64_json + data_items = out.response_body.get("data", []) + if isinstance(data_items, list): + for data_item in data_items: + if not isinstance(data_item, dict): + continue + b64_json = data_item.get("b64_json", "") + if isinstance(b64_json, str) and b64_json: + media_urls.append(f"data:image/png;base64,{b64_json}") + + for img_idx, url in enumerate(media_urls): + if "," not in url: + continue + + try: + header, b64_data = url.split(",", 1) + ext = "png" + if "image/jpeg" in header: + ext = "jpg" + elif "image/webp" in header: + ext = "webp" + elif "video/mp4" in header: + ext = "mp4" + + img_bytes = base64.b64decode(b64_data) + fname = f"req_{idx:04d}_{img_idx}.{ext}" + fpath = os.path.join(save_dir, fname) + with open(fpath, "wb") as f: + f.write(img_bytes) + saved += 1 + except Exception as e: + print(f"Warning: failed to save image for request {idx}: {e}") + + print(f"Saved {saved} generated image(s) to {save_dir}") + + def wait_for_service(base_url: str, timeout: int = 120) -> None: print(f"Waiting for service at {base_url}...") start_time = time.time() @@ -995,6 +1061,9 @@ async def limited_request_func(req, session, pbar): print("\n" + "=" * 60) + if args.save_dir: + _save_generated_outputs(outputs, requests_list, args.save_dir) + if args.output_file: with open(args.output_file, "w") as f: json.dump(metrics, f, indent=2) @@ -1069,8 +1138,10 @@ async def limited_request_func(req, session, pbar): parser.add_argument( "--warmup-num-inference-steps", type=int, - default=1, - help="num_inference_steps used for warmup requests.", + default=2, + help="num_inference_steps used for warmup requests. " + "Must be >= 2 to ensure at least one denoising step is executed " + "(some models, e.g. Bagel, run num_timesteps-1 denoising iterations).", ) parser.add_argument("--width", type=int, default=None, help="Image/Video width.") parser.add_argument("--height", type=int, default=None, help="Image/Video height.") @@ -1104,6 +1175,13 @@ async def limited_request_func(req, session, pbar): default=3.0, help="SLO target multiplier: slo_ms = estimated_exec_time_ms * slo_scale (default: 3).", ) + parser.add_argument( + "--save-dir", + type=str, + default=None, + help="Directory to save generated images/outputs for visual inspection. " + "If not set, generated outputs are discarded after metric collection.", + ) parser.add_argument("--disable-tqdm", action="store_true", help="Disable progress bar.") parser.add_argument( "--enable-negative-prompt", diff --git a/benchmarks/kernels/mot_linear_benchmarks.py b/benchmarks/kernels/mot_linear_benchmarks.py new file mode 100644 index 00000000000..7d22c82090d --- /dev/null +++ b/benchmarks/kernels/mot_linear_benchmarks.py @@ -0,0 +1,1015 @@ +# ruff: noqa: N803 + +"""MoT (Mixture-of-Tokens) GEMM kernel benchmark and auto-tuning. + +Generates optimal Triton kernel configurations for MoT GEMM operations +across different batch sizes, model shapes, TP configurations, and hardware. + +Usage: + # Auto-tune and save configs: + python benchmarks/kernels/mot_linear_benchmarks.py \ + --model ByteDance-Seed/BAGEL-7B-MoT \ + --tp-size 1 --dtype w16a16 --tune \ + --save-dir vllm_omni/diffusion/layers/mot/configs/ + + # Auto-tune with local model path (offline clusters): + python benchmarks/kernels/mot_linear_benchmarks.py \ + --model /data/models/BAGEL-7B-MoT \ + --tp-size 2 --tune + + # Benchmark only (measure with default configs, no search): + python benchmarks/kernels/mot_linear_benchmarks.py \ + --model ByteDance-Seed/BAGEL-7B-MoT \ + --tp-size 1 --dtype w16a16 +""" + +import argparse +import gc +import json +import logging +import math +import os +import time +from datetime import datetime +from itertools import product +from typing import Any + +import ray +import torch +from ray.experimental.tqdm_ray import tqdm +from vllm.platforms import current_platform +from vllm.transformers_utils.config import get_config +from vllm.triton_utils import triton +from vllm.utils.torch_utils import set_random_seed + +# NOTE: you should use the same naming syetem for the kernel to load properly +from vllm_omni.diffusion.layers.mot.ops.mot_gemm import build_config_filename, get_device_name + +# clear the triton cache from time to time, usaully no need to change +_CACHE_CLEAR_INTERVAL_ENV = "VLLM_MOT_TUNE_CACHE_CLEAR_INTERVAL" +TRITON_CACHE_CLEAR_INTERVAL = int(os.environ.get(_CACHE_CLEAR_INTERVAL_ENV, "50")) + +# represent the token number of each generated image +_VAE_CHUNK_SIZE_ENV = "VAE_CHUNK_SIZE" +VAE_CHUNK_SIZE = int(os.environ.get(_VAE_CHUNK_SIZE_ENV, "1024")) + +logger = logging.getLogger(__name__) + +# ===================================================================== +# Utility Functions +# ===================================================================== + + +def clear_triton_cache(): + """Clear Triton JIT compilation cache and Python/CUDA memory.""" + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + try: + if hasattr(triton, "runtime") and hasattr(triton.runtime, "cache") and hasattr(triton.runtime.cache, "clear"): + triton.runtime.cache.clear() + except Exception: + pass + gc.collect() + + +# TODO: check rocm/npus support +# based on https://docs.nvidia.com/cuda/cuda-runtime-api/ +# structcudaDeviceProp.html#structcudaDeviceProp_16cede1829516e86917f0842a5f6498c8 +def get_max_shared_memory() -> int: + """Return the maximum shared memory per block in bytes.""" + props = torch.cuda.get_device_properties(0) + if hasattr(props, "shared_memory_per_block_option"): + return props.shared_memory_per_block_option + return getattr(props, "shared_memory_per_block", 49152) + + +def get_max_regs() -> int: + """Return the maximum registers per block in bytes.""" + props = torch.cuda.get_device_properties(0) + if hasattr(props, "regs_per_block"): + return props.regs_per_block + return 65536 + + +def get_sm_count() -> int: + """Get the number of physical SMs on the target GPU (A100 = 108, H100 = 132).""" + return torch.cuda.get_device_properties(0).multi_processor_count + + +def get_ab_element_bytes(dtype_str: str) -> tuple[int, int]: + """Return ``(activation_bytes, weight_bytes)`` for a dtype config.""" + if dtype_str == "w16a16": + return 2, 2 + elif dtype_str == "fp8_w8a8": + return 1, 1 + elif dtype_str == "int8_w8a16": + return 2, 1 + return 2, 2 + + +def build_regular_indices( + image_num: int, + vae_chunk_size: int, + device: str = "cuda", +) -> tuple[torch.Tensor, torch.Tensor, int]: + """Build deterministic MoT indices with per-image [Text][VAE...][Text].""" + if image_num <= 0: + raise ValueError(f"image_num must be > 0, got {image_num}") + if vae_chunk_size <= 0: + raise ValueError(f"{_VAE_CHUNK_SIZE_ENV} must be > 0, got {vae_chunk_size}") + + text_idx_list: list[int] = [] + vae_idx_list: list[int] = [] + + current_idx = 0 + for _ in range(image_num): + text_idx_list.append(current_idx) + current_idx += 1 + + vae_idx_list.extend(range(current_idx, current_idx + vae_chunk_size)) + current_idx += vae_chunk_size + + text_idx_list.append(current_idx) + current_idx += 1 + + text_indices = torch.tensor(text_idx_list, dtype=torch.long, device=device) + vae_indices = torch.tensor(vae_idx_list, dtype=torch.long, device=device) + exact_M = current_idx + return text_indices, vae_indices, exact_M + + +def get_exact_m(image_num: int, vae_chunk_size: int) -> int: + if image_num <= 0: + raise ValueError(f"image_num must be > 0, got {image_num}") + if vae_chunk_size <= 0: + raise ValueError(f"{_VAE_CHUNK_SIZE_ENV} must be > 0, got {vae_chunk_size}") + return image_num * (vae_chunk_size + 2) + + +# ===================================================================== +# Model Shape Extraction +# ===================================================================== + + +class MoTShape: + """One unique (K, N) GEMM shape in a MoT model layer.""" + + def __init__(self, K: int, N: int, comment: str): + self.K = K + self.N = N + self.comment = comment + + def config_key(self) -> str: + return f"{self.K}_{self.N}" + + def __repr__(self) -> str: + return f"MoTShape(K={self.K}, N={self.N}, comment='{self.comment}')" + + +def get_mot_shapes( + model: str, + tp_size: int, + trust_remote_code: bool = False, +) -> tuple[list[MoTShape], str]: + """Extract MoT GEMM shapes from a HuggingFace model config. + + Supports both remote HuggingFace model IDs and local checkpoint paths. + + Returns + ------- + shapes : list[MoTShape] + De-duplicated GEMM shapes (K, N) with TP applied. + model_name : str + Cleaned model name for the config filename. + """ + config = get_config(model=model, trust_remote_code=trust_remote_code) + model_name = model.rstrip("/").split("/")[-1] + + text_config = getattr(config, "text_config", config) + + hidden_size: int = text_config.hidden_size + num_attention_heads: int = text_config.num_attention_heads + num_kv_heads: int = getattr(text_config, "num_key_value_heads", num_attention_heads) + head_dim: int = getattr(text_config, "head_dim", hidden_size // num_attention_heads) + intermediate_size: int = text_config.intermediate_size + + # ---- Compute per-TP shapes ---- + + # QKV_PROJ (QKVParallelLinear, output partitioned by TP) + q_out = num_attention_heads * head_dim + kv_out = 2 * num_kv_heads * head_dim + qkv_total = q_out + kv_out + assert qkv_total % tp_size == 0, f"QKV output {qkv_total} not divisible by tp {tp_size}" + qkv_N = qkv_total // tp_size + + # O_PROJ (RowParallelLinear, input partitioned by TP) + assert q_out % tp_size == 0, f"Q output {q_out} not divisible by tp {tp_size}" + o_K = q_out // tp_size + o_N = hidden_size + + # FFN gate+up (MergedColumnParallelLinear, output partitioned by TP) + gate_up_total = 2 * intermediate_size + assert gate_up_total % tp_size == 0, f"Gate-up output {gate_up_total} not divisible by tp {tp_size}" + gate_up_N = gate_up_total // tp_size + + # FFN down (RowParallelLinear, input partitioned by TP) + assert intermediate_size % tp_size == 0, f"Intermediate size {intermediate_size} not divisible by tp {tp_size}" + down_K = intermediate_size // tp_size + down_N = hidden_size + + shapes = [ + MoTShape(K=o_K, N=o_N, comment="O_PROJ"), + MoTShape(K=hidden_size, N=qkv_N, comment="QKV_PROJ"), + MoTShape(K=hidden_size, N=gate_up_N, comment="FFN_GATE_UP_PROJ"), + MoTShape(K=down_K, N=down_N, comment="FFN_DOWN_PROJ"), + ] + + seen: dict[str, MoTShape] = {} + unique: list[MoTShape] = [] + for s in shapes: + key = s.config_key() + if key not in seen: + seen[key] = s + unique.append(s) + else: + seen[key].comment += f" / {s.comment}" + + return unique, model_name + + +# ===================================================================== +# Search Space Generation & Pruning +# ===================================================================== + + +def estimate_sram_bytes(config: dict[str, int], dtype_str: str) -> int: + """Estimate SRAM (shared memory) usage for a Triton tile config. + + Formula: + (BLOCK_M * BLOCK_K * a_bytes + BLOCK_N * BLOCK_K * b_bytes) + * num_stages + """ + bm = config["BLOCK_SIZE_M"] + bn = config["BLOCK_SIZE_N"] + bk = config["BLOCK_SIZE_K"] + stages = config["num_stages"] + a_bytes, b_bytes = get_ab_element_bytes(dtype_str) + return (bm * bk * a_bytes + bn * bk * b_bytes) * stages + + +# TODO: +# granule_size= 256 for nvdia gpus, +# warp_size=32 for nvdia gpus, +# not sure if it is true for rocm/other npus +def estimate_register_pressure( + config: dict[str, int], + dtype_str: str, + max_regs_per_block: int = 65536, + max_regs_per_thread: int = 255, + granule_size: int = 256, + warp_size: int = 32, +) -> bool: + """ + Evaluate register pressure for MoT GEMM based on kernel structure and datatypes. + + Args: + config: Triton tile configuration. + dtype_str: for now only support:"w16a16", "fp8_w8a8","int8_w8a16" + max_regs_per_block: Hardware limit (usually 65536). + max_regs_per_thread: PTX limit (usually 255). + granule_size: register allocation size for one warp. + warp_size: number of threads per warp. + Returns: + True if the config is safe to compile and run efficiently, False if it should be pruned. + """ + bm = config["BLOCK_SIZE_M"] + bn = config["BLOCK_SIZE_N"] + bk = config["BLOCK_SIZE_K"] + warps = config["num_warps"] + num_threads = warps * warp_size + + a_bytes, b_bytes = get_ab_element_bytes(dtype_str) + + # Physical register standard: 1 register = 32-bit (4 bytes) + + # [Accumulator C] + # Triton uses fp32/int32 by default as the accumulator for fp16/int8 to prevent overflow + regs_c = (bm * bn) / num_threads * 1.0 + + # [MMA slices A and B] + # Data is loaded into registers from SRAM to participate in Tensor Core operations + regs_a = ((bm * bk) / num_threads) * (a_bytes / 4.0) + regs_b = ((bk * bn) / num_threads) * (b_bytes / 4.0) + + # [MoT specific routing overhead] + # real_row_idxs is tl.int64 (8 bytes), each element needs 2 32-bit registers + regs_routing = (bm / num_threads) * 2.0 + + # [Quantization specific Epilogue overhead] + # W8A8 needs to load scale_a and scale_b after the loop for de-quantization + regs_epilogue = 0.0 + if dtype_str == "fp8_w8a8": + # fp8*token-wise quant scenario: scale_a length is bm, scale_b length is bn + regs_epilogue = ((bm + bn) / num_threads) * 1.0 + elif dtype_str == "int8_w8a16": + # Weight-Only*token-wise quant scenario: usually only scale_b is needed + regs_epilogue = (bn / num_threads) * 1.0 + + # [Control flow and base pointer constant overhead] + # Includes: loop counter(k), pointer addressing, + # Mask predicate calculation, TMA state machine, etc. + constant_overhead = 35 + # --------------------------------------------------------- + # --------------------------------------------------------- + # Summary and red line intercept + # --------------------------------------------------------- + estimated_regs_per_thread = math.ceil(regs_c + regs_a + regs_b + regs_routing + regs_epilogue + constant_overhead) + # Hardware red line 1: single thread physical limit + # (PTX ISA specifies a maximum of 255 registers per thread, + # leaving 10% as a compiler buffer) + if estimated_regs_per_thread > max_regs_per_thread * 0.9: + return False + + # Hardware red line 2: single block total physical limit + # (each warp allocated registers are rounded up to the nearest multiple of 256) + regs_per_warp_raw = estimated_regs_per_thread * warp_size + regs_per_warp_actual = math.ceil(regs_per_warp_raw / granule_size) * granule_size + + # Calculate the actual physical register consumption for the current block + estimated_regs_per_block = regs_per_warp_actual * warps + if estimated_regs_per_block > max_regs_per_block: + return False + + return True + + +def get_mot_search_space( + M: int, + K: int, + N: int, + dtype_str: str, + max_sram: int, + max_regs: int, + num_sms: int, +) -> list[dict[str, int]]: + """Generate a pruned search space of Triton tile configs for MoT GEMM.""" + + param_ranges = { + "BLOCK_SIZE_M": [32, 64, 128, 256], + "BLOCK_SIZE_N": [32, 64, 128, 256], + "BLOCK_SIZE_K": [32, 64, 128], + "GROUP_SIZE_M": [4, 8, 16], + "num_warps": [4, 8], + "num_stages": [2, 3, 4, 5], + } + + def next_power_of_2(n): + return 1 if n == 0 else 2 ** (n - 1).bit_length() + + padded_M = next_power_of_2(M) + padded_N = next_power_of_2(N) + padded_K = next_power_of_2(K) + + keys, values = zip(*param_ranges.items()) + configs: list[dict[str, int]] = [] + + for vals in product(*values): + cfg = dict(zip(keys, vals)) + bm = cfg["BLOCK_SIZE_M"] + bn = cfg["BLOCK_SIZE_N"] + bk = cfg["BLOCK_SIZE_K"] + + # --- Dimension-based pruning --- + if bm > max(32, padded_M): + continue + if bn > max(32, padded_N): + continue + if bk > max(32, padded_K): + continue + if bm * bn < 64: + continue + + # --- Occupancy-based pruning --- + grid_m = (M + bm - 1) // bm + grid_n = (N + bn - 1) // bn + total_blocks = grid_m * grid_n + + if total_blocks < num_sms // 4: + continue + + # --- SRAM capacity check --- + if estimate_sram_bytes(cfg, dtype_str) > max_sram * 0.9: + continue + + # --- register spilling check --- + if not estimate_register_pressure(cfg, dtype_str, max_regs): + continue + + configs.append(cfg) + + return configs + + +# ===================================================================== +# Single-Config Benchmark +# ===================================================================== + + +def benchmark_config( + config: dict[str, int], + image_num: int, + K: int, + N: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + num_iters: int = 10, + cache_flusher: torch.Tensor | None = None, +) -> float: + """Run a MoT GEMM with the given tile config and return avg latency (us).""" + from vllm_omni.diffusion.layers.mot.ops.mot_gemm import invoke_mot_gemm + + text_indices, vae_indices, M = build_regular_indices( + image_num=image_num, + vae_chunk_size=VAE_CHUNK_SIZE, + device="cuda", + ) + + # ---- Allocate tensors on the current CUDA device ---- + A_scale: torch.Tensor | None = None + B_text_scale: torch.Tensor | None = None + B_vae_scale: torch.Tensor | None = None + + if use_fp8_w8a8: + fp8_dtype = current_platform.fp8_dtype() + A = torch.randn(M, K, dtype=torch.float16, device="cuda").to(fp8_dtype) + B_text = torch.randn(K, N, dtype=torch.float16, device="cuda").to(fp8_dtype) + B_vae = torch.randn(K, N, dtype=torch.float16, device="cuda").to(fp8_dtype) + A_scale = torch.ones(M, dtype=torch.float32, device="cuda") + B_text_scale = torch.ones(1, dtype=torch.float32, device="cuda") + B_vae_scale = torch.ones(1, dtype=torch.float32, device="cuda") + C = torch.empty(M, N, dtype=dtype, device="cuda") + elif use_int8_w8a16: + A = torch.randn(M, K, dtype=dtype, device="cuda") + B_text = torch.randint(-127, 127, (K, N), dtype=torch.int8, device="cuda") + B_vae = torch.randint(-127, 127, (K, N), dtype=torch.int8, device="cuda") + B_text_scale = torch.ones(N, dtype=torch.float32, device="cuda") + B_vae_scale = torch.ones(N, dtype=torch.float32, device="cuda") + C = torch.empty(M, N, dtype=dtype, device="cuda") + else: + A = torch.randn(M, K, dtype=dtype, device="cuda") + B_text = torch.randn(K, N, dtype=dtype, device="cuda") + B_vae = torch.randn(K, N, dtype=dtype, device="cuda") + C = torch.empty(M, N, dtype=dtype, device="cuda") + + def run(): + invoke_mot_gemm( + A=A, + B_text=B_text, + B_vae=B_vae, + C=C, + bias_text=None, + bias_vae=None, + text_indices=text_indices, + vae_indices=vae_indices, + A_scale=A_scale, + B_text_scale=B_text_scale, + B_vae_scale=B_vae_scale, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=False, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=False, + A_per_channel_quant=use_fp8_w8a8, + B_per_channel_quant=use_int8_w8a16, + config=config, + ) + + # JIT warmup + run() + torch.cuda.synchronize() + + # Capture 1 invocations with CUDA Graph + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + run() + torch.cuda.synchronize() + + # Warmup replays + for _ in range(5): + graph.replay() + torch.cuda.synchronize() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + latencies: list[float] = [] + for _ in range(num_iters): + if cache_flusher is not None: + cache_flusher.zero_() + torch.cuda.synchronize() + + start_event.record() + graph.replay() + end_event.record() + end_event.synchronize() + + latencies.append(start_event.elapsed_time(end_event)) + + latencies.sort() + valid_latencies = latencies[1:-1] if len(latencies) > 2 else latencies + + avg_us = sum(valid_latencies) / len(valid_latencies) * 1000 # ms → us + graph.reset() + + return avg_us + + +# ===================================================================== +# Ray Worker +# ===================================================================== + + +@ray.remote(num_gpus=1) +class BenchmarkWorker: + def __init__(self, seed: int) -> None: + # Ray will automatically set CUDA_VISIBLE_DEVICES, + # so the GPU seen by the worker is always the logical 0 + self.logical_device_id = 0 + torch.set_default_device(f"cuda:{self.logical_device_id}") + + set_random_seed(seed) + self.seed = seed + + # ---- Benchmark (use default config, report latency) ---- + + def benchmark( + self, + image_num: int, + K: int, + N: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + ) -> tuple[dict[str, int], float]: + set_random_seed(self.seed) + from vllm_omni.diffusion.layers.mot.ops.mot_gemm import ( + get_best_mot_config, + ) + + M = get_exact_m(image_num, VAE_CHUNK_SIZE) + loaded_m_key, config = get_best_mot_config(M, N, K) + if loaded_m_key == -1: + print( + " [config] WARNING: No tuned config found — " + "using conservative default. " + "Performance numbers are NOT representative. " + "Run mot_linear_benchmarks.py --tune to generate configs." + ) + else: + print(f" [config] Tuned config loaded (actual M={M}, loaded M={loaded_m_key}) config = {config})") + kernel_time = benchmark_config( + config, + image_num, + K, + N, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + num_iters=10, + ) + return config, kernel_time + + # ---- Tune (search over all configs, return best) ---- + + def tune( + self, + image_num: int, + K: int, + N: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + search_space: list[dict[str, int]], + ) -> dict[str, int] | None: + set_random_seed(self.seed) + M = get_exact_m(image_num, VAE_CHUNK_SIZE) + + best_config: dict[str, int] | None = None + best_time = float("inf") + + # Diagnosis counters + total_configs = len(search_space) + err_oom = 0 + err_triton_resources = 0 + err_other = 0 + + with torch.cuda.device(self.logical_device_id): + cache_flusher = torch.empty(int(256 * 1024 * 1024 / 4), dtype=torch.int32, device="cuda") + + for idx, config in enumerate(tqdm(search_space)): + try: + kernel_time = benchmark_config( + config, + image_num, + K, + N, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + num_iters=10, + cache_flusher=cache_flusher, + ) + except triton.runtime.autotuner.OutOfResources: + err_triton_resources += 1 + continue + except torch.cuda.OutOfMemoryError: + err_oom += 1 + clear_triton_cache() + continue + except Exception: + err_other += 1 + logger.exception("Config %s failed unexpectedly", config) + clear_triton_cache() + continue + + if kernel_time < best_time: + best_time = kernel_time + best_config = config + + if TRITON_CACHE_CLEAR_INTERVAL > 0 and idx > 0 and idx % TRITON_CACHE_CLEAR_INTERVAL == 0: + clear_triton_cache() + + del cache_flusher + clear_triton_cache() + + if best_config is None: + diag_msg = ( + f"\n🚨 [CRITICAL] TUNING FAILED for M={M}, K={K}, N={N}\n" + f" Total configs tested: {total_configs}\n" + f" - Triton OutOfResources (SRAM/Regs): {err_triton_resources}\n" + f" - CUDA OOM: {err_oom}\n" + f" - Other Errors: {err_other}\n" + f" 💡 DIAGNOSIS:\n" + f" 1. If total configs is 0, your 'get_mot_search_space' pruning is too aggressive.\n" + f" 2. If Triton/OOM errors == total configs, hardware limits (SRAM/Regs) in pruning are too loose.\n" + f" 3. If Other Errors is high, check benchmark_config logic or Triton kernel runtime bugs." + ) + print(diag_msg) + return None + + now = datetime.now() + print(f"[{now.ctime()}] Tuning done: M={M}, K={K}, N={N}, best_time={best_time:.2f} us") + return best_config + + # ---- Device info helpers (called from driver) ---- + + def get_device_name(self) -> str: + return get_device_name() + + def get_max_shared_memory(self) -> int: + return get_max_shared_memory() + + def get_sm_count(self) -> int: + return get_sm_count() + + def get_max_regs(self) -> int: + return get_max_regs() + + +# ===================================================================== +# Config I/O +# ===================================================================== + + +def sort_config(config: dict[str, int]) -> dict[str, int]: + return { + "BLOCK_SIZE_M": config["BLOCK_SIZE_M"], + "BLOCK_SIZE_N": config["BLOCK_SIZE_N"], + "BLOCK_SIZE_K": config["BLOCK_SIZE_K"], + "GROUP_SIZE_M": config["GROUP_SIZE_M"], + "num_warps": config["num_warps"], + "num_stages": config["num_stages"], + } + + +def save_configs( + results: dict[str, dict[int, dict[str, int]]], + shapes: list[MoTShape], + model_name: str, + tp_size: int, + device_name: str, + dtype_str: str, + save_dir: str, +) -> str: + """Merge tuned configs into ``device_name=...,dtype=....json``. + + Behavior: + - Create a new file if it does not exist. + - If it exists, merge by shape key (``K_N``) and M key. + - Existing entries are preserved unless overwritten by current results. + """ + shape_map = {s.config_key(): s for s in shapes} + + current_output: dict[str, Any] = {} + for config_key, m_configs in results.items(): + shape = shape_map[config_key] + # Make comments self-descriptive across mixed model/tp runs. + comment = f"model={model_name}|tp={tp_size}|op={shape.comment}" + entry: dict[str, Any] = {"_comment": comment} + for m_val in sorted(m_configs.keys()): + entry[str(m_val)] = sort_config(m_configs[m_val]) + current_output[config_key] = entry + + filename = f"device_name={device_name},dtype={dtype_str}.json" + os.makedirs(save_dir, exist_ok=True) + filepath = os.path.join(save_dir, filename) + + merged_output: dict[str, Any] = {} + if os.path.isfile(filepath): + try: + with open(filepath) as f: + existing = json.load(f) + if isinstance(existing, dict): + merged_output = existing + else: + print(f"WARNING: Existing config is not a JSON object: {filepath}. Overwrite with newly tuned configs.") + except Exception as exc: + print(f"WARNING: Failed to read existing config {filepath}: {exc}. Overwrite with newly tuned configs.") + + # Merge on two levels: shape key -> M key + for config_key, new_entry in current_output.items(): + old_entry = merged_output.get(config_key, {}) + if not isinstance(old_entry, dict): + old_entry = {} + merged_entry = dict(old_entry) + old_comment = merged_entry.get("_comment") + new_comment = new_entry.get("_comment") + merged_entry.update(new_entry) + if old_comment and new_comment and old_comment != new_comment: + merged_entry["_comment"] = f"{old_comment} / {new_comment}" + merged_output[config_key] = merged_entry + + print(f"Saving merged config to {filepath}") + with open(filepath, "w") as f: + json.dump(merged_output, f, indent=2) + f.write("\n") + + return filepath + + +# ===================================================================== +# Main +# ===================================================================== + + +def main(args: argparse.Namespace): + print(args) + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + ) + + if not current_platform.is_cuda(): + raise NotImplementedError( + "Non-CUDA environment detected!" + "This benchmark script has not been fully tested on" + "AMD GPUs and may produce errors or suboptimal results." + ) + + # ---- 1. Extract model shapes ---- + shapes, model_name = get_mot_shapes(args.model, args.tp_size, args.trust_remote_code) + print(f"\nModel: {model_name} | TP: {args.tp_size}") + print(f"Detected {len(shapes)} unique GEMM shape(s):") + for s in shapes: + print(f" {s}") + + # ---- 2. Determine dtype ---- + dtype_str: str = args.dtype + use_fp8_w8a8 = dtype_str == "fp8_w8a8" + use_int8_w8a16 = dtype_str == "int8_w8a16" + dtype = torch.bfloat16 + + # ---- 3. Image counts ---- + image_nums: list[int] = args.batch_size if args.batch_size is not None else [1, 2, 4, 8, 16] + + # ---- 4. Initialize Ray workers ---- + ray.init() + num_gpus = int(ray.available_resources()["GPU"]) + workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)] + print(f"\nRay initialized with {num_gpus} GPU worker(s)") + + device_name = ray.get(workers[0].get_device_name.remote()) + max_sram = ray.get(workers[0].get_max_shared_memory.remote()) + max_regs = ray.get(workers[0].get_max_regs.remote()) + sm_count = ray.get(workers[0].get_sm_count.remote()) + + print( + f"Device: {device_name} | Max SRAM/Block: {max_sram} bytes\n" + f" Max Regs/Block: {max_regs} 32-bit regs\n" + f" SM on GPU: {sm_count} \n" + ) + + # ---- Helper: round-robin distribute tasks to workers ---- + def distribute(method: str, inputs: list[tuple[Any, ...]]) -> list[Any]: + futures = [] + for i, input_args in enumerate(inputs): + worker = workers[i % num_gpus] + futures.append(getattr(worker, method).remote(*input_args)) + return ray.get(futures) + + # ---- 5. TUNE mode ---- + if args.tune: + start = time.time() + + # 1) Checkpoint loading and resuming + filename = build_config_filename(device_name, dtype_str) + filepath = os.path.join(args.save_dir, filename) + + existing_history: dict[str, Any] = {} + if os.path.isfile(filepath): + try: + with open(filepath) as f: + existing_history = json.load(f) + print(f"Loaded existing checkpoint from {filepath}, resuming...") + except Exception as e: + print(f"WARNING: Failed to load existing checkpoint: {e}") + + # 2) Build task queue and execute checkpoint filtering + pending_futures = {} + task_counter = 0 + + for shape in shapes: + shape_key = shape.config_key() + for image_num in image_nums: + exact_M = get_exact_m(image_num, VAE_CHUNK_SIZE) + + if shape_key in existing_history and str(exact_M) in existing_history[shape_key]: + print(f"Skipping image_num={image_num} (M={exact_M}), Shape={shape_key} (Already tuned)") + continue + + # Only tune parameters that have not been tuned yet + search_space = get_mot_search_space( + M=exact_M, + K=shape.K, + N=shape.N, + dtype_str=dtype_str, + max_sram=max_sram, + max_regs=max_regs, + num_sms=sm_count, + ) + if len(search_space) == 0: + print( + f"WARNING: empty search space for " + f"{shape.config_key()} image_num={image_num} (M={exact_M}), " + "skipping" + ) + continue + + # Round-robin assign to Worker + worker = workers[task_counter % num_gpus] + future = worker.tune.remote( + image_num, shape.K, shape.N, dtype, use_fp8_w8a8, use_int8_w8a16, search_space + ) + + # Bind future with its corresponding metadata + pending_futures[future] = (shape, image_num, exact_M) + task_counter += 1 + + print(f"Starting tuning: {len(pending_futures)} new tasks pending...") + + # 3)Async streaming collect results and incremental checkpoint (Streaming Checkpoint) + results: dict[str, dict[int, dict[str, int]]] = {} + + # ray.wait will return when any task is completed + # file I/O is executed serially + while pending_futures: + done_refs, not_done_refs = ray.wait(list(pending_futures.keys()), num_returns=1) + + for ready_future in done_refs: + shape, image_num, exact_M = pending_futures.pop(ready_future) + config_key = shape.config_key() + try: + best_config = ray.get(ready_future) + + if best_config is None: + print( + f"⚠️ SKIPPING CHECKPOINT for image_num={image_num}, " + f"M={exact_M}, Shape={config_key} due to tuning failure. " + "Please review the worker diagnostics above." + ) + continue + + # Put the temporary results of this run into the result set + results.setdefault(config_key, {})[exact_M] = best_config + save_configs( + results={config_key: {exact_M: best_config}}, + shapes=shapes, + model_name=model_name, + tp_size=args.tp_size, + device_name=device_name, + dtype_str=dtype_str, + save_dir=args.save_dir, + ) + print(f"Checkpoint saved for image_num={image_num}, M={exact_M}, Shape={config_key}") + + except Exception as e: + print( + f"🚨 CRITICAL ERROR: Task failed for image_num={image_num}, " + f"M={exact_M}, Shape={config_key}. Error: {e}" + ) + + elapsed = time.time() - start + print(f"\nTuning completed in {elapsed:.1f}s") + print(f"Complete Configs saved to: {filepath}") + + # ---- 6. BENCHMARK mode ---- + else: + all_tasks = [] + task_keys = [] + + for shape in shapes: + for image_num in image_nums: + exact_M = get_exact_m(image_num, VAE_CHUNK_SIZE) + all_tasks.append( + ( + image_num, + shape.K, + shape.N, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + ) + ) + task_keys.append((shape.config_key(), image_num, exact_M)) + + all_results = distribute("benchmark", all_tasks) + + current_key = None + for (config_key, image_num, exact_M), (config, kernel_time) in zip(task_keys, all_results): + if config_key != current_key: + current_key = config_key + shape = next(s for s in shapes if s.config_key() == config_key) + print(f"\n{'=' * 60}") + print(f"Shape: {shape}") + print(f"{'=' * 60}") + print(f" image_num={image_num:>4d} M={exact_M:>6d} {kernel_time:>8.2f} us config={config}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="MoT GEMM kernel benchmark and auto-tuning", + ) + parser.add_argument( + "--model", + type=str, + required=True, + help="HuggingFace model name or local checkpoint path", + ) + parser.add_argument( + "--tp-size", + "-tp", + type=int, + default=1, + help="Tensor parallel size (default: 1)", + ) + parser.add_argument( + "--dtype", + type=str, + default="w16a16", + choices=["w16a16", "fp8_w8a8", "int8_w8a16"], + help="Weight/activation dtype (default: w16a16)", + ) + parser.add_argument( + "--batch-size", + type=int, + nargs="+", + default=None, + help="Image counts to tune/benchmark, note M=batch_size*(VAE_CHUNK_SIZE+2) (default: 1 2 4 8 16)", + ) + parser.add_argument( + "--tune", + action="store_true", + help="Enable auto-tuning mode (search for best configs)", + ) + parser.add_argument( + "--save-dir", + type=str, + default="./", + help="Directory to save tuned config JSON (default: ./)", + ) + parser.add_argument( + "--seed", + type=int, + default=0, + help="Random seed (default: 0)", + ) + parser.add_argument( + "--trust-remote-code", + action="store_true", + help="Trust remote code when loading HuggingFace config", + ) + + args = parser.parse_args() + main(args) diff --git a/tests/diffusion/kernels/__init__.py b/tests/diffusion/kernels/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/diffusion/kernels/mot/__init__.py b/tests/diffusion/kernels/mot/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/diffusion/kernels/mot/test_mot_linear.py b/tests/diffusion/kernels/mot/test_mot_linear.py new file mode 100644 index 00000000000..e48dde4d468 --- /dev/null +++ b/tests/diffusion/kernels/mot/test_mot_linear.py @@ -0,0 +1,459 @@ +# ruff: noqa: N803, E741 +"""Layer-level correctness & performance test for MoT parallel linear layers. + +Compares two equivalent computation paths: + - Reference: 2x standard vLLM parallel linear layers + PyTorch index + scatter/gather (text_linear(x[text_idx]) + vae_linear(x[vae_idx])) + - Target: 1x MoT fused parallel linear layer + (mot_linear(x, text_indices, vae_indices)) + +The reference path uses cuBLAS GEMM (always auto-tuned by cuBLAS). +The MoT path uses a fused Triton kernel whose tile config is loaded +from a JSON file matched by ``device + dtype``. If no tuned +config is found for the current GPU, the kernel falls back to a conservative +default and a warning is printed — the correctness test still passes +but the performance comparison is NOT representative. + +Usage:: + pytest tests/diffusion/kernels/mot/test_mot_linear.py -v -s +""" + +from __future__ import annotations + +import os +import time +from dataclasses import dataclass + +import pytest +import torch +from vllm.config import VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.linear import ( + QKVParallelLinear, + RowParallelLinear, +) + +from vllm_omni.diffusion.distributed.parallel_state import ( + destroy_distributed_env, + init_distributed_environment, + initialize_model_parallel, + model_parallel_is_initialized, +) +from vllm_omni.diffusion.layers.mot.mot_qkv_parallel_linear import ( + MoTQKVParallelLinear, +) +from vllm_omni.diffusion.layers.mot.mot_row_parallel_linear import ( + MoTRowParallelLinear, +) +from vllm_omni.diffusion.layers.mot.ops.mot_gemm import ( + get_best_mot_config, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.gpu] + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +# BAGEL-7B-MoT architecture parameters +_BAGEL_HEAD_SIZE = 128 +_BAGEL_TOTAL_NUM_HEADS = 28 +_BAGEL_TOTAL_NUM_KV_HEADS = 4 +_VAE_CHUNK_SIZE = 1024 # the token number of one image +_IMAGE_NUM = [1, 2, 4, 8] + + +@pytest.fixture(scope="module", autouse=True) +def _init_single_rank_tp_env(): + """Initialize single-rank distributed/TP env for vLLM linear params.""" + os.environ.setdefault("RANK", "0") + os.environ.setdefault("LOCAL_RANK", "0") + os.environ.setdefault("WORLD_SIZE", "1") + os.environ.setdefault("MASTER_ADDR", "127.0.0.1") + os.environ.setdefault("MASTER_PORT", "29501") + + if not torch.distributed.is_initialized(): + init_distributed_environment(world_size=1, rank=0, local_rank=0) + + if not model_parallel_is_initialized(): + initialize_model_parallel( + data_parallel_size=1, + cfg_parallel_size=1, + sequence_parallel_size=1, + ulysses_degree=1, + ring_degree=1, + tensor_parallel_size=1, + pipeline_parallel_size=1, + ) + + yield + destroy_distributed_env() + + +# --------------------------------------------------------------------------- +# DType configuration — extensible for future quantized types +# --------------------------------------------------------------------------- + + +@dataclass +class DTypeConfig: + """Parsed dtype configuration for a test case.""" + + torch_dtype: torch.dtype + use_fp8_w8a8: bool = False + use_int8_w8a16: bool = False + use_int4_w4a16: bool = False + + +def _parse_dtype(dtype_str: str) -> DTypeConfig: + """Parse a dtype string into quantization flags and torch dtype. + + Supported now: + "w16a16_bf16" — BF16 weights & activations (no quantization) + "w16a16_fp16" — FP16 weights & activations (no quantization) + Reserved for future: + "fp8_w8a8" — FP8 W8A8 quantization + "int8_w8a16" — INT8 weight-only quantization + "int4_w4a16" — INT4 weight-only quantization + """ + supported: dict[str, DTypeConfig] = { + "w16a16_bf16": DTypeConfig(torch_dtype=torch.bfloat16), + "w16a16_fp16": DTypeConfig(torch_dtype=torch.float16), + } + if dtype_str in supported: + return supported[dtype_str] + pytest.skip(f"Quantized dtype '{dtype_str}' not yet implemented in layer test") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _report_mot_config(K: int, N: int, M: int): + """Print which Triton tile config the MoT kernel will use.""" + loaded_m_key, config = get_best_mot_config(M, N, K, None) + if loaded_m_key == -1: + print( + " [config] WARNING: No tuned config found — " + "using conservative default. " + "Performance numbers are NOT representative. " + "Run mot_linear_benchmarks.py --tune to generate configs." + ) + else: + print(f" [config] Tuned config loaded (actual M={M}, loaded M={loaded_m_key}) config = {config})") + + +def _make_indices(image_num: int, vae_chunk_size: int, device: str = "cuda") -> tuple[torch.Tensor, torch.Tensor, int]: + """ + Simulate exact Bagel-MoT distributions for image generation: + Pattern per image like: [1 Text] + [4096 VAE] + [1 Text] + Returns text_indices, vae_indices, and the exact total M. + """ + text_idx_list = [] + vae_idx_list = [] + + current_idx = 0 + for _ in range(image_num): + text_idx_list.append(current_idx) + current_idx += 1 + + vae_idx_list.extend(range(current_idx, current_idx + vae_chunk_size)) + current_idx += vae_chunk_size + + text_idx_list.append(current_idx) + current_idx += 1 + + text_indices = torch.tensor(text_idx_list, dtype=torch.long, device=device) + vae_indices = torch.tensor(vae_idx_list, dtype=torch.long, device=device) + + exact_M = current_idx # exact_M = image_num * (vae_chunk_size + 2) + + return text_indices, vae_indices, exact_M + + +def _benchmark(fn, warmup: int = 20, iters: int = 100) -> float: + """Return mean latency in milliseconds.""" + cache_flusher = torch.empty(int(256 * 1024 * 1024 / 4), dtype=torch.int32, device="cuda") + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + t_flush_start = time.perf_counter() + for _ in range(iters): + cache_flusher.zero_() + torch.cuda.synchronize() + flush_time_total = time.perf_counter() - t_flush_start + + # 3. Measure the total time of "flush cache + operator execution" + torch.cuda.synchronize() + t_total_start = time.perf_counter() + for _ in range(iters): + cache_flusher.zero_() + fn() + torch.cuda.synchronize() + total_time = time.perf_counter() - t_total_start + + # 4. Asynchronous subtraction separation + # The E2E time of the pure operator = total time - flush time + pure_fn_time_total = total_time - flush_time_total + avg_ms = (pure_fn_time_total / iters) * 1000.0 + + del cache_flusher + return max(avg_ms, 0.001) + + +def _sync_weights(ref_text, ref_vae, mot_layer): + """Assign the same random weights to reference layers and MoT layer.""" + with torch.no_grad(): + W_text = torch.randn_like(ref_text.weight) * 0.02 + W_vae = torch.randn_like(ref_vae.weight) * 0.02 + ref_text.weight.copy_(W_text) + ref_vae.weight.copy_(W_vae) + mot_layer.weight.copy_(W_text) + mot_layer.gen_exp.weight.copy_(W_vae) + if ref_text.bias is not None and mot_layer.bias is not None: + b_text = torch.randn_like(ref_text.bias) * 0.02 + ref_text.bias.copy_(b_text) + mot_layer.bias.copy_(b_text) + if ref_vae.bias is not None and mot_layer.gen_exp.bias is not None: + b_vae = torch.randn_like(ref_vae.bias) * 0.02 + ref_vae.bias.copy_(b_vae) + mot_layer.gen_exp.bias.copy_(b_vae) + + +def _reference_forward(x, text_indices, vae_indices, text_linear, vae_linear): + """Reference path: index-gather → 2x standard linear → index-scatter.""" + M = x.size(0) + out_text = text_linear(x[text_indices]) + out_vae = vae_linear(x[vae_indices]) + if isinstance(out_text, tuple): + out_text = out_text[0] + if isinstance(out_vae, tuple): + out_vae = out_vae[0] + N = out_text.size(-1) + output = torch.empty(M, N, dtype=x.dtype, device=x.device) + output[text_indices] = out_text + output[vae_indices] = out_vae + return output + + +def _check_and_report(ref: torch.Tensor, mot: torch.Tensor, tag: str): + """Compare outputs, print metrics, and assert correctness. + + Both ``ref`` and ``mot`` are in the original compute dtype (e.g. bf16). + We upcast to fp32 solely for computing error metrics with higher + arithmetic precision — the actual layer outputs remain bf16. + """ + # Upcast for numerically stable error computation only + ref_hp = ref.float() + mot_hp = mot.float() + + abs_err = (ref_hp - mot_hp).abs() + max_abs = abs_err.max().item() + + # Mixed metric: relative error where |ref| >= 1, absolute error otherwise + denom = ref_hp.abs().clamp(min=1.0) + max_rel = (abs_err / denom).max().item() + + cos_sim = ( + torch.nn.functional.cosine_similarity( + ref_hp, + mot_hp, + dim=-1, + ) + .min() + .item() + ) + + print(f"\n [{tag}] max_abs={max_abs:.4e} max_rel={max_rel:.4e} min_cos_sim={cos_sim:.6f}") + + # Cosine similarity is the primary correctness gate: robust to scale + # and accumulation-order differences between cuBLAS and Triton. + # For bf16 GEMM with K up to ~20k, cos_sim > 0.99 is easily achieved. + assert cos_sim > 0.98, f"Cosine similarity too low: {cos_sim:.6f}" + # Supplementary per-element check (generous to avoid flaky failures + # on extreme K dimensions like 18944) + assert max_rel < 0.1, f"Max relative error too large: {max_rel:.4e}" + + +def _run_timing( + ref_fn, + mot_fn, + tag: str, + warmup: int = 20, + iters: int = 100, +): + """Benchmark both paths and print timing comparison.""" + ref_ms = _benchmark(ref_fn, warmup=warmup, iters=iters) + mot_ms = _benchmark(mot_fn, warmup=warmup, iters=iters) + speedup = ref_ms / mot_ms if mot_ms > 0 else float("inf") + print(f" [{tag}] Ref(2x linear): {ref_ms:.3f} ms | MoT(fused): {mot_ms:.3f} ms | Speedup: {speedup:.2f}x") + + +# ========================================================================= +# Test: qkv proj +# ========================================================================= + + +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize( + "image_num, K, N, dtype", + [(num, 3584, 4608, "w16a16_bf16") for num in _IMAGE_NUM], + ids=lambda val: "", +) +def test_mot_qkv_parallel(image_num: int, K: int, N: int, dtype: str, bias: bool): + dcfg = _parse_dtype(dtype) + torch.manual_seed(42) + + with set_current_vllm_config(VllmConfig()): + text_linear = QKVParallelLinear( + hidden_size=K, + head_size=_BAGEL_HEAD_SIZE, + total_num_heads=_BAGEL_TOTAL_NUM_HEADS, + total_num_kv_heads=_BAGEL_TOTAL_NUM_KV_HEADS, + bias=bias, + params_dtype=dcfg.torch_dtype, + disable_tp=True, + ).cuda() + vae_linear = QKVParallelLinear( + hidden_size=K, + head_size=_BAGEL_HEAD_SIZE, + total_num_heads=_BAGEL_TOTAL_NUM_HEADS, + total_num_kv_heads=_BAGEL_TOTAL_NUM_KV_HEADS, + bias=bias, + params_dtype=dcfg.torch_dtype, + disable_tp=True, + ).cuda() + mot_linear = MoTQKVParallelLinear( + hidden_size=K, + head_size=_BAGEL_HEAD_SIZE, + total_num_heads=_BAGEL_TOTAL_NUM_HEADS, + total_num_kv_heads=_BAGEL_TOTAL_NUM_KV_HEADS, + bias=bias, + vae_bias=bias, + params_dtype=dcfg.torch_dtype, + disable_tp=True, + ).cuda() + + assert text_linear.output_size_per_partition == N, ( + f"Expected output_size_per_partition={N}, " + f"got {text_linear.output_size_per_partition}. " + f"Check head parameters." + ) + + _sync_weights(text_linear, vae_linear, mot_linear) + + text_idx, vae_idx, M = _make_indices(image_num, _VAE_CHUNK_SIZE) + x = torch.randn(M, K, dtype=dcfg.torch_dtype, device="cuda") + + tag = f"QKVParallel M={M} K={K} N={N}" + _report_mot_config(K, N, M) + + with torch.no_grad(): + ref = _reference_forward( + x, + text_idx, + vae_idx, + text_linear, + vae_linear, + ) + mot_out, _ = mot_linear(x, text_idx, vae_idx) + + _check_and_report(ref, mot_out, tag) + + with torch.no_grad(): + _run_timing( + lambda: _reference_forward( + x, + text_idx, + vae_idx, + text_linear, + vae_linear, + ), + lambda: mot_linear(x, text_idx, vae_idx), + tag, + ) + + +# ========================================================================= +# Test: o proj +# ========================================================================= + + +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize( + "image_num, K, N, dtype", + [(num, 3584, 3584, "w16a16_bf16") for num in _IMAGE_NUM], + ids=lambda val: "", +) +def test_mot_o_proj( + image_num: int, + K: int, + N: int, + dtype: str, + bias: bool, +): + dcfg = _parse_dtype(dtype) + torch.manual_seed(42) + + with set_current_vllm_config(VllmConfig()): + text_linear = RowParallelLinear( + K, + N, + bias=bias, + input_is_parallel=True, + params_dtype=dcfg.torch_dtype, + disable_tp=True, + ).cuda() + vae_linear = RowParallelLinear( + K, + N, + bias=bias, + input_is_parallel=True, + params_dtype=dcfg.torch_dtype, + disable_tp=True, + ).cuda() + mot_linear = MoTRowParallelLinear( + K, + N, + bias=bias, + vae_bias=bias, + input_is_parallel=True, + params_dtype=dcfg.torch_dtype, + disable_tp=True, + ).cuda() + + _sync_weights(text_linear, vae_linear, mot_linear) + + text_idx, vae_idx, M = _make_indices(image_num, _VAE_CHUNK_SIZE) + x = torch.randn(M, K, dtype=dcfg.torch_dtype, device="cuda") + + tag = f"O Proj M={M} K={K} N={N}" + _report_mot_config(K, N, M) + + # Correctness (also warms up Triton JIT compilation) + with torch.no_grad(): + ref = _reference_forward( + x, + text_idx, + vae_idx, + text_linear, + vae_linear, + ) + mot_out, _ = mot_linear(x, text_idx, vae_idx) + + _check_and_report(ref, mot_out, tag) + + # Performance + with torch.no_grad(): + _run_timing( + lambda: _reference_forward( + x, + text_idx, + vae_idx, + text_linear, + vae_linear, + ), + lambda: mot_linear(x, text_idx, vae_idx), + tag, + ) diff --git a/tests/diffusion/kernels/mot/test_mot_norm.py b/tests/diffusion/kernels/mot/test_mot_norm.py new file mode 100644 index 00000000000..af6724d1e8b --- /dev/null +++ b/tests/diffusion/kernels/mot/test_mot_norm.py @@ -0,0 +1,278 @@ +# ruff: noqa: N803, E741 +"""Layer-level correctness & performance test for MoTRMSNorm. + +Compares two equivalent computation paths: + - Reference: 2x vLLM RMSNorm CUDA kernel + PyTorch index scatter/gather + (rms_norm(x[text_idx], text_w) + rms_norm(x[vae_idx], vae_w)) + - Target: 1x MoTRMSNorm fused Triton kernel + (mot_norm(x, text_indices, vae_indices)) + +Usage:: + pytest tests/diffusion/kernels/mot/test_mot_norm.py -v -s +""" + +from __future__ import annotations + +import time + +import pytest +import torch +from vllm.config import VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.layernorm import rms_norm + +from vllm_omni.diffusion.layers.mot.mot_layernorm import MoTRMSNorm + +pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.gpu] + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +_EPS = 1e-6 + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_indices( + M: int, + text_ratio: float, + device: str = "cuda", +) -> tuple[torch.Tensor, torch.Tensor]: + M_text = max(1, int(M * text_ratio)) + perm = torch.randperm(M, device=device) + text_indices = perm[:M_text].sort().values + vae_indices = perm[M_text:].sort().values + return text_indices, vae_indices + + +def _benchmark(fn, warmup: int = 20, iters: int = 200) -> float: + """Return mean latency in milliseconds.""" + for _ in range(warmup): + fn() + torch.cuda.synchronize() + t0 = time.perf_counter() + for _ in range(iters): + fn() + torch.cuda.synchronize() + return (time.perf_counter() - t0) / iters * 1000.0 + + +def _reference_forward( + x: torch.Tensor, + text_indices: torch.Tensor, + vae_indices: torch.Tensor, + text_weight: torch.Tensor, + vae_weight: torch.Tensor, + eps: float, +) -> torch.Tensor: + """Reference: index-gather → 2x vLLM rms_norm CUDA kernel → scatter.""" + output = torch.empty_like(x) + output[text_indices] = rms_norm(x[text_indices], text_weight, eps) + output[vae_indices] = rms_norm(x[vae_indices], vae_weight, eps) + return output + + +def _reference_forward_head_norm( + x: torch.Tensor, + text_indices: torch.Tensor, + vae_indices: torch.Tensor, + text_weight: torch.Tensor, + vae_weight: torch.Tensor, + eps: float, +) -> torch.Tensor: + """Reference for head_norm=True path. + + The MoT routing happens on token dimension while RMSNorm is applied + independently on each head's last dimension. + """ + output = torch.empty_like(x) + hidden_size = x.shape[-1] + output[text_indices] = rms_norm( + x[text_indices].reshape(-1, hidden_size), + text_weight, + eps, + ).reshape_as(x[text_indices]) + output[vae_indices] = rms_norm( + x[vae_indices].reshape(-1, hidden_size), + vae_weight, + eps, + ).reshape_as(x[vae_indices]) + return output + + +def _check_and_report(ref: torch.Tensor, mot: torch.Tensor, tag: str): + """Compare outputs, print metrics, and assert correctness. + + Both ``ref`` and ``mot`` are in the original compute dtype (e.g. bf16). + We upcast to fp32 solely for computing error metrics with higher + arithmetic precision — the actual layer outputs remain bf16. + """ + ref_hp = ref.float() + mot_hp = mot.float() + + abs_err = (ref_hp - mot_hp).abs() + max_abs = abs_err.max().item() + + denom = ref_hp.abs().clamp(min=1.0) + max_rel = (abs_err / denom).max().item() + + cos_sim = ( + torch.nn.functional.cosine_similarity( + ref_hp, + mot_hp, + dim=-1, + ) + .min() + .item() + ) + + print(f"\n [{tag}] max_abs={max_abs:.4e} max_rel={max_rel:.4e} min_cos_sim={cos_sim:.6f}") + + # RMSNorm is element-wise (no cross-element accumulation like GEMM), + # so the error between two fp32-accumulating implementations should + # be very small — well within bf16 rounding. + assert cos_sim > 0.99, f"Cosine similarity too low: {cos_sim:.6f}" + assert max_rel < 0.05, f"Max relative error too large: {max_rel:.4e}" + + +def _run_timing( + ref_fn, + mot_fn, + tag: str, + warmup: int = 20, + iters: int = 200, +): + """Benchmark both paths and print timing comparison.""" + ref_ms = _benchmark(ref_fn, warmup=warmup, iters=iters) + mot_ms = _benchmark(mot_fn, warmup=warmup, iters=iters) + speedup = ref_ms / mot_ms if mot_ms > 0 else float("inf") + print(f" [{tag}] Ref(2x rms_norm): {ref_ms:.3f} ms | MoT(fused): {mot_ms:.3f} ms | Speedup: {speedup:.2f}x") + + +# ========================================================================= +# Test: MoTRMSNorm vs 2x vLLM rms_norm +# ========================================================================= + + +@pytest.mark.parametrize( + "M, hidden_size, text_ratio", + [ + (2048, 3584, 0.01), + (8192, 3584, 0.01), + (2048, 128, 0.01), + (8192, 128, 0.01), + ], + ids=[ + "M2048_H3584_layernorm", + "M8192_H3584_layernorm", + "M2048_H128_qknorm", + "M8192_H128_qknorm", + ], +) +def test_mot_rms_norm(M: int, hidden_size: int, text_ratio: float): + torch.manual_seed(42) + dtype = torch.bfloat16 + + # --- Build MoT layer --- + mot_norm = MoTRMSNorm(hidden_size, eps=_EPS).cuda() + with torch.no_grad(): + W_text = torch.randn(hidden_size, dtype=dtype, device="cuda") + W_vae = torch.randn(hidden_size, dtype=dtype, device="cuda") + mot_norm.weight.data.copy_(W_text) + mot_norm.gen_weight.data.copy_(W_vae) + + # --- Build reference weights (same data) --- + ref_text_weight = W_text.clone() + ref_vae_weight = W_vae.clone() + + # --- Inputs --- + x = torch.randn(M, hidden_size, dtype=dtype, device="cuda") + text_idx, vae_idx = _make_indices(M, text_ratio) + + tag = f"RMSNorm M={M} H={hidden_size}" + + # vLLM's rms_norm() uses vllm._custom_ops which may inspect global + # config; wrap in VllmConfig context for safety. + with set_current_vllm_config(VllmConfig()): + # --- Correctness (also warms up Triton JIT) --- + with torch.no_grad(): + ref = _reference_forward( + x, + text_idx, + vae_idx, + ref_text_weight, + ref_vae_weight, + _EPS, + ) + mot_out = mot_norm(x, text_idx, vae_idx) + + _check_and_report(ref, mot_out, tag) + + # --- Performance --- + with torch.no_grad(): + _run_timing( + lambda: _reference_forward( + x, + text_idx, + vae_idx, + ref_text_weight, + ref_vae_weight, + _EPS, + ), + lambda: mot_norm(x, text_idx, vae_idx), + tag, + ) + + +@pytest.mark.parametrize( + "M, num_heads, head_dim, text_ratio", + [ + (2048, 28, 128, 0.01), + ], + ids=[ + "M2048_NH28_HD128_qknorm_head_norm", + ], +) +def test_mot_rms_norm_head_norm( + M: int, + num_heads: int, + head_dim: int, + text_ratio: float, +): + torch.manual_seed(42) + dtype = torch.bfloat16 + + # --- Build MoT layer (head_norm path) --- + mot_norm = MoTRMSNorm(head_dim, head_norm=True, eps=_EPS).cuda() + with torch.no_grad(): + W_text = torch.randn(head_dim, dtype=dtype, device="cuda") + W_vae = torch.randn(head_dim, dtype=dtype, device="cuda") + mot_norm.weight.data.copy_(W_text) + mot_norm.gen_weight.data.copy_(W_vae) + + # --- Build reference weights (same data) --- + ref_text_weight = W_text.clone() + ref_vae_weight = W_vae.clone() + + # --- Inputs --- + x = torch.randn(M, num_heads, head_dim, dtype=dtype, device="cuda") + text_idx, vae_idx = _make_indices(M, text_ratio) + + tag = f"RMSNorm(head_norm=True) M={M} NH={num_heads} HD={head_dim}" + + with set_current_vllm_config(VllmConfig()): + # --- Correctness (also warms up Triton JIT) --- + with torch.no_grad(): + ref = _reference_forward_head_norm( + x, + text_idx, + vae_idx, + ref_text_weight, + ref_vae_weight, + _EPS, + ) + mot_out = mot_norm(x, text_idx, vae_idx) + + _check_and_report(ref, mot_out, tag) diff --git a/vllm_omni/diffusion/layers/mot/__init__.py b/vllm_omni/diffusion/layers/mot/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/vllm_omni/diffusion/layers/mot/configs/README b/vllm_omni/diffusion/layers/mot/configs/README new file mode 100644 index 00000000000..6661c6c9aaa --- /dev/null +++ b/vllm_omni/diffusion/layers/mot/configs/README @@ -0,0 +1,50 @@ +This directory contains auto-tuned Triton kernel configurations for the +MoT (Mixture-of-Tokens) GEMM and RMSNorm operators used by BAGEL and other +MoT-architecture diffusion models. + +File naming convention: + device_name=,dtype=.json + +For example: + device_name=NVIDIA_A100-SXM4-80GB,dtype=w16a16.json + +Each JSON file maps (K, N) matrix shapes to a dictionary of batch sizes (M) +and their optimal Triton tile configurations: + + { + "3584_4608": { // K=3584, N=4608 (QKV projection) + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 3 + }, + ... + }, + "3584_3584": { ... } // K=3584, N=3584 (output projection) + } + +Config loading order (3-tier, see ops/mot_gemm.py): + 1. $VLLM_TUNED_CONFIG_FOLDER/ (env override) + 2. This directory: vllm_omni/diffusion/layers/mot/configs/ + 3. Conservative default config (compiles everywhere, sub-optimal perf) + +If no config file matches the current device, a warning is printed with +instructions to run the auto-tuning benchmark. + +To generate configs for your hardware: + + python benchmarks/kernels/mot_linear_benchmarks.py \ + --model ByteDance-Seed/BAGEL-7B-MoT \ + --tp-size 1 --dtype w16a16 --tune \ + --save-dir vllm_omni/diffusion/layers/mot/configs/ + +For multi-GPU tuning (uses Ray for parallel search): + + python benchmarks/kernels/mot_linear_benchmarks.py \ + --model ByteDance-Seed/BAGEL-7B-MoT \ + --tp-size 2 --tune + +See benchmarks/kernels/mot_linear_benchmarks.py for full options. diff --git a/vllm_omni/diffusion/layers/mot/mot_layernorm.py b/vllm_omni/diffusion/layers/mot/mot_layernorm.py new file mode 100644 index 00000000000..c43aacdfecb --- /dev/null +++ b/vllm_omni/diffusion/layers/mot/mot_layernorm.py @@ -0,0 +1,97 @@ +# ruff: noqa: N803, E741 +"""Mixture-of-Tokens (MoT) RMS Normalization layer. + +Holds two sets of weights (text / gen) and routes tokens to the +appropriate weight based on indices. When text_indices is None the +layer degrades to a standard RMSNorm using self.weight (und mode). +""" + +import torch +import torch.nn as nn + +from vllm_omni.diffusion.layers.custom_op import CustomOp + + +class MoTRMSNorm(CustomOp): + """Mixture-of-Tokens RMS Normalization. + + In *und* mode (``text_indices is None``), every token is normalised + with ``self.weight`` – exactly like a vanilla RMSNorm. + + In *gen* mode, text tokens are normalised with ``self.weight`` and + gen tokens are normalised with ``self.gen_weight``, using a single + fused Triton kernel that avoids the gather / scatter overhead. + """ + + def __init__( + self, + hidden_size: int, + head_norm: bool = False, + eps: float = 1e-6, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.variance_epsilon = eps + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.gen_weight = nn.Parameter(torch.ones(hidden_size)) + self.head_norm = head_norm + + # ------------------------------------------------------------------ + # Native (pure-PyTorch) fallback + # ------------------------------------------------------------------ + def forward_native( + self, + x: torch.Tensor, + text_indices: torch.Tensor | None = None, + vae_indices: torch.Tensor | None = None, + ) -> torch.Tensor: + if text_indices is None: + return self._rms_norm_native(x, self.weight) + + output = torch.empty_like(x) + output[text_indices] = self._rms_norm_native(x[text_indices], self.weight) + output[vae_indices] = self._rms_norm_native(x[vae_indices], self.gen_weight) + return output + + # ------------------------------------------------------------------ + # CUDA fast-path + # ------------------------------------------------------------------ + def forward_cuda( + self, + x: torch.Tensor, + text_indices: torch.Tensor | None = None, + vae_indices: torch.Tensor | None = None, + ) -> torch.Tensor: + if text_indices is None: + # und mode – delegate to vllm's highly-optimised CUDA kernel + from vllm.model_executor.layers.layernorm import rms_norm + + return rms_norm(x, self.weight.data, self.variance_epsilon) + + # gen mode – fused MoT Triton kernel + from vllm_omni.diffusion.layers.mot.ops.mot_rms_norm import ( + mot_rms_norm, + ) + + return mot_rms_norm( + x, + self.weight.data, + self.gen_weight.data, + text_indices, + vae_indices, + head_norm=self.head_norm, + eps=self.variance_epsilon, + ) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + def _rms_norm_native(self, x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + orig_dtype = x.dtype + x = x.float() + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + return (x * weight.float()).to(orig_dtype) + + def extra_repr(self) -> str: + return f"hidden_size={self.hidden_size}, eps={self.variance_epsilon}" diff --git a/vllm_omni/diffusion/layers/mot/mot_qkv_parallel_linear.py b/vllm_omni/diffusion/layers/mot/mot_qkv_parallel_linear.py new file mode 100644 index 00000000000..95df808319a --- /dev/null +++ b/vllm_omni/diffusion/layers/mot/mot_qkv_parallel_linear.py @@ -0,0 +1,334 @@ +from __future__ import annotations + +import torch +from torch.nn.parameter import Parameter +from vllm.distributed import ( + tensor_model_parallel_all_gather, +) +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import ( + WEIGHT_LOADER_V2_SUPPORTED, + QKVParallelLinear, +) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, +) +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform + +from vllm_omni.diffusion.layers.mot.ops.mot_gemm import invoke_mot_gemm + +logger = init_logger(__name__) + + +class MoTQKVParallelLinear(QKVParallelLinear): + """QKVParallelLinear with Mixture-of-Tokens routing. + + Text weights: stored directly on self (self.weight, self.weight_scale, ...), + created through the standard QKVParallelLinear.__init__ process. + + VAE weights: stored in the permanent submodule self.gen_exp + (self.gen_exp.weight, ...), + created via quant_method.create_weights(self.gen_exp, ...). + gen_exp.quant_method points to the same quant_method, so that + the vLLM framework’s process_weights_after_loading + can automatically detect and process it. + + Forward behavior: + - und mode (text_indices is None): fully reuse super().forward() + - gen mode: call the MoT fused GEMM kernel + """ + + def __init__( + self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: int | None = None, + bias: bool = True, + vae_bias: bool = False, + skip_bias_add: bool = False, + params_dtype: torch.dtype | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + *, + return_bias: bool = True, + disable_tp: bool = False, + v_head_size: int | None = None, + ): + # ---- 1) Parent class creates text weights ---- + # super().__init__ will do the following: + # quant_method.create_weights(self, ...) → self.weight, self.weight_scale, ... + # QKVParallelLinear hardcodes gather_output=False + super().__init__( + hidden_size, + head_size, + total_num_heads, + total_num_kv_heads, + bias, + skip_bias_add, + params_dtype, + quant_config, + prefix, + return_bias=return_bias, + disable_tp=disable_tp, + v_head_size=v_head_size, + ) + + # ---- 2) Create vae weights (permanent submodule) ---- + assert self.quant_method is not None + + self.gen_exp = torch.nn.Module() + + # Use the same weight_loader as text + vae_weight_loader = ( + self.weight_loader_v2 + if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED + else self.weight_loader + ) + self.quant_method.create_weights( + layer=self.gen_exp, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=vae_weight_loader, + ) + + # Make gen_exp discoverable by vLLM framework's process_weights_after_loading + self.gen_exp.quant_method = self.quant_method + + # ---- 3) vae bias ---- + if vae_bias: + self.gen_exp.bias = Parameter(torch.empty(self.output_size_per_partition, dtype=self.params_dtype)) + set_weight_attrs( + self.gen_exp.bias, + {"output_dim": 0, "weight_loader": self.weight_loader}, + ) + else: + self.gen_exp.register_parameter("bias", None) + + self.update_param_tp_status() + + # ================================================================== + # Forward + # ================================================================== + def forward( + self, + input_: torch.Tensor, + text_indices: torch.Tensor | None = None, + vae_indices: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: + # ---- und mode: fully reuse parent class (only text path) ---- + if text_indices is None: + return super().forward(input_) + + # ---- gen mode: fuse MoT GEMM ---- + output_parallel = self._mot_gemm_dispatch(input_, text_indices, vae_indices) + + # QKVParallelLinear hardcodes gather_output=False, this branch never enters; + # retained for future subclass changes + if self.gather_output and self.tp_size > 1: + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + + if not self.return_bias: + return output + + if self.skip_bias_add and text_indices is not None and (self.bias is not None or self.gen_exp.bias is not None): + merged_bias = torch.zeros( + output.size(0), + self.output_size_per_partition, + dtype=output.dtype, + device=output.device, + ) + if self.bias is not None: + merged_bias[text_indices] = self.bias + if self.gen_exp.bias is not None: + merged_bias[vae_indices] = self.gen_exp.bias + return output, merged_bias + + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + # ================================================================== + # MoT GEMM dispatcher + # ================================================================== + def _mot_gemm_dispatch( + self, + x: torch.Tensor, + text_indices: torch.Tensor, + vae_indices: torch.Tensor, + ) -> torch.Tensor: + # Any other backend (ROCm, XPU, TPU, CPU) uses the safe vllm fallback. + if not current_platform.is_cuda(): + return self._mot_fallback(x, text_indices, vae_indices) + + """Dispatch to different MoT kernel paths based on weight dtype / quant attributes.""" + bias_text = self.bias if not self.skip_bias_add else None + bias_vae = self.gen_exp.bias if not self.skip_bias_add else None + + w_text = self.weight + w_vae = self.gen_exp.weight + assert w_text.dtype.is_floating_point == w_vae.dtype.is_floating_point, ( + "weight of text expert and image expert should be the same dtype." + ) + + # w_text.dtype.itemsize >= 2 means bytes_per_element >= 2 (16bits or 32bits) + if w_text.dtype.is_floating_point and w_text.dtype.itemsize >= 2: + return self._mot_gemm_unquantized( + x, + text_indices, + vae_indices, + bias_text, + bias_vae, + ) + elif w_text.dtype == torch.float8_e4m3fn: + return self._mot_gemm_fp8_w8a8( + x, + text_indices, + vae_indices, + bias_text, + bias_vae, + ) + elif w_text.dtype == torch.int8: + return self._mot_gemm_weight_only( + x, + text_indices, + vae_indices, + bias_text, + bias_vae, + ) + else: + return self._mot_fallback( + x, + text_indices, + vae_indices, + ) + + # ================================================================== + # Implementation of each quantization path + # ================================================================== + def _mot_gemm_unquantized(self, x, text_idx, vae_idx, bias_t, bias_v): + """BF16/FP16 path.""" + N = self.output_size_per_partition + C = torch.zeros(x.size(0), N, dtype=x.dtype, device=x.device) + invoke_mot_gemm( + A=x, + B_text=self.weight.data.t(), + B_vae=self.gen_exp.weight.data.t(), + C=C, + bias_text=bias_t, + bias_vae=bias_v, + text_indices=text_idx, + vae_indices=vae_idx, + A_scale=None, + B_text_scale=None, + B_vae_scale=None, + use_fp8_w8a8=False, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + A_per_channel_quant=False, + B_per_channel_quant=False, + ) + return C + + def _mot_gemm_fp8_w8a8(self, x, text_idx, vae_idx, bias_t, bias_v): + """FP8 W8A8 path. + + 1) Activation quantization: reuse vllm's ops.scaled_fp8_quant + 2) MoT GEMM: text/vae use different fp8 weights and weight_scale + 3) De-quantization: done by MoT kernel's epilogue + """ + from vllm import _custom_ops as ops + + x_2d = x.view(-1, x.shape[-1]) + input_scale = getattr(self, "input_scale", None) + x_fp8, x_scale = ops.scaled_fp8_quant( + x_2d, + input_scale, + use_per_token_if_dynamic=True, + ) + + N = self.output_size_per_partition + C = torch.zeros(x.size(0), N, dtype=x.dtype, device=x.device) + + invoke_mot_gemm( + A=x_fp8, + B_text=self.weight.data, + B_vae=self.gen_exp.weight.data, + C=C, + bias_text=bias_t, + bias_vae=bias_v, + text_indices=text_idx, + vae_indices=vae_idx, + A_scale=x_scale, + B_text_scale=self.weight_scale.data, + B_vae_scale=self.gen_exp.weight_scale.data, + use_fp8_w8a8=True, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + A_per_channel_quant=True, + B_per_channel_quant=False, + ) + return C + + def _mot_gemm_weight_only(self, x, text_idx, vae_idx, bias_t, bias_v): + """Weight-Only W8A16 path. + + Activation values kept as bf16/fp16, weights are int8 + per-channel scale. + De-quantization done by MoT kernel's epilogue. + """ + N = self.output_size_per_partition + C = torch.zeros(x.size(0), N, dtype=x.dtype, device=x.device) + invoke_mot_gemm( + A=x, + B_text=self.weight.data.t(), + B_vae=self.gen_exp.weight.data.t(), + C=C, + bias_text=bias_t, + bias_vae=bias_v, + text_indices=text_idx, + vae_indices=vae_idx, + A_scale=None, + B_text_scale=self.weight_scale.data, + B_vae_scale=self.gen_exp.weight_scale.data, + use_fp8_w8a8=False, + use_int8_w8a8=False, + use_int8_w8a16=True, + use_int4_w4a16=False, + A_per_channel_quant=False, + B_per_channel_quant=True, + ) + return C + + def _mot_fallback(self, x, text_idx, vae_idx): + """Fallback: fall back to gather/scatter + quant_method.apply. + + For unsupported quantization types, call standard forward for text/vae tokens. + """ + assert self.quant_method is not None + + bias_text = self.bias if not self.skip_bias_add else None + bias_vae = self.gen_exp.bias if not self.skip_bias_add else None + + output = torch.zeros( + x.size(0), + self.output_size_per_partition, + dtype=x.dtype, + device=x.device, + ) + output_text = self.quant_method.apply(self, x[text_idx], bias_text) + + output_vae = self.quant_method.apply( + self.gen_exp, + x[vae_idx], + bias_vae, + ) + output[text_idx] = output_text + output[vae_idx] = output_vae + return output diff --git a/vllm_omni/diffusion/layers/mot/mot_row_parallel_linear.py b/vllm_omni/diffusion/layers/mot/mot_row_parallel_linear.py new file mode 100644 index 00000000000..f69aab2d7a4 --- /dev/null +++ b/vllm_omni/diffusion/layers/mot/mot_row_parallel_linear.py @@ -0,0 +1,346 @@ +from __future__ import annotations + +import torch +from torch.nn.parameter import Parameter +from vllm.distributed import ( + split_tensor_along_last_dim, + tensor_model_parallel_all_reduce, +) +from vllm.model_executor.layers.linear import ( + WEIGHT_LOADER_V2_SUPPORTED, + RowParallelLinear, +) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, +) +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform + +from vllm_omni.diffusion.layers.mot.ops.mot_gemm import invoke_mot_gemm + + +class MoTRowParallelLinear(RowParallelLinear): + """RowParallelLinear with Mixture-of-Tokens routing. + + text weights: directly on self (self.weight, self.weight_scale, ...), + created by RowParallelLinear.__init__ standard process. + vae weights: on permanent submodule self.gen_exp (self.gen_exp.weight, ...), + created by quant_method.create_weights(self.gen_exp, ...). + gen_exp.quant_method points to the same quant_method, enabling + vLLM framework's process_weights_after_loading to automatically + discover and process it. + + Forward behavior: + - und mode (text_indices is None): fully reuse super().forward() + - gen mode: call MoT fused GEMM kernel, then execute TP all-reduce + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + vae_bias: bool = False, + input_is_parallel: bool = True, + skip_bias_add: bool = False, + params_dtype: torch.dtype | None = None, + reduce_results: bool = True, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + *, + return_bias: bool = True, + disable_tp: bool = False, + ): + # ---- Step 1: Parent class creates text weights ---- + # super().__init__ internally calls: + # quant_method.create_weights(self, ...) → self.weight, self.weight_scale, ... + super().__init__( + input_size, + output_size, + bias, + input_is_parallel, + skip_bias_add, + params_dtype, + reduce_results, + quant_config, + prefix, + return_bias=return_bias, + disable_tp=disable_tp, + ) + + # ---- Step 2: Create vae weights (permanent submodule) ---- + assert self.quant_method is not None + + self.gen_exp = torch.nn.Module() + + # Select weight_loader consistent with text + vae_weight_loader = ( + self.weight_loader_v2 + if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED + else self.weight_loader + ) + self.quant_method.create_weights( + layer=self.gen_exp, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=vae_weight_loader, + ) + + # Enable gen_exp to be automatically discovered by vLLM framework's process_weights_after_loading + self.gen_exp.quant_method = self.quant_method + + # ---- Step 3: vae bias ---- + # RowParallelLinear's bias is full output_size (not sharded) + if vae_bias: + self.gen_exp.bias = Parameter(torch.empty(self.output_size, dtype=self.params_dtype)) + set_weight_attrs( + self.gen_exp.bias, + {"output_dim": 0, "weight_loader": self.weight_loader}, + ) + else: + self.gen_exp.register_parameter("bias", None) + + self.update_param_tp_status() + + # ================================================================== + # Forward + # ================================================================== + def forward( + self, + input_: torch.Tensor, + text_indices: torch.Tensor | None = None, + vae_indices: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: + # ---- und mode: fully reuse parent class (only text path) ---- + if text_indices is None: + return super().forward(input_) + + # ---- gen mode ---- + # Handle input_is_parallel (consistent with parent class logic) + if self.input_is_parallel: + input_parallel = input_ + else: + split_input = split_tensor_along_last_dim(input_, num_partitions=self.tp_size) + input_parallel = split_input[self.tp_rank].contiguous() + + # Fused MoT GEMM + output_parallel = self._mot_gemm_dispatch(input_parallel, text_indices, vae_indices) + + # ---- TP communication: all-reduce (consistent with parent class) ---- + if self.reduce_results and self.tp_size > 1: + output = tensor_model_parallel_all_reduce(output_parallel) + else: + output = output_parallel + + if not self.return_bias: + return output + + if self.skip_bias_add and text_indices is not None and (self.bias is not None or self.gen_exp.bias is not None): + # Construct per-token mixed bias + merged_bias = torch.zeros( + output.size(0), + self.output_size_per_partition, + dtype=output.dtype, + device=output.device, + ) + if self.bias is not None: + merged_bias[text_indices] = self.bias + if self.gen_exp.bias is not None: + merged_bias[vae_indices] = self.gen_exp.bias + return output, merged_bias + + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + # ================================================================== + # MoT GEMM Dispatcher + # ================================================================== + def _mot_gemm_dispatch( + self, + x: torch.Tensor, + text_indices: torch.Tensor, + vae_indices: torch.Tensor, + ) -> torch.Tensor: + """Dispatch to different MoT kernel paths based on weight dtype / quant attributes.""" + + # Any other backend (ROCm, XPU, TPU, CPU) uses the safe vllm fallback. + if not current_platform.is_cuda(): + return self._mot_fallback(x, text_indices, vae_indices) + + # RowParallelLinear: bias only fused into GEMM at rank 0, + # other ranks pass None to avoid duplicate accumulation + bias_text = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + bias_vae = None if (self.tp_rank > 0 or self.skip_bias_add) else self.gen_exp.bias + + # Determine quantization type by weight dtype, avoid isinstance coupling to specific quant_method + # TODO: Currently does not support online quantization fp8, does not support quantization types involving int4 + w_text = self.weight + w_vae = self.gen_exp.weight + assert w_text.dtype.is_floating_point == w_vae.dtype.is_floating_point, ( + "weight of text expert and image expert should be the same dtype." + ) + + # w_text.dtype.itemsize >= 2 means bytes_per_element >= 2 (16bits or 32bits) + if w_text.dtype.is_floating_point and w_text.dtype.itemsize >= 2: + # ---- Path 0: BF16 / FP16 (unquantized) ---- + return self._mot_gemm_unquantized( + x, + text_indices, + vae_indices, + bias_text, + bias_vae, + ) + elif w_text.dtype == torch.float8_e4m3fn: + # ---- Path 1: FP8 W8A8 ---- + return self._mot_gemm_fp8_w8a8( + x, + text_indices, + vae_indices, + bias_text, + bias_vae, + ) + elif w_text.dtype == torch.int8: + # ---- Path 2: Weight-Only INT8 W8A16 ---- + return self._mot_gemm_weight_only( + x, + text_indices, + vae_indices, + bias_text, + bias_vae, + ) + else: + # ---- Fallback: gather/scatter + quant_method.apply ---- + return self._mot_fallback( + x, + text_indices, + vae_indices, + ) + + # ================================================================== + # Specific implementations for each quantization path + # ================================================================== + def _mot_gemm_unquantized(self, x, text_idx, vae_idx, bias_t, bias_v): + """BF16/FP16 path.""" + N = self.output_size_per_partition + C = torch.empty(x.size(0), N, dtype=x.dtype, device=x.device) + invoke_mot_gemm( + A=x, + B_text=self.weight.data.t(), + B_vae=self.gen_exp.weight.data.t(), + C=C, + bias_text=bias_t, + bias_vae=bias_v, + text_indices=text_idx, + vae_indices=vae_idx, + A_scale=None, + B_text_scale=None, + B_vae_scale=None, + use_fp8_w8a8=False, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + A_per_channel_quant=False, + B_per_channel_quant=False, + ) + return C + + def _mot_gemm_fp8_w8a8(self, x, text_idx, vae_idx, bias_t, bias_v): + """FP8 W8A8 path. + + 1) Activation quantization: reuse vllm's ops.scaled_fp8_quant + 2) MoT GEMM: text/vae each use different fp8 weights and weight_scale + 3) Dequantization: completed by MoT kernel internal epilogue + """ + from vllm import _custom_ops as ops + + x_2d = x.view(-1, x.shape[-1]) + input_scale = getattr(self, "input_scale", None) + x_fp8, x_scale = ops.scaled_fp8_quant( + x_2d, + input_scale, + use_per_token_if_dynamic=True, + ) + + N = self.output_size_per_partition + C = torch.empty(x.size(0), N, dtype=x.dtype, device=x.device) + + # weight has been transposed to (K, N) in process_weights_after_loading + invoke_mot_gemm( + A=x_fp8, + B_text=self.weight.data, + B_vae=self.gen_exp.weight.data, + C=C, + bias_text=bias_t, + bias_vae=bias_v, + text_indices=text_idx, + vae_indices=vae_idx, + A_scale=x_scale, + B_text_scale=self.weight_scale.data, + B_vae_scale=self.gen_exp.weight_scale.data, + use_fp8_w8a8=True, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + A_per_channel_quant=True, + B_per_channel_quant=False, + ) + return C + + def _mot_gemm_weight_only(self, x, text_idx, vae_idx, bias_t, bias_v): + """Weight-Only W8A16 path. + + Activations remain bf16/fp16, weights are int8 + per-channel scale. + MoT kernel internally performs immediate dequantization. + """ + N = self.output_size_per_partition + C = torch.empty(x.size(0), N, dtype=x.dtype, device=x.device) + invoke_mot_gemm( + A=x, + B_text=self.weight.data.t(), + B_vae=self.gen_exp.weight.data.t(), + C=C, + bias_text=bias_t, + bias_vae=bias_v, + text_indices=text_idx, + vae_indices=vae_idx, + A_scale=None, + B_text_scale=self.weight_scale.data, + B_vae_scale=self.gen_exp.weight_scale.data, + use_fp8_w8a8=False, + use_int8_w8a8=False, + use_int8_w8a16=True, + use_int4_w4a16=False, + A_per_channel_quant=False, + B_per_channel_quant=True, + ) + return C + + def _mot_fallback(self, x, text_idx, vae_idx): + """Fallback: degrade to gather/scatter + quant_method.apply. + + For unsupported quantization types, call standard forward separately for text/vae tokens. + """ + assert self.quant_method is not None + + # RowParallelLinear: bias only fused at rank 0 + bias_text = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + bias_vae = None if (self.tp_rank > 0 or self.skip_bias_add) else self.gen_exp.bias + + output = torch.empty( + x.size(0), + self.output_size_per_partition, + dtype=x.dtype, + device=x.device, + ) + # text tokens → standard quant_method (operate on weights on self) + output[text_idx] = self.quant_method.apply(self, x[text_idx], bias_text) + # vae tokens → same quant_method (operate on weights on gen_exp) + output[vae_idx] = self.quant_method.apply( + self.gen_exp, + x[vae_idx], + bias_vae, + ) + return output diff --git a/vllm_omni/diffusion/layers/mot/ops/__init__.py b/vllm_omni/diffusion/layers/mot/ops/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/vllm_omni/diffusion/layers/mot/ops/mot_gemm.py b/vllm_omni/diffusion/layers/mot/ops/mot_gemm.py new file mode 100644 index 00000000000..ec5c8d7acf8 --- /dev/null +++ b/vllm_omni/diffusion/layers/mot/ops/mot_gemm.py @@ -0,0 +1,904 @@ +# ruff: noqa: N803, E741 +import functools +import json +import logging +import os +import pathlib +import re +from typing import Any + +import torch +from vllm.triton_utils import tl, triton + +logger = logging.getLogger(__name__) + +# ===================================================================== +# MoT GEMM Config Loading (3-tier: env → built-in → default) +# +# Usage pattern (mirrors vLLM fused_moe): +# 1. Layer invokes invoke_mot_gemm(...) without extra config context. +# 2. invoke_mot_gemm lazily calls +# get_mot_configs(K, N, dtype_str) +# which is @lru_cache'd and lazily loads the JSON on first hit. +# 3. If get_mot_configs returns None, fall back to +# get_mot_default_config(M, N, K, ...) +# ===================================================================== + +_CONFIGS_DIR = pathlib.Path(__file__).resolve().parent.parent / "configs" +_ENV_CONFIG_FOLDER = "VLLM_TUNED_CONFIG_FOLDER" + + +def get_device_name() -> str: + """Sanitized GPU device name, matching mot_linear_benchmarks.py output.""" + raw = torch.cuda.get_device_name(0) + name = re.sub(r"[^a-zA-Z0-9]", "", raw.replace(" ", "")) + for prefix in ("NVIDIA", "AMD"): + if name.startswith(prefix): + name = name[len(prefix) :] + # --- Device Aliasing Patch --- + alias_map = { + "A800": "A100", + "H800": "H100", + } + # map A800/H800 to A100/H100 for chinese market + for key, target in alias_map.items(): + if key in name: + name = name.replace(key, target) + break + return name + + +def build_config_filename(device_name: str, dtype_str: str) -> str: + return f"device_name={device_name},dtype={dtype_str}.json" + + +def _try_load_json(filepath: str) -> dict | None: + if os.path.isfile(filepath): + with open(filepath) as f: + return json.load(f) + return None + + +@functools.lru_cache +def _load_mot_config_file(dtype_str: str) -> dict | None: + """Load and cache the full MoT config JSON (one file per device/dtype). + + Search order: + 1. ``$VLLM_TUNED_CONFIG_FOLDER/device_name=...,dtype=....json`` + 2. ``vllm_omni/.../mot/configs/device_name=...,dtype=....json`` + 3. Return ``None`` (caller falls back to ``get_mot_default_config``). + """ + device_name = get_device_name() + filename = build_config_filename(device_name, dtype_str) + + config_file_paths: list[str] = [] + env_dir = os.environ.get(_ENV_CONFIG_FOLDER) + if env_dir: + config_file_paths.append(str(pathlib.Path(env_dir) / filename)) + config_file_paths.append(str(_CONFIGS_DIR / filename)) + + for path in config_file_paths: + data = _try_load_json(path) + if data is not None: + logger.info("MoT config loaded from %s", path) + return data + + logger.warning( + f"\n{'=' * 80}\n" + f" ⚠️ [WARNING] No tuned MoT config found.\n" + f" Searched paths: {', '.join(config_file_paths)}\n" + f" Using conservative default configs which are NOT optimal.\n" + f" Run `python benchmarks/kernels/mot_linear_benchmarks.py --tune` \n" + f" to generate hardware-specific optimal configs.\n" + f"{'=' * 80}\n" + ) + return None + + +@functools.lru_cache +def get_mot_configs( + K: int, + N: int, + dtype_str: str | None = None, +) -> dict[int, dict[str, int]] | None: + """Return ``{M: tile_config}`` for a given (K, N) shape, or ``None``. + + The return value maps an irregular grid of batch sizes (M) to Triton + tile configurations. The caller should pick the entry whose M is + closest to the actual batch size. + + Config file is selected by ``device_name + dtype``. + """ + file_data = _load_mot_config_file(dtype_str or "w16a16") + if file_data is None: + return None + + shape_entry = file_data.get(f"{K}_{N}") + if shape_entry is None: + logger.warning( + f"\n{'=' * 80}\n" + f" ⚠️ [WARNING] MoT config file found, but NO tuned entry for shape K={K}, N={N}.\n" + f" Using conservative default configs which are NOT optimal for this specific shape.\n" + f" Run `python benchmarks/kernels/mot_linear_benchmarks.py --tune` \n" + f" to generate hardware-specific optimal configs.\n" + f"{'=' * 80}\n" + ) + return None + + return {int(k): dict(v) for k, v in shape_entry.items() if k != "_comment"} + + +def get_mot_default_config( + M: int, + N: int, + K: int, + dtype: str | None = None, + block_quant_shape: list[int] | None = None, +) -> dict[str, int]: + """Conservative fallback config guaranteed to compile on all hardware. + + Trades peak performance for universal compatibility (T4 / V100 / A100 / + H100, CUDA & ROCm). + """ + # FP8 block-wise quantization requires strict alignment + if dtype == "fp8_w8a8" and block_quant_shape is not None: + return { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": block_quant_shape[0], + "BLOCK_SIZE_K": block_quant_shape[1], + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + } + + # Very small M (tail batches, final feature concat, etc.) + if M <= 64: + return { + "BLOCK_SIZE_M": 16 if M <= 16 else 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + } + + # Standard fallback for typical image-generation M (2048 / 4096 / …) + # SRAM usage: (64*32 + 64*32) * 2 * 2 = 16 KB — safe on decade-old GPUs + return { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + } + + +def get_best_mot_config(M: int, N: int, K: int, dtype_str: str | None = None) -> tuple[int, dict[str, int]]: + configs = get_mot_configs(K, N, dtype_str) + if configs: + loaded_m_key = min(configs.keys(), key=lambda x: abs(x - M)) + return loaded_m_key, configs[loaded_m_key] + else: + return -1, get_mot_default_config(M, N, K, dtype=dtype_str) + + +# ================================================================= +# Part 1: The Router (Routing Component) +# Responsibilities: Handle PID mapping, Text/VAE distribution, indirect index loading, pointer calculation +# ================================================================= +@triton.jit +def _get_mot_pointers( + # System Inputs + pid, + # Matrix Pointers + a_ptr, + b_text_ptr, + b_vae_ptr, + bias_text_ptr, + bias_vae_ptr, + scale_a_ptr, + scale_b_text_ptr, + scale_b_vae_ptr, + # Indices & Meta + text_indices_ptr, + vae_indices_ptr, + M_text, + M_vae, + N, + # Strides (need to select based on Text/VAE) + stride_bk_text, + stride_bn_text, + stride_bk_vae, + stride_bn_vae, + # Block Config + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # 1. Calculate Text/VAE task boundaries + num_pid_m_text = tl.cdiv(M_text, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + # 2. MoT Routing + # Initialize VAE variables first and overwrite for Text + + # VAE Path + cur_pid = pid - (num_pid_m_text * num_pid_n) + # Select VAE Pointers + cur_b_ptr = b_vae_ptr + cur_bias_ptr = bias_vae_ptr + cur_scale_b_ptr = scale_b_vae_ptr + cur_indices_ptr = vae_indices_ptr + # Select VAE Strides / Limits + cur_stride_bk = stride_bk_vae + cur_stride_bn = stride_bn_vae + M_limit = M_vae + + # Text Path + if pid < num_pid_m_text * num_pid_n: + cur_pid = pid + # Select Text Pointers + cur_b_ptr = b_text_ptr + cur_bias_ptr = bias_text_ptr + cur_scale_b_ptr = scale_b_text_ptr + cur_indices_ptr = text_indices_ptr + # Select Text Strides / Limits + cur_stride_bk = stride_bk_text + cur_stride_bn = stride_bn_text + M_limit = M_text + + # 3. Calculate Grid coordinates(grouping) + cur_num_pid_m = tl.cdiv(M_limit, BLOCK_SIZE_M) + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = cur_pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m_adj = tl.minimum(cur_num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (cur_pid % group_size_m_adj) + pid_n = (cur_pid % num_pid_in_group) // group_size_m_adj + + # 4. Load indirect indices (Indirect Indexing for A) + # Calculate the M range covered by current Block [0, BLOCK_SIZE_M] + offs_m_idx = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + m_mask = offs_m_idx < M_limit + # Load real row indices from the index array + # for very big K (eg. ffn down proj*TP=1,K=16k), if M is also huge(M>130K) + # loading int32 indices may result in integer overflow when we compute offs_m + real_row_idxs = tl.load(cur_indices_ptr + offs_m_idx, mask=m_mask, other=0).to(tl.int64) + + # 5. Calculate N-dimension Offsets + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) + n_mask = offs_n < N + + return ( + pid_m, + pid_n, # Grid Coordinates + real_row_idxs, + m_mask, # M-dim info (Indirect) + offs_n, + n_mask, # N-dim info + M_limit, # Boundary + cur_b_ptr, + cur_bias_ptr, # Selected Pointers + cur_scale_b_ptr, # Selected Scale Pointer + cur_stride_bk, + cur_stride_bn, # Selected Strides + ) + + +# ================================================================= +# Part 2: Compute Cores +# Responsibilities: Execute specific Loop structures based on QUANT_TYPE +# ================================================================= + + +# Core A: Standard GEMM (for BF16/FP16 and W8A8) +# Feature: No dequantization inside Loop, Scale is applied after Loop ends +@triton.jit +def _core_standard_gemm( + # Pointers + a_ptr, + b_ptr, + # Offsets & Masks + real_row_idxs, + m_mask, + offs_n, + n_mask, + offs_k, + # Strides + stride_am, + stride_ak, + stride_bk, + stride_bn, + # Quant-Related + scale_a_ptr, + scale_b_ptr, + stride_scale_a, + stride_scale_b, + # Loop Info + K, + BLOCK_SIZE_K: tl.constexpr, + # Configs + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + ACCUMULATOR_DTYPE: tl.constexpr, + IS_W8A8: tl.constexpr, + # Accelerate Configs + EVEN_K: tl.constexpr, + EVEN_N: tl.constexpr, + STRIDE_AK_IS_1: tl.constexpr, + STRIDE_BK_IS_1: tl.constexpr, + STRIDE_BN_IS_1: tl.constexpr, +): + # 1. Stride optimizations (Bypassing compiler limitations for unit strides) + _stride_ak = 1 if STRIDE_AK_IS_1 else stride_ak + _stride_bk = 1 if STRIDE_BK_IS_1 else stride_bk + _stride_bn = 1 if STRIDE_BN_IS_1 else stride_bn + + # Pointer initialization (A uses indirect indexing, B uses standard striding) + a_ptrs = a_ptr + (stride_am * real_row_idxs[:, None] + _stride_ak * offs_k[None, :]) + b_ptrs = b_ptr + (_stride_bk * offs_k[:, None] + _stride_bn * offs_n[None, :]) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=ACCUMULATOR_DTYPE) + + # 2. Unified Main Loop + # Triton evaluates `constexpr` conditions at compile time, ensuring zero runtime overhead + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Compute K-mask only if needed + if not EVEN_K: + mask_k = offs_k < K - k * BLOCK_SIZE_K + + # Load A + if EVEN_K: + a = tl.load(a_ptrs, mask=m_mask[:, None], other=0.0) + else: + a = tl.load(a_ptrs, mask=m_mask[:, None] & mask_k[None, :], other=0.0) + + # Load B + if EVEN_K and EVEN_N: + b = tl.load(b_ptrs) + elif not EVEN_K and EVEN_N: + b = tl.load(b_ptrs, mask=mask_k[:, None], other=0.0) + elif EVEN_K and not EVEN_N: + b = tl.load(b_ptrs, mask=n_mask[None, :], other=0.0) + else: + b = tl.load(b_ptrs, mask=mask_k[:, None] & n_mask[None, :], other=0.0) + + # Compute & Advance + accumulator = tl.dot(a, b, accumulator, out_dtype=ACCUMULATOR_DTYPE) + a_ptrs += BLOCK_SIZE_K * _stride_ak + b_ptrs += BLOCK_SIZE_K * _stride_bk + + # 3. Epilogue (only needed for W8A8) + if IS_W8A8: + accumulator = accumulator.to(tl.float32) + # Load Scale A + scale_a_ptrs = scale_a_ptr + real_row_idxs * stride_scale_a + sa = tl.load(scale_a_ptrs, mask=m_mask, other=1.0) + accumulator = accumulator * sa[:, None] + + # Load Scale B + scale_b_ptrs = scale_b_ptr + offs_n * stride_scale_b + sb = tl.load(scale_b_ptrs, mask=n_mask, other=1.0) + accumulator = accumulator * sb[None, :] + + return accumulator + + +# Core B: Weight Only GEMM (for W4A16 / W8A16) +# Feature: Dequantization inside Loop (Dequantize-on-the-fly) +# Currently only supports W8A16 +@triton.jit +def _core_weight_only_gemm( + # Pointers + a_ptr, + b_ptr, + # Offsets & Masks + real_row_idxs, + m_mask, + offs_n, + n_mask, + offs_k, + # Strides + stride_am, + stride_ak, + stride_bk, + stride_bn, + # Quant-Related + scale_b_ptr, + stride_scale_b, # 1(per-channel) or 0(per-tensor) + # Loop Info + K, + BLOCK_SIZE_K: tl.constexpr, + # Configs + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + ACCUMULATOR_DTYPE: tl.constexpr, + COMPUTE_DTYPE: tl.constexpr, # bf16 or fp16 + WEIGHT_BITS: tl.constexpr, # 4 or 8 + # Accelerate Configs + EVEN_K: tl.constexpr, + EVEN_N: tl.constexpr, + STRIDE_AK_IS_1: tl.constexpr, + STRIDE_BK_IS_1: tl.constexpr, + STRIDE_BN_IS_1: tl.constexpr, +): + # Compile-time Check + tl.static_assert(WEIGHT_BITS == 8, "For weight-only, we only support W8A16 at this point") + + # 1. Stride optimizations + _stride_ak = 1 if STRIDE_AK_IS_1 else stride_ak + _stride_bk = 1 if STRIDE_BK_IS_1 else stride_bk + _stride_bn = 1 if STRIDE_BN_IS_1 else stride_bn + + a_ptrs = a_ptr + (stride_am * real_row_idxs[:, None] + _stride_ak * offs_k[None, :]) + b_ptrs = b_ptr + (_stride_bk * offs_k[:, None] + _stride_bn * offs_n[None, :]) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=ACCUMULATOR_DTYPE) + + # 2. Unified Main Loop + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + if not EVEN_K: + mask_k = offs_k < K - k * BLOCK_SIZE_K + + # Load A + if EVEN_K: + a = tl.load(a_ptrs, mask=m_mask[:, None], other=0.0) + else: + a = tl.load(a_ptrs, mask=m_mask[:, None] & mask_k[None, :], other=0.0) + + # Load B (Int) + if EVEN_K and EVEN_N: + b_int = tl.load(b_ptrs) + elif not EVEN_K and EVEN_N: + b_int = tl.load(b_ptrs, mask=mask_k[:, None], other=0.0) + elif EVEN_K and not EVEN_N: + b_int = tl.load(b_ptrs, mask=n_mask[None, :], other=0.0) + else: + b_int = tl.load(b_ptrs, mask=mask_k[:, None] & n_mask[None, :], other=0.0) + + # --- Dequantize Logic (Type Cast Only) --- + # No multiplication here, correct as long as per-tensor/per-channel scaling (scale_b.shape=(N,)) + b_compute = b_int.to(COMPUTE_DTYPE) + + # Compute & Advance + accumulator = tl.dot(a, b_compute, accumulator, out_dtype=ACCUMULATOR_DTYPE) + + a_ptrs += BLOCK_SIZE_K * _stride_ak + b_ptrs += BLOCK_SIZE_K * _stride_bk + + # 3. Epilogue: Apply Scale B safely outside the loop + # Load Scale B + scale_b_ptrs = scale_b_ptr + offs_n * stride_scale_b + sb = tl.load(scale_b_ptrs, mask=n_mask, other=1.0) # shape=(N,) + + accumulator = accumulator * sb[None, :] + return accumulator + + +# ================================================================= +# Part 3: Unified Entry Kernel +# Responsibilities: Call Router, statically dispatch Core, store results +# ================================================================= +@triton.jit +def mot_unified_gemm_kernel( + # Inputs + a_ptr, + b_text_ptr, + b_vae_ptr, + c_ptr, + bias_text_ptr, + bias_vae_ptr, + text_indices_ptr, + vae_indices_ptr, + # Dimensions & Strides + M_text, + M_vae, + N, + K, + stride_am, + stride_ak, + stride_bk_text, + stride_bn_text, + stride_bk_vae, + stride_bn_vae, + stride_cm, + stride_cn, + # Scales (pass 0 if None) + scale_a_ptr, + scale_b_text_ptr, + scale_b_vae_ptr, + stride_scale_a, + stride_scale_b, # 1(per-channel) or 0(per-tensor) + # Metas + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + # accelerate config + EVEN_K: tl.constexpr, + EVEN_N: tl.constexpr, + STRIDE_AK_IS_1: tl.constexpr, + STRIDE_BK_IS_1: tl.constexpr, + STRIDE_BN_IS_1: tl.constexpr, + # Quant-related dtypes + ACCUMULATOR_DTYPE: tl.constexpr, + COMPUTE_DTYPE: tl.constexpr, + OUTPUT_DTYPE: tl.constexpr, + # Quant Control + # 0=None, 1=W8A8, 2=W8A16, 3=W4A16 + QUANT_TYPE: tl.constexpr, + HAS_BIAS: tl.constexpr = False, +): + pid = tl.program_id(axis=0) + + # ----------------------------------------------------------- + # 1. Routing Phase (General) + # ----------------------------------------------------------- + ( + pid_m, + pid_n, + real_row_idxs, + m_mask, + offs_n, + n_mask, + M_limit, + cur_b_ptr, + cur_bias_ptr, + cur_scale_b_ptr, + cur_stride_bk, + cur_stride_bn, + ) = _get_mot_pointers( + pid, + a_ptr, + b_text_ptr, + b_vae_ptr, + bias_text_ptr, + bias_vae_ptr, + scale_a_ptr, + scale_b_text_ptr, + scale_b_vae_ptr, + text_indices_ptr, + vae_indices_ptr, + M_text, + M_vae, + N, + stride_bk_text, + stride_bn_text, + stride_bk_vae, + stride_bn_vae, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + GROUP_SIZE_M, + ) + + offs_k = tl.arange(0, BLOCK_SIZE_K) + + # ----------------------------------------------------------- + # 2. Compute Phase (Static Dispatch) + # ----------------------------------------------------------- + if QUANT_TYPE == 0: # FP16 / BF16 Standard + c = _core_standard_gemm( + a_ptr, + cur_b_ptr, + real_row_idxs, + m_mask, + offs_n, + n_mask, + offs_k, + stride_am, + stride_ak, + cur_stride_bk, + cur_stride_bn, + 0, + 0, + 0, + 0, + K, + BLOCK_SIZE_K, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + ACCUMULATOR_DTYPE, + IS_W8A8=False, + EVEN_K=EVEN_K, + EVEN_N=EVEN_N, + STRIDE_AK_IS_1=STRIDE_AK_IS_1, + STRIDE_BK_IS_1=STRIDE_BK_IS_1, + STRIDE_BN_IS_1=STRIDE_BN_IS_1, + ) + elif QUANT_TYPE == 1: # W8A8 + c = _core_standard_gemm( + a_ptr, + cur_b_ptr, + real_row_idxs, + m_mask, + offs_n, + n_mask, + offs_k, + stride_am, + stride_ak, + cur_stride_bk, + cur_stride_bn, + scale_a_ptr, + cur_scale_b_ptr, + stride_scale_a, + stride_scale_b, + K, + BLOCK_SIZE_K, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + ACCUMULATOR_DTYPE, + IS_W8A8=True, + EVEN_K=EVEN_K, + EVEN_N=EVEN_N, + STRIDE_AK_IS_1=STRIDE_AK_IS_1, + STRIDE_BK_IS_1=STRIDE_BK_IS_1, + STRIDE_BN_IS_1=STRIDE_BN_IS_1, + ) + elif QUANT_TYPE == 2 or QUANT_TYPE == 3: # Weight Only + bits = 8 if QUANT_TYPE == 2 else 4 + c = _core_weight_only_gemm( + a_ptr, + cur_b_ptr, + real_row_idxs, + m_mask, + offs_n, + n_mask, + offs_k, + stride_am, + stride_ak, + cur_stride_bk, + cur_stride_bn, + cur_scale_b_ptr, + stride_scale_b, + K, + BLOCK_SIZE_K, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + ACCUMULATOR_DTYPE, + COMPUTE_DTYPE, + WEIGHT_BITS=bits, + EVEN_K=EVEN_K, + EVEN_N=EVEN_N, + STRIDE_AK_IS_1=STRIDE_AK_IS_1, + STRIDE_BK_IS_1=STRIDE_BK_IS_1, + STRIDE_BN_IS_1=STRIDE_BN_IS_1, + ) + + # ----------------------------------------------------------- + # 3. Store Phase (General) + # ----------------------------------------------------------- + + # Bias Add + if HAS_BIAS: + bias = tl.load(cur_bias_ptr + offs_n, mask=n_mask, other=0.0) + c = c + bias[None, :] # reshape to (1, BLOCK_SIZE_N) + + # Cast C into the output dtype + c = c.to(OUTPUT_DTYPE) + + # Store + offs_cm = real_row_idxs + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_n[None, :] + + # Need to be careful with the Mask here: real_row_idxs may contain 0 + # (out-of-bounds padding), but c_mask needs the real boundary + store_mask = m_mask[:, None] & n_mask[None, :] + tl.store(c_ptrs, c, mask=store_mask) + + +# Define is_weak_contiguous +def is_weak_contiguous(x: torch.Tensor): + strides = x.stride() + sizes = x.shape + is_not_transpose = strides[0] == 1 and (strides[1] >= max(1, sizes[0])) + is_transpose = strides[1] == 1 and (strides[0] >= max(1, sizes[1])) + return is_transpose or is_not_transpose + + +def invoke_mot_gemm( + # Inputs + A: torch.Tensor, + B_text: torch.Tensor, + B_vae: torch.Tensor, + C: torch.Tensor, + bias_text: torch.Tensor | None, + bias_vae: torch.Tensor | None, + # Indices + text_indices: torch.Tensor, + vae_indices: torch.Tensor, + # Quant Scales (None if disabled) + A_scale: torch.Tensor | None, + B_text_scale: torch.Tensor | None, + B_vae_scale: torch.Tensor | None, + # Quant Flags + use_fp8_w8a8: bool, + use_int8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + # Quant Type + A_per_channel_quant: bool, + B_per_channel_quant: bool, + # Config + config: dict[str, Any] | None = None, +): + # ------ 1. Basic Assertions ------ + M = A.size(0) + K = B_text.size(0) + N = B_text.size(1) + + if config is None: + if use_fp8_w8a8: + _dtype_str: str | None = "fp8_w8a8" + elif use_int8_w8a16: + _dtype_str = "int8_w8a16" + else: + _dtype_str = None + + loaded_m_key, config = get_best_mot_config(M, N, K, _dtype_str) + + assert len(A.shape) == 2 and len(C.shape) == 2, ( + "The input tensor and output tensor should be flattened to (batch_size*seq_len, hidden_dim)" + ) + + assert K == A.size(1), "the weights' first dimension should matchinputtensor's last dimension (hidden_dim)" + + assert K == B_vae.size(0) and N == B_vae.size(1), ( + "the weights dimension for text expert andimage expert should be the same" + ) + + assert C.size(0) == M and C.size(1) == N, "the output tensor shape is not correct" + + M_text = text_indices.size(0) + M_vae = vae_indices.size(0) + assert M_text + M_vae == M, "the length sum of text and image indices should match input tensor's first dimension" + if bias_text is not None: + assert bias_text.dtype == C.dtype, "the bias tensor dtype should match the output tensor dtype" + if bias_vae is not None: + assert bias_vae.dtype == C.dtype, "the bias tensor dtype should match the output tensor dtype" + + assert is_weak_contiguous(A) + assert is_weak_contiguous(B_text) + assert is_weak_contiguous(B_vae) + + # --- 2. Quantization Logic Translation --- + + def triton_dtype(torch_dtype): + if torch_dtype == torch.float8_e4m3fn: + return getattr(tl, "float8e4m3fn", tl.float8e4nv) + elif torch_dtype == torch.float8_e5m2: + return tl.float8e5 + return { + torch.float16: tl.float16, + torch.bfloat16: tl.bfloat16, + torch.float32: tl.float32, + torch.int8: tl.int8, + torch.float8_e4m3fn: tl.float8e4nv, + torch.float8_e5m2: tl.float8e5, + }[torch_dtype] + + # Determine QUANT_TYPE + # 0=None, 1=W8A8, 2=W8A16, 3=W4A16 + quant_type = 0 + ACCUMULATOR_DTYPE = tl.float32 + COMPUTE_DTYPE = triton_dtype(A.dtype) + OUTPUT_DTYPE = triton_dtype(C.dtype) + + if use_int8_w8a8 or use_fp8_w8a8: + quant_type = 1 + assert A_scale is not None, "W8A8 requires A_scale" + assert B_text_scale is not None and B_vae_scale is not None, "W8A8 requires B_text_scale and B_vae_scale" + if use_int8_w8a8: + ACCUMULATOR_DTYPE = tl.int32 + assert ( + A.dtype == torch.int8 + and B_text.dtype == torch.int8 + and B_vae.dtype == torch.int8 + and C.dtype in [torch.float16, torch.bfloat16] + ), "if you want to use INT8_W8A8, A should be INT8, B should be INT8, C should be FP16/BF16" + else: + ACCUMULATOR_DTYPE = tl.float32 + assert ( + A.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] + and B_text.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] + and B_vae.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] + and C.dtype in [torch.float16, torch.bfloat16] + ), "if you want to use FP8_W8A8, A should be FP8, B should be FP8, C should be FP16/BF16" + elif use_int8_w8a16: + quant_type = 2 + assert B_text_scale is not None and B_vae_scale is not None, "W8A16 requires B_text_scale and B_vae_scale" + ACCUMULATOR_DTYPE = tl.float32 + assert ( + A.dtype in [torch.float16, torch.bfloat16] + and B_text.dtype == torch.int8 + and B_vae.dtype == torch.int8 + and C.dtype in [torch.float16, torch.bfloat16] + ), "if you want to use INT8_W8A16, A should be FP16/BF16, B should be INT8, C should be FP16/BF16" + + elif use_int4_w4a16: + raise NotImplementedError("For weight-only, we only support W8A16 at this point") + # quant_type = 3 + # ACCUMULATOR_DTYPE=tl.float32 + + # accelerate config + EVEN_K = K % config["BLOCK_SIZE_K"] == 0 + EVEN_N = N % config["BLOCK_SIZE_N"] == 0 + STRIDE_AK_IS_1 = A.stride(1) == 1 + STRIDE_BK_IS_1 = (B_text.stride(0) == 1) and (B_vae.stride(0) == 1) + STRIDE_BN_IS_1 = (B_text.stride(1) == 1) and (B_vae.stride(1) == 1) + + # bias check + assert (bias_text is None) == (bias_vae is None), ( + "Bias must be provided for both Text and VAE simultaneously, or neither." + ) + has_bias = bias_text is not None + + # --- 3. Grid Calculation --- + def grid(META): + return ( + (triton.cdiv(M_text, META["BLOCK_SIZE_M"]) + triton.cdiv(M_vae, META["BLOCK_SIZE_M"])) + * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + + # --- 4. Launch --- + run_config = config.copy() + run_config.update( + { + "QUANT_TYPE": quant_type, + "ACCUMULATOR_DTYPE": ACCUMULATOR_DTYPE, + "COMPUTE_DTYPE": COMPUTE_DTYPE, + "OUTPUT_DTYPE": OUTPUT_DTYPE, + "HAS_BIAS": has_bias, + "EVEN_K": EVEN_K, + "EVEN_N": EVEN_N, + "STRIDE_AK_IS_1": STRIDE_AK_IS_1, + "STRIDE_BK_IS_1": STRIDE_BK_IS_1, + "STRIDE_BN_IS_1": STRIDE_BN_IS_1, + } + ) + + # Pointers (Handle None -> 0) + p_a_scale = A_scale if A_scale is not None else 0 + p_b_text_scale = B_text_scale if B_text_scale is not None else 0 + p_b_vae_scale = B_vae_scale if B_vae_scale is not None else 0 + p_bias_text = bias_text if bias_text is not None else 0 + p_bias_vae = bias_vae if bias_vae is not None else 0 + + # Quantization granularity + stride_scale_a = 1 if A_per_channel_quant else 0 + stride_scale_b = 1 if B_per_channel_quant else 0 + + mot_unified_gemm_kernel[grid]( + # Inputs + A, + B_text, + B_vae, + C, + p_bias_text, + p_bias_vae, + text_indices, + vae_indices, + # Dimensions + M_text, + M_vae, + N, + K, + # Strides + A.stride(0), + A.stride(1), + B_text.stride(0), + B_text.stride(1), + B_vae.stride(0), + B_vae.stride(1), + C.stride(0), + C.stride(1), + # Scales + p_a_scale, + p_b_text_scale, + p_b_vae_scale, + stride_scale_a, + stride_scale_b, + # Config + **run_config, + ) diff --git a/vllm_omni/diffusion/layers/mot/ops/mot_rms_norm.py b/vllm_omni/diffusion/layers/mot/ops/mot_rms_norm.py new file mode 100644 index 00000000000..4d40a015af5 --- /dev/null +++ b/vllm_omni/diffusion/layers/mot/ops/mot_rms_norm.py @@ -0,0 +1,283 @@ +# ruff: noqa: N803, E741 +import torch +from vllm.triton_utils import tl, triton + + +@triton.jit +def _mot_rms_norm_kernel( + input_ptr, + text_weight_ptr, + vae_weight_ptr, + text_indices_ptr, # MoT Routing Info + vae_indices_ptr, # MoT Routing Info + M_text, # MoT Routing Info + output_ptr, + input_row_stride, + output_row_stride, + n_cols, + eps, + BLOCK_SIZE: tl.constexpr, +): + """ + input shape: (batch_size*seq_len, hidden_size) + RMSNorm: y = x / sqrt(mean(x^2) + eps) * weight + """ + # Step 0: MoT Routing + pid = tl.program_id(0).to(tl.int64) + + # dummy init (must be scalar int64[] to match tl.load return type) + row_idx = tl.cast(0, tl.int64) + weight_ptr = text_weight_ptr + + if pid < M_text: + # --- Text Path --- + row_idx = tl.load(text_indices_ptr + pid) + weight_ptr = text_weight_ptr + else: + # --- VAE Path --- + vae_pid = pid - M_text + row_idx = tl.load(vae_indices_ptr + vae_pid) + weight_ptr = vae_weight_ptr + + row_start_ptr = input_ptr + row_idx * input_row_stride + output_row_start_ptr = output_ptr + row_idx * output_row_stride + + # Step 1: Compute sum of squares in float32 to avoid overflow + sum_sq = tl.zeros([1], dtype=tl.float32) + for col_offset in range(0, n_cols, BLOCK_SIZE): + col_idx = col_offset + tl.arange(0, BLOCK_SIZE) + mask = col_idx < n_cols + + vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0) + # Convert to float32 for accumulation to prevent overflow + vals_f32 = vals.to(tl.float32) + sq_vals = vals_f32 * vals_f32 + sum_sq += tl.sum(tl.where(mask, sq_vals, 0.0)) + + # Step 2: Compute RMS (root mean square) in float32 + mean_sq = sum_sq / n_cols + rms = tl.sqrt(mean_sq + eps) + inv_rms = 1.0 / rms + + # Step 3: Normalize and apply weight + for col_offset in range(0, n_cols, BLOCK_SIZE): + col_idx = col_offset + tl.arange(0, BLOCK_SIZE) + mask = col_idx < n_cols + vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0) + weight = tl.load(weight_ptr + col_idx, mask=mask, other=1.0) + # Compute in float32 then convert back to input dtype + vals_f32 = vals.to(tl.float32) + weight_f32 = weight.to(tl.float32) + output_f32 = vals_f32 * inv_rms * weight_f32 + output = output_f32.to(vals.dtype) + tl.store(output_row_start_ptr + col_idx, output, mask=mask) + + +@triton.jit +def _mot_rms_norm_qk_kernel( + input_ptr, + text_weight_ptr, + vae_weight_ptr, + text_indices_ptr, + vae_indices_ptr, + M_text, + output_ptr, + stride_tok, + stride_head, + out_stride_tok, + out_stride_head, + num_heads, + head_dim, + eps, + BLOCK_DIM: tl.constexpr, # for head_dim=128,should be 128 + SHARED_WEIGHT: tl.constexpr, # weights are shared across heads +): + pid_tok = tl.program_id(0).to(tl.int64) + + # MoT Routing + if pid_tok < M_text: + row_idx = tl.load(text_indices_ptr + pid_tok) + weight_ptr = text_weight_ptr + else: + vae_pid = pid_tok - M_text + row_idx = tl.load(vae_indices_ptr + vae_pid) + weight_ptr = vae_weight_ptr + + dim_offsets = tl.arange(0, BLOCK_DIM) + mask = dim_offsets < head_dim + + # input.shape=(4098,28,128)=(seq_len,num_heads,head_dim) + # weight.shape=(28,128)=(num_heads,head_dim) or (128,)=(head_dim) + tok_in_ptr = input_ptr + row_idx * stride_tok + tok_out_ptr = output_ptr + row_idx * out_stride_tok + + # if weight.shape=(128,) + if SHARED_WEIGHT: + weight = tl.load(weight_ptr + dim_offsets, mask=mask, other=1.0) + weight_f32 = weight.to(tl.float32) + + # one program loop over all heads + for h in range(num_heads): + head_in_ptr = tok_in_ptr + h * stride_head + head_out_ptr = tok_out_ptr + h * out_stride_head + + # if weight.shape=(28,128) + if not SHARED_WEIGHT: + weight_offset = h * head_dim + dim_offsets + weight = tl.load(weight_ptr + weight_offset, mask=mask, other=1.0) + weight_f32 = weight.to(tl.float32) + + vals = tl.load(head_in_ptr + dim_offsets, mask=mask, other=0.0) + vals_f32 = vals.to(tl.float32) + + sq_vals = vals_f32 * vals_f32 + sum_sq = tl.sum(tl.where(mask, sq_vals, 0.0)) + mean_sq = sum_sq / head_dim + rms = tl.sqrt(mean_sq + eps) + inv_rms = 1.0 / rms + + output_f32 = vals_f32 * inv_rms * weight_f32 + tl.store(head_out_ptr + dim_offsets, output_f32.to(vals.dtype), mask=mask) + + +def mot_rms_norm( + input: torch.Tensor, + text_weight: torch.Tensor, + vae_weight: torch.Tensor, + text_indices: torch.Tensor, + vae_indices: torch.Tensor, + head_norm: bool = False, + eps: float = 1e-6, + block_size: int | None = None, +) -> torch.Tensor: + """ + Compute RMS normalization using Triton kernel. + + RMS Norm normalizes the input by the root mean square and scales by weight: + output = input / sqrt(mean(input^2) + eps) * weight + + Args: + input: Input tensor of shape (batch_size*seq_len, hidden_size) or (batch_size,seq_len, hidden_size) + text_weight: Weight for text tokens tensor of shape (hidden_size,) + vae_weight: Weight for vae tokens tensor of shape (hidden_size,) + text_indices: indices of text tokens, (batch_size*2,) + vae_indices: indices of vae tokens, (batch_size*(seq_len-2),) + eps: Small constant for numerical stability + + Returns: + Tensor with RMS normalization applied along the last dimension + """ + + assert input.shape[-1] == text_weight.shape[-1], ( + f"Input last dimension ({input.shape[-1]}) must match Text weight dimension ({text_weight.shape[-1]})" + ) + assert input.shape[-1] == vae_weight.shape[0], ( + f"Input last dimension ({input.shape[-1]}) must match VAE weight dimension ({vae_weight.shape[-1]})" + ) + + original_shape = input.shape + text_indices = text_indices.reshape(-1) + vae_indices = vae_indices.reshape(-1) + M_text = text_indices.shape[0] + M_vae = vae_indices.shape[0] + num_tokens = M_text + M_vae + + if not head_norm: + input_2d = input.reshape(-1, input.shape[-1]) + assert input_2d.shape[0] == num_tokens, ( + f"batch_size={input_2d.shape[0]}, len(text_indices)={M_text}, len(vae_indices)={M_vae}" + f"for layer norm, batched_token_length should match the sum of indices_length" + ) + input_2d = input_2d.contiguous() + text_indices = text_indices.contiguous() + vae_indices = vae_indices.contiguous() + + text_weight = text_weight.contiguous() + vae_weight = vae_weight.contiguous() + + n_rows, n_cols = input_2d.shape + + output = torch.empty_like(input_2d) + + if block_size is None: + block_size = triton.next_power_of_2(n_cols) + block_size = min(block_size, 4096) + + num_warps = 4 + if block_size >= 2048: + num_warps = 8 + elif block_size >= 1024: + num_warps = 4 + else: + num_warps = 2 + + grid = (n_rows,) + _mot_rms_norm_kernel[grid]( + input_2d, + text_weight, + vae_weight, + text_indices, + vae_indices, + M_text, + output, + input_2d.stride(0), + output.stride(0), + n_cols, + eps, + BLOCK_SIZE=block_size, + num_warps=num_warps, + ) + return output.reshape(original_shape) + else: + # qk norm scenarios: + # input.shape=(batch_size, seq_len,head_num, head_dim) or + # input.shape=(batch_size*seq_len,head_num, head_dim) + assert len(original_shape) > 2, ( + "If head_norm=True,input shape be 3D or 3D, last 2 dimensions should be head_num and head_dim" + ) + num_heads = input.shape[-2] + head_dim = input.shape[-1] + + is_shared_weight = text_weight.dim() == 1 + if not is_shared_weight: + assert num_heads == text_weight.shape[0] and num_heads == vae_weight.shape[0], ( + "when weights are not shared across heads, the first dimension of " + "weights should be num of heads and match input.shape[-2]" + ) + + # reshape to 3D + input_3d = input.view(-1, num_heads, head_dim) + + input_3d = input_3d.contiguous() + text_indices = text_indices.contiguous() + vae_indices = vae_indices.contiguous() + text_weight = text_weight.contiguous() + vae_weight = vae_weight.contiguous() + + output_3d = torch.empty_like(input_3d) + + block_dim = triton.next_power_of_2(head_dim) + num_warps = 4 if block_dim > 128 else 2 + num_stages = 3 + + _mot_rms_norm_qk_kernel[(num_tokens,)]( + input_3d, + text_weight, + vae_weight, + text_indices, + vae_indices, + M_text, + output_3d, + input_3d.stride(0), + input_3d.stride(1), + output_3d.stride(0), + output_3d.stride(1), + num_heads, + head_dim, + eps, + BLOCK_DIM=block_dim, + SHARED_WEIGHT=is_shared_weight, + num_warps=num_warps, + num_stages=num_stages, + ) + return output_3d.view(original_shape) diff --git a/vllm_omni/diffusion/models/bagel/bagel_transformer.py b/vllm_omni/diffusion/models/bagel/bagel_transformer.py index d1254f84566..73f88027fa9 100644 --- a/vllm_omni/diffusion/models/bagel/bagel_transformer.py +++ b/vllm_omni/diffusion/models/bagel/bagel_transformer.py @@ -22,11 +22,11 @@ ) from transformers.utils import ModelOutput from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, MergedColumnParallelLinear, - QKVParallelLinear, RowParallelLinear, ) from vllm.model_executor.layers.quantization.base_config import ( @@ -48,6 +48,9 @@ get_sp_group, ) from vllm_omni.diffusion.forward_context import get_forward_context, is_forward_context_available +from vllm_omni.diffusion.layers.mot.mot_layernorm import MoTRMSNorm +from vllm_omni.diffusion.layers.mot.mot_qkv_parallel_linear import MoTQKVParallelLinear +from vllm_omni.diffusion.layers.mot.mot_row_parallel_linear import MoTRowParallelLinear from vllm_omni.diffusion.layers.rope import RotaryEmbedding @@ -149,6 +152,13 @@ def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> tuple[torch.Te class BagelMLP(nn.Module): + """FFN with Mixture-of-Tokens routing via MoT parallel linear layers. + + gate_proj + up_proj are fused into a single MoTMergedColumnParallelLinear. + down_proj uses MoTRowParallelLinear. Both layers hold text weights on self + and vae weights on self.gen_exp, routing by text_indices / vae_indices. + """ + def __init__( self, hidden_size: int, @@ -158,29 +168,35 @@ def __init__( prefix: str = "", ) -> None: super().__init__() + + self.intermediate_size = intermediate_size + self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, - [intermediate_size, intermediate_size], + input_size=hidden_size, + output_sizes=[intermediate_size, intermediate_size], bias=False, + gather_output=False, quant_config=quant_config, prefix=f"{prefix}.gate_up_proj", ) self.down_proj = RowParallelLinear( - intermediate_size, - hidden_size, - input_is_parallel=True, + input_size=intermediate_size, + output_size=hidden_size, bias=False, + input_is_parallel=True, quant_config=quant_config, prefix=f"{prefix}.down_proj", ) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. Only silu is supported.") - self.act_fn = nn.SiLU() + self.act_fn = SiluAndMul() - def forward(self, x): + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: gate_up, _ = self.gate_up_proj(x) - gate, up = gate_up.chunk(2, dim=-1) - x = self.act_fn(gate) * up + x = self.act_fn(gate_up) x, _ = self.down_proj(x) return x @@ -279,12 +295,9 @@ class BaseNavitOutputWithPast(ModelOutput): class PackedAttentionMoT(nn.Module): """Packed attention with Mixture-of-Tokens routing for understanding/generation. - Uses vLLM's QKVParallelLinear and RowParallelLinear for tensor parallelism - support, following the same pattern as vLLM's Qwen2Attention. - - The q/k/v projections are stacked into a single QKVParallelLinear: - - qkv_proj : stacks q_proj + k_proj + v_proj (understanding + gen text) - - qkv_proj_moe_gen : stacks q_proj_moe_gen + k_proj_moe_gen + v_proj_moe_gen (gen vae) + Uses MoTQKVParallelLinear and MoTRowParallelLinear for tensor parallelism. + Text and vae weights are held within the same MoT layer (text on self, + vae on self.gen_exp). Token routing is driven by text_indices / vae_indices. """ def __init__( @@ -314,47 +327,26 @@ def __init__( self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim - # Understanding mode projections (stacked q/k/v) - self.qkv_proj = QKVParallelLinear( + self.qkv_proj = MoTQKVParallelLinear( self.hidden_size, self.head_dim, self.total_num_heads, self.total_num_kv_heads, bias=True, + vae_bias=True, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", ) - self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - self.hidden_size, + self.o_proj = MoTRowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=self.hidden_size, bias=False, quant_config=quant_config, prefix=f"{prefix}.o_proj", ) - # Generation mode MoE projections (stacked q/k/v) - self.qkv_proj_moe_gen = QKVParallelLinear( - self.hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj_moe_gen", - ) - self.o_proj_moe_gen = RowParallelLinear( - self.total_num_heads * self.head_dim, - self.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj_moe_gen", - ) - - # QK normalization - self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) - self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) - self.q_norm_moe_gen = RMSNorm(self.head_dim, eps=config.rms_norm_eps) - self.k_norm_moe_gen = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.q_norm = MoTRMSNorm(self.head_dim, head_norm=True, eps=config.rms_norm_eps) + self.k_norm = MoTRMSNorm(self.head_dim, head_norm=True, eps=config.rms_norm_eps) self.rotary_op = RotaryEmbedding(is_neox_style=True) @@ -389,60 +381,45 @@ def _forward_sp_gen( ) -> tuple[torch.Tensor, NaiveCache | None]: """SP-aware attention for gen mode denoising. - Converts packed format to batched (1, S, H, D) and uses the diffusion - Attention layer (Ulysses / Ring) with joint mechanism: + Uses MoT unified layers for projection/norm (fused text/vae routing), + then splits into text/vae for the DiffusionAttention (Ulysses / Ring) + with joint mechanism: - Main Q/K/V: VAE tokens (split across SP ranks) - Joint Q: text marker Q (replicated) - Joint K/V: KV cache K/V + text marker K/V (replicated) """ + text_indices = packed_text_indexes + vae_indices = packed_vae_token_indexes + packed_query_sequence = packed_query_sequence.to(torch.bfloat16) - packed_text_query_sequence = packed_query_sequence[packed_text_indexes] - packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes] - - # Project text tokens through base qkv - text_qkv, _ = self.qkv_proj(packed_text_query_sequence) - text_q, text_k, text_v = text_qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - - # Project vae tokens through moe_gen qkv - vae_qkv, _ = self.qkv_proj_moe_gen(packed_vae_query_sequence) - vae_q, vae_k, vae_v = vae_qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - - # Reshape to (tokens, heads, head_dim) - text_q = text_q.view(-1, self.num_heads, self.head_dim) - text_k = text_k.view(-1, self.num_kv_heads, self.head_dim) - text_v = text_v.view(-1, self.num_kv_heads, self.head_dim) - vae_q = vae_q.view(-1, self.num_heads, self.head_dim) - vae_k = vae_k.view(-1, self.num_kv_heads, self.head_dim) - vae_v = vae_v.view(-1, self.num_kv_heads, self.head_dim) - - # Apply QK norms - text_q = self.q_norm(text_q.to(torch.float32)) - text_k = self.k_norm(text_k.to(torch.float32)) - vae_q = self.q_norm_moe_gen(vae_q.to(torch.float32)) - vae_k = self.k_norm_moe_gen(vae_k.to(torch.float32)) - - # Apply RoPE - need to build per-token cos/sin for text and vae separately - # packed_query_position_embeddings are ordered as the packed sequence - cos_full, sin_full = [x[..., : self.head_dim // 2] for x in packed_query_position_embeddings] - - # Extract cos/sin for text and vae positions - text_cos = cos_full[packed_text_indexes] - text_sin = sin_full[packed_text_indexes] - vae_cos = cos_full[packed_vae_token_indexes] - vae_sin = sin_full[packed_vae_token_indexes] - - text_q = self.rotary_op(text_q.to(text_cos.dtype).unsqueeze(0), text_cos, text_sin).squeeze(0) - text_k = self.rotary_op(text_k.to(text_cos.dtype).unsqueeze(0), text_cos, text_sin).squeeze(0) - vae_q = self.rotary_op(vae_q.to(vae_cos.dtype).unsqueeze(0), vae_cos, vae_sin).squeeze(0) - vae_k = self.rotary_op(vae_k.to(vae_cos.dtype).unsqueeze(0), vae_cos, vae_sin).squeeze(0) - - text_q = text_q.to(torch.bfloat16) - text_k = text_k.to(torch.bfloat16) - text_v = text_v.to(torch.bfloat16) - vae_q = vae_q.to(torch.bfloat16) - vae_k = vae_k.to(torch.bfloat16) - vae_v = vae_v.to(torch.bfloat16) + # MoT QKV projection — fused kernel routes text/vae to correct weights + qkv, _ = self.qkv_proj(packed_query_sequence, text_indices, vae_indices) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = q.view(-1, self.num_heads, self.head_dim) + k = k.view(-1, self.num_kv_heads, self.head_dim) + v = v.view(-1, self.num_kv_heads, self.head_dim) + + # MoT QK norms — routes text/vae to weight/gen_weight internally + q = self.q_norm(q.to(torch.float32), text_indices, vae_indices) + k = self.k_norm(k.to(torch.float32), text_indices, vae_indices) + + # RoPE on full packed tensor + cos, sin = [x[..., : self.head_dim // 2] for x in packed_query_position_embeddings] + q = self.rotary_op(q.to(cos.dtype).unsqueeze(0), cos, sin).squeeze(0) + k = self.rotary_op(k.to(cos.dtype).unsqueeze(0), cos, sin).squeeze(0) + + q = q.to(torch.bfloat16) + k = k.to(torch.bfloat16) + v = v.to(torch.bfloat16) + + # Split into text / vae for SP attention + text_q = q[text_indices] + text_k = k[text_indices] + text_v = v[text_indices] + vae_q = q[vae_indices] + vae_k = k[vae_indices] + vae_v = v[vae_indices] # Build joint K/V: [kv_cache, text_markers] (replicated across SP ranks) if past_key_values is not None and past_key_values.key_cache[self.layer_idx] is not None: @@ -472,15 +449,19 @@ def _forward_sp_gen( joint_strategy="front", ), ) - # attn_out: (1, text_len + local_vae_len, H, D) + # attn_out: (1, text_len + vae_len, H, D) text_len = text_q.shape[0] - attn_out = attn_out.squeeze(0) # (text_len + local_vae_len, H, D) + attn_out = attn_out.squeeze(0) text_attn = attn_out[:text_len].reshape(text_len, self.q_size) vae_attn = attn_out[text_len:].reshape(-1, self.q_size) - # Apply output projections - text_out, _ = self.o_proj(text_attn) - vae_out, _ = self.o_proj_moe_gen(vae_attn) + # MoT output projection — construct local packed tensor with local indices + local_packed = torch.cat([text_attn, vae_attn], dim=0) + local_text_idx = torch.arange(text_len, device=local_packed.device) + local_vae_idx = torch.arange(text_len, text_len + vae_attn.shape[0], device=local_packed.device) + local_out, _ = self.o_proj(local_packed, local_text_idx, local_vae_idx) + text_out = local_out[:text_len] + vae_out = local_out[text_len:] # Merge back into packed format total_len = packed_query_sequence.shape[0] @@ -522,56 +503,25 @@ def forward( packed_text_indexes=packed_text_indexes, ) - if mode == "und": - qkv, _ = self.qkv_proj(packed_query_sequence) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - packed_query_states = q.view(-1, self.num_heads, self.head_dim) - packed_key_states = k.view(-1, self.num_kv_heads, self.head_dim) - packed_value_states = v.view(-1, self.num_kv_heads, self.head_dim) - packed_query_states = self.q_norm(packed_query_states) - packed_key_states = self.k_norm(packed_key_states) - elif mode == "gen": - packed_query_sequence = packed_query_sequence.to(torch.bfloat16) - - packed_text_query_sequence = packed_query_sequence[packed_text_indexes] - packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes] - - # Project text tokens through base qkv - text_qkv, _ = self.qkv_proj(packed_text_query_sequence) - text_q, text_k, text_v = text_qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - - # Project vae tokens through moe_gen qkv - vae_qkv, _ = self.qkv_proj_moe_gen(packed_vae_query_sequence) - vae_q, vae_k, vae_v = vae_qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + text_indices = packed_text_indexes if mode == "gen" else None + vae_indices = packed_vae_token_indexes if mode == "gen" else None - # Merge into packed tensors - total_len = packed_query_sequence.shape[0] - packed_query_states = packed_query_sequence.new_zeros((total_len, self.q_size)) - packed_key_states = packed_query_sequence.new_zeros((total_len, self.kv_size)) - packed_value_states = packed_query_sequence.new_zeros((total_len, self.kv_size)) - - packed_query_states[packed_text_indexes] = text_q - packed_query_states[packed_vae_token_indexes] = vae_q - packed_key_states[packed_text_indexes] = text_k - packed_key_states[packed_vae_token_indexes] = vae_k - packed_value_states[packed_text_indexes] = text_v - packed_value_states[packed_vae_token_indexes] = vae_v + if mode == "gen": + packed_query_sequence = packed_query_sequence.to(torch.bfloat16) - packed_query_states = packed_query_states.view(-1, self.num_heads, self.head_dim) - packed_key_states = packed_key_states.view(-1, self.num_kv_heads, self.head_dim) - packed_value_states = packed_value_states.view(-1, self.num_kv_heads, self.head_dim) + # QKV projection — MoT layer handles text/vae routing internally + qkv, _ = self.qkv_proj(packed_query_sequence, text_indices, vae_indices) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + packed_query_states = q.view(-1, self.num_heads, self.head_dim) + packed_key_states = k.view(-1, self.num_kv_heads, self.head_dim) + packed_value_states = v.view(-1, self.num_kv_heads, self.head_dim) + if mode == "gen": packed_query_states = packed_query_states.to(torch.float32) - packed_query_states[packed_text_indexes] = self.q_norm(packed_query_states[packed_text_indexes]) - packed_query_states[packed_vae_token_indexes] = self.q_norm_moe_gen( - packed_query_states[packed_vae_token_indexes] - ) - packed_key_states = packed_key_states.to(torch.float32) - packed_key_states[packed_text_indexes] = self.k_norm(packed_key_states[packed_text_indexes]) - packed_key_states[packed_vae_token_indexes] = self.k_norm_moe_gen( - packed_key_states[packed_vae_token_indexes] - ) + + packed_query_states = self.q_norm(packed_query_states, text_indices, vae_indices) + packed_key_states = self.k_norm(packed_key_states, text_indices, vae_indices) cos, sin = [x[..., : self.head_dim // 2] for x in packed_query_position_embeddings] packed_query_states = self.rotary_op(packed_query_states.to(cos.dtype).unsqueeze(0), cos, sin).squeeze(0) @@ -612,15 +562,9 @@ def forward( causal=is_causal, ) packed_attn_output = packed_attn_output.reshape(-1, self.q_size) - if mode == "und": - packed_attn_output, _ = self.o_proj(packed_attn_output) - elif mode == "gen": - text_out, _ = self.o_proj(packed_attn_output[packed_text_indexes]) - vae_out, _ = self.o_proj_moe_gen(packed_attn_output[packed_vae_token_indexes]) - full_output = text_out.new_zeros((packed_attn_output.shape[0], self.hidden_size)) - full_output[packed_text_indexes] = text_out - full_output[packed_vae_token_indexes] = vae_out - packed_attn_output = full_output + + # Output projection — MoT layer handles text/vae routing internally + packed_attn_output, _ = self.o_proj(packed_attn_output, text_indices, vae_indices) if update_past_key_values: past_key_values.key_cache[self.layer_idx] = merged_key_states @@ -640,12 +584,15 @@ def __init__( prefix: str = "", ): super().__init__() + self.layer_idx = layer_idx self.hidden_size = config.hidden_size self.self_attn = attn_module( config, layer_idx, parallel_config=parallel_config, quant_config=quant_config, prefix=f"{prefix}.self_attn" ) + self.input_layernorm = MoTRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.mlp = BagelMLP( config.hidden_size, config.intermediate_size, @@ -660,8 +607,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.mlp_moe_gen", ) - self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.input_layernorm_moe_gen = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm_moe_gen = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -684,18 +630,12 @@ def forward( ) -> BaseNavitOutputWithPast: if packed_query_sequence is None: packed_query_sequence = hidden_states + + text_indices = packed_text_indexes if mode == "gen" else None + vae_indices = packed_vae_token_indexes if mode == "gen" else None + residual = packed_query_sequence - if mode == "und": - packed_query_sequence = self.input_layernorm(packed_query_sequence) - elif mode == "gen": - packed_query_sequence_ = torch.zeros_like(packed_query_sequence) - packed_query_sequence_[packed_text_indexes] = self.input_layernorm( - packed_query_sequence[packed_text_indexes] - ) - packed_query_sequence_[packed_vae_token_indexes] = self.input_layernorm_moe_gen( - packed_query_sequence[packed_vae_token_indexes] - ) - packed_query_sequence = packed_query_sequence_ + packed_query_sequence = self.input_layernorm(packed_query_sequence, text_indices, vae_indices) # Self Attention packed_query_sequence, past_key_values = self.self_attn( @@ -726,7 +666,9 @@ def forward( packed_vae_query_sequence = self.post_attention_layernorm_moe_gen(packed_vae_query_sequence).to( torch.bfloat16 ) - + packed_normed = torch.zeros_like(packed_query_sequence).to(torch.bfloat16) + packed_normed[packed_text_indexes] = packed_text_query_sequence + packed_normed[packed_vae_token_indexes] = packed_vae_query_sequence packed_query_sequence_ = torch.zeros_like(packed_query_sequence).to(torch.bfloat16) packed_query_sequence_[packed_text_indexes] = self.mlp(packed_text_query_sequence) packed_query_sequence_[packed_vae_token_indexes] = self.mlp_moe_gen(packed_vae_query_sequence) @@ -765,9 +707,7 @@ def __init__( ] ) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - if self.use_moe: - self.norm_moe_gen = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = MoTRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = BagelRotaryEmbedding(config=config) # Initialize weights and apply final processing @@ -820,18 +760,9 @@ def forward( **extra_inputs, ) - if self.use_moe: - if mode == "und": - packed_query_sequence = self.norm(packed_query_sequence) - elif mode == "gen": - packed_query_sequence_ = torch.zeros_like(packed_query_sequence) - packed_query_sequence_[packed_text_indexes] = self.norm(packed_query_sequence[packed_text_indexes]) - packed_query_sequence_[packed_vae_token_indexes] = self.norm_moe_gen( - packed_query_sequence[packed_vae_token_indexes] - ) - packed_query_sequence = packed_query_sequence_ - else: - packed_query_sequence = self.norm(packed_query_sequence) + text_indices = packed_text_indexes if self.use_moe and mode == "gen" else None + vae_indices = packed_vae_token_indexes if self.use_moe and mode == "gen" else None + packed_query_sequence = self.norm(packed_query_sequence, text_indices, vae_indices) return BaseNavitOutputWithPast( packed_query_sequence=packed_query_sequence, @@ -910,57 +841,84 @@ def forward( return outputs def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - """Load weights for vLLM parallel layers. + """Load weights for MoT parallel layers. + + Stacked parameter remapping (checkpoint name → model parameter): + - q/k/v_proj → qkv_proj (text, shard q/k/v) + - q/k/v_proj_moe_gen → qkv_proj.gen_exp (gen, shard q/k/v) - Handles stacked parameter remapping for QKVParallelLinear: - - q_proj, k_proj, v_proj -> qkv_proj (shard ids: q, k, v) - - q_proj_moe_gen, k_proj_moe_gen, v_proj_moe_gen -> qkv_proj_moe_gen - Other parallel layers (gate_proj, up_proj, down_proj, embed_tokens, etc.) - keep HF checkpoint names and use weight_loader for TP sharding. + Direct remapping (no shard dimension): + - o_proj_moe_gen → o_proj.gen_exp + - {norm}_moe_gen.weight → {norm}.gen_weight (all MoTRMSNorm layers) + + Text norm weights (input_layernorm.weight, q_norm.weight, etc.) and + other names (embed_tokens, lm_head) pass through unchanged. """ stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - # More specific _moe_gen patterns FIRST to avoid substring - # ambiguity (`.q_proj` is a substring of `.q_proj_moe_gen`). - (".qkv_proj_moe_gen", ".q_proj_moe_gen", "q"), - (".qkv_proj_moe_gen", ".k_proj_moe_gen", "k"), - (".qkv_proj_moe_gen", ".v_proj_moe_gen", "v"), + # (param_name, weight_name, shard_id) + # _moe_gen patterns MUST come first — `.q_proj` is a substring + # of `.q_proj_moe_gen`, so the more specific pattern must match first. + (".qkv_proj.gen_exp", ".q_proj_moe_gen", "q"), + (".qkv_proj.gen_exp", ".k_proj_moe_gen", "k"), + (".qkv_proj.gen_exp", ".v_proj_moe_gen", "v"), (".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".k_proj", "k"), (".qkv_proj", ".v_proj", "v"), - # MLP gate/up projections — fused into MergedColumnParallelLinear. - # HF checkpoints store separate gate_proj / up_proj weights; - # these entries remap them to the fused gate_up_proj parameter. - (".gate_up_proj", ".gate_proj", 0), - (".gate_up_proj", ".up_proj", 1), + (".mlp_moe_gen.gate_up_proj", ".mlp_moe_gen.gate_proj", 0), + (".mlp_moe_gen.gate_up_proj", ".mlp_moe_gen.up_proj", 1), + (".mlp.gate_up_proj", ".mlp.gate_proj", 0), + (".mlp.gate_up_proj", ".mlp.up_proj", 1), + ] + + direct_remap = [ + (".o_proj_moe_gen.", ".o_proj.gen_exp."), + # Norm _moe_gen.weight → {norm_name}.gen_weight + (".input_layernorm_moe_gen.", ".input_layernorm.gen_"), + (".q_norm_moe_gen.", ".q_norm.gen_"), + (".k_norm_moe_gen.", ".k_norm.gen_"), + (".norm_moe_gen.", ".norm.gen_"), ] - self.stacked_params_mapping = stacked_params_mapping + params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() + def handle_weight(name, loaded_weight, shard_id=None): + param = params_dict.get(name) + if param is not None: + weight_loader = getattr(param, "weight_loader", default_weight_loader) + if shard_id is not None: + weight_loader(param, loaded_weight, shard_id) + else: + weight_loader(param, loaded_weight) + loaded_params.add(name) + for name, loaded_weight in weights: - loaded = False + # match direct remap + handled = False + for old_substr, new_substr in direct_remap: + if old_substr in name: + name = name.replace(old_substr, new_substr) + handle_weight(name, loaded_weight) + handled = True + break + + if handled: + continue + + # match stacked params mapping for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue - stacked_name = name.replace(weight_name, param_name) - param = params_dict.get(stacked_name) - if param is None: - break - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight, shard_id) - name = stacked_name - loaded = True + name = name.replace(weight_name, param_name) + handle_weight(name, loaded_weight, shard_id) + handled = True break - if not loaded: - param = params_dict.get(name) - if param is None: - continue - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) + if handled: + continue - loaded_params.add(name) + # no-name-match cases are handled here + handle_weight(name, loaded_weight) return loaded_params diff --git a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py index a3d2259e643..8da64ef13d5 100644 --- a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py +++ b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py @@ -819,18 +819,33 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: tp_aware_params = {name for name, p in self.named_parameters() if hasattr(p, "weight_loader")} # Expand allowed/tp_aware_params with stacked param source names. - # QKVParallelLinear merges q_proj+k_proj+v_proj into qkv_proj; the - # checkpoint stores the original separate names. We must recognise - # those names so _filtered_weights does not drop them. + # The model fuses several checkpoint projections into merged layers: + # QKV: q/k/v_proj → qkv_proj, q/k/v_proj_moe_gen → qkv_proj.gen_exp + # and remaps non-stacked weights like: + # {norm}_moe_gen.weight → {norm}.gen_weight (MoTRMSNorm layers) + # We expand allowed names so _filtered_weights does not drop them. _stacked_expansions = [ + # text QKV (".qkv_proj", ".q_proj"), (".qkv_proj", ".k_proj"), (".qkv_proj", ".v_proj"), - (".qkv_proj_moe_gen", ".q_proj_moe_gen"), - (".qkv_proj_moe_gen", ".k_proj_moe_gen"), - (".qkv_proj_moe_gen", ".v_proj_moe_gen"), - (".gate_up_proj", ".gate_proj"), - (".gate_up_proj", ".up_proj"), + # gen QKV + (".qkv_proj.gen_exp", ".q_proj_moe_gen"), + (".qkv_proj.gen_exp", ".k_proj_moe_gen"), + (".qkv_proj.gen_exp", ".v_proj_moe_gen"), + # gen o_proj (non-stacked, but still remapped) + (".o_proj.gen_exp", ".o_proj_moe_gen"), + # text FFN gate+up + (".mlp.gate_up_proj", ".mlp.gate_proj"), + (".mlp.gate_up_proj", ".mlp.up_proj"), + # gen FFN gate+up + (".mlp_moe_gen.gate_up_proj", ".mlp_moe_gen.gate_proj"), + (".mlp_moe_gen.gate_up_proj", ".mlp_moe_gen.up_proj"), + # MoTRMSNorm gen_weight ← checkpoint _moe_gen.weight + (".input_layernorm.gen_", ".input_layernorm_moe_gen."), + (".q_norm.gen_", ".q_norm_moe_gen."), + (".k_norm.gen_", ".k_norm_moe_gen."), + (".norm.gen_", ".norm_moe_gen."), ] stacked_source_names: set[str] = set() for name in list(allowed):