diff --git a/docs/user_guide/diffusion/parallelism_acceleration.md b/docs/user_guide/diffusion/parallelism_acceleration.md index dfacd2183ff..324301158d8 100644 --- a/docs/user_guide/diffusion/parallelism_acceleration.md +++ b/docs/user_guide/diffusion/parallelism_acceleration.md @@ -23,13 +23,23 @@ The following table shows which models are currently supported by parallelism me | **LongCat-Image** | `meituan-longcat/LongCat-Image` | ✅ | ✅ | ❌ | ❌ | | **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` | ✅ | ✅ | ❌ | ❌ | | **Ovis-Image** | `OvisAI/Ovis-Image` | ❌ | ❌ | ❌ | ❌ | -| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | ✅ | ✅ | ❌ | +| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | ✅ | ✅ | ✅ | | **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | ✅ | ✅ | ✅ | ❌ | | **Qwen-Image-Edit-2509** | `Qwen/Qwen-Image-Edit-2509` | ✅ | ✅ | ✅ | ❌ | | **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | ✅ | ✅ | ✅ | ❌ | | **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ❌ | ❌ | ❌ | ✅ (TP=2 only) | | **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ❌ | ❌ | ❌ | + +!!! note "TP Limitations for Diffusion Models" + We currently implement Tensor Parallelism (TP) only for the DiT (Diffusion Transformer) blocks. This is because the `text_encoder` component in vLLM-Omni uses the original Transformers implementation, which does not yet support TP. + + - Good news: The text_encoder typically has minimal impact on overall inference performance. + - Bad news: When TP is enabled, every TP process retains a full copy of the text_encoder weights, leading to significant GPU memory waste. + + We are actively refactoring this design to address this. For details and progress, please refer to [Issue #771](https://github.com/vllm-project/vllm-omni/issues/771). + + !!! note "Why Z-Image is TP=2 only" Z-Image Turbo is currently limited to `tensor_parallel_size` of **1 or 2** due to model shape divisibility constraints. For example, the model has `n_heads=30` and a final projection out dimension of `64`, so valid TP sizes must divide both 30 and 64; the only common divisors are **1 and 2**. diff --git a/examples/offline_inference/image_to_image/image_edit.py b/examples/offline_inference/image_to_image/image_edit.py index 5d2b1052bec..c31d098252b 100644 --- a/examples/offline_inference/image_to_image/image_edit.py +++ b/examples/offline_inference/image_to_image/image_edit.py @@ -181,6 +181,12 @@ def parse_args() -> argparse.Namespace: default=1, help="Number of GPUs used for ring sequence parallelism.", ) + parser.add_argument( + "--tensor_parallel_size", + type=int, + default=1, + help="Number of GPUs used for tensor parallelism (TP) inside the DiT.", + ) parser.add_argument("--layers", type=int, default=4, help="Number of layers to decompose the input image into.") parser.add_argument( "--resolution", @@ -301,7 +307,10 @@ def main(): vae_use_slicing = is_npu() vae_use_tiling = is_npu() parallel_config = DiffusionParallelConfig( - ulysses_degree=args.ulysses_degree, ring_degree=args.ring_degree, cfg_parallel_size=args.cfg_parallel_size + ulysses_degree=args.ulysses_degree, + ring_degree=args.ring_degree, + cfg_parallel_size=args.cfg_parallel_size, + tensor_parallel_size=args.tensor_parallel_size, ) # Configure cache based on backend type @@ -351,7 +360,7 @@ def main(): else: print(f" Input image size: {input_image.size}") print( - f" Parallel configuration: ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}, cfg_parallel_size={args.cfg_parallel_size}" + f" Parallel configuration: ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}, cfg_parallel_size={args.cfg_parallel_size}, tensor_parallel_size={args.tensor_parallel_size}" ) print(f"{'=' * 60}\n") diff --git a/tests/e2e/offline_inference/test_diffusion_cpu_offload.py b/tests/e2e/offline_inference/test_diffusion_cpu_offload.py index 0066d49b161..cefda891571 100644 --- a/tests/e2e/offline_inference/test_diffusion_cpu_offload.py +++ b/tests/e2e/offline_inference/test_diffusion_cpu_offload.py @@ -1,11 +1,10 @@ import sys -import threading -import time from pathlib import Path import pytest import torch +from tests.utils import GPUMemoryMonitor from vllm_omni.utils.platform_utils import is_npu, is_rocm # ruff: noqa: E402 @@ -15,39 +14,6 @@ from vllm_omni import Omni - -class GPUMemoryMonitor: - """Poll global device memory usage via CUDA APIs.""" - - def __init__(self, device_index: int, interval: float = 0.05): - self.device_index = device_index - self.interval = interval - self.peak_used_mb = 0.0 - self._stop_event = threading.Event() - self._thread: threading.Thread | None = None - - def start(self) -> None: - def monitor_loop() -> None: - while not self._stop_event.is_set(): - try: - with torch.cuda.device(self.device_index): - free_bytes, total_bytes = torch.cuda.mem_get_info() - used_mb = (total_bytes - free_bytes) / (1024**2) - self.peak_used_mb = max(self.peak_used_mb, used_mb) - except Exception: - pass - time.sleep(self.interval) - - self._thread = threading.Thread(target=monitor_loop, daemon=True) - self._thread.start() - - def stop(self) -> None: - if self._thread is None: - return - self._stop_event.set() - self._thread.join(timeout=2.0) - - models = ["riverclouds/qwen_image_random"] @@ -73,13 +39,7 @@ def inference(offload: bool = True): generator=torch.Generator("cuda").manual_seed(42), ) - monitor.stop() - torch.cuda.synchronize(device_index) - fallback_alloc = torch.cuda.max_memory_allocated(device=device_index) / (1024**2) - fallback_reserved = torch.cuda.max_memory_reserved(device=device_index) / (1024**2) - peak_memory_mb = max(monitor.peak_used_mb, fallback_alloc, fallback_reserved) - - return peak_memory_mb + return monitor.peak_used_mb offload_peak_memory = inference(offload=True) no_offload_peak_memory = inference(offload=False) diff --git a/tests/e2e/offline_inference/test_zimage_tensor_parallel.py b/tests/e2e/offline_inference/test_zimage_tensor_parallel.py index d32bb2b8223..60686992278 100644 --- a/tests/e2e/offline_inference/test_zimage_tensor_parallel.py +++ b/tests/e2e/offline_inference/test_zimage_tensor_parallel.py @@ -17,6 +17,7 @@ if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) +from tests.utils import GPUMemoryMonitor from vllm_omni import Omni from vllm_omni.diffusion.data import DiffusionParallelConfig from vllm_omni.outputs import OmniRequestOutput @@ -66,7 +67,12 @@ def _extract_single_image(outputs) -> Image.Image: def _run_zimage_generate( *, tp_size: int, height: int, width: int, num_inference_steps: int, seed: int -) -> tuple[Image.Image, float]: +) -> tuple[Image.Image, float, float]: + torch.cuda.empty_cache() + device_index = torch.cuda.current_device() + monitor = GPUMemoryMonitor(device_index=device_index, interval=0.02) + monitor.start() + m = Omni( model=_get_zimage_model(), parallel_config=DiffusionParallelConfig(tensor_parallel_size=tp_size), @@ -107,7 +113,10 @@ def _run_zimage_generate( pass median_time_s = float(np.median(per_request_times_s)) - return _extract_single_image([last_output]), median_time_s + + peak_memory_mb = monitor.peak_used_mb + + return _extract_single_image([last_output]), median_time_s, peak_memory_mb finally: m.close() cleanup_dist_env_and_memory() @@ -125,14 +134,14 @@ def test_zimage_tensor_parallel_tp2(tmp_path: Path): num_inference_steps = 2 seed = 42 - tp1_img, tp1_time_s = _run_zimage_generate( + tp1_img, tp1_time_s, tp1_peak_mem = _run_zimage_generate( tp_size=1, height=height, width=width, num_inference_steps=num_inference_steps, seed=seed, ) - tp2_img, tp2_time_s = _run_zimage_generate( + tp2_img, tp2_time_s, tp2_peak_mem = _run_zimage_generate( tp_size=2, height=height, width=width, @@ -164,3 +173,8 @@ def test_zimage_tensor_parallel_tp2(tmp_path: Path): print(f"Z-Image TP perf (lower is better): tp1_time_s={tp1_time_s:.6f}, tp2_time_s={tp2_time_s:.6f}") assert tp2_time_s < tp1_time_s, f"Expected TP=2 to be faster than TP=1 (tp1={tp1_time_s}, tp2={tp2_time_s})" + + print(f"Z-Image TP peak memory (MB): tp1_peak_mem={tp1_peak_mem:.2f}, tp2_peak_mem={tp2_peak_mem:.2f}") + assert tp2_peak_mem < tp1_peak_mem, ( + f"Expected TP=2 to use less peak memory than TP=1 (tp1={tp1_peak_mem}, tp2={tp2_peak_mem})" + ) diff --git a/tests/utils.py b/tests/utils.py index 2a2dca238a8..8e5593d6501 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -7,6 +7,7 @@ import subprocess import sys import tempfile +import threading import time from collections.abc import Callable from contextlib import ExitStack, contextmanager, suppress @@ -14,6 +15,7 @@ import cloudpickle import pytest +import torch from typing_extensions import ParamSpec from vllm.platforms import current_platform from vllm.utils.torch_utils import cuda_device_count_stateless @@ -474,3 +476,44 @@ def wrapper(f: Callable[_P, None]) -> Callable[_P, None]: return func return wrapper + + +class GPUMemoryMonitor: + """Poll global device memory usage via CUDA APIs.""" + + def __init__(self, device_index: int, interval: float = 0.05): + self.device_index = device_index + self.interval = interval + self._peak_used_mb = 0.0 + self._stop_event = threading.Event() + self._thread: threading.Thread | None = None + + def start(self) -> None: + def monitor_loop() -> None: + while not self._stop_event.is_set(): + try: + with torch.cuda.device(self.device_index): + free_bytes, total_bytes = torch.cuda.mem_get_info() + used_mb = (total_bytes - free_bytes) / (1024**2) + self._peak_used_mb = max(self._peak_used_mb, used_mb) + except Exception: + pass + time.sleep(self.interval) + + self._thread = threading.Thread(target=monitor_loop, daemon=True) + self._thread.start() + + def stop(self) -> None: + if self._thread is None: + return + self._stop_event.set() + self._thread.join(timeout=2.0) + + @property + def peak_used_mb(self) -> float: + fallback_alloc = torch.cuda.max_memory_allocated(device=self.device_index) / (1024**2) + fallback_reserved = torch.cuda.max_memory_reserved(device=self.device_index) / (1024**2) + return max(self._peak_used_mb, fallback_alloc, fallback_reserved) + + def __del__(self): + self.stop() diff --git a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py index 8ac5014ce89..cbf0b7e10ac 100644 --- a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py +++ b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py @@ -8,7 +8,7 @@ import torch import torch.nn as nn -from diffusers.models.attention import FeedForward +import torch.nn.functional as F # TODO replace this with vLLM implementation from diffusers.models.embeddings import TimestepEmbedding, Timesteps @@ -16,7 +16,11 @@ from diffusers.models.normalization import AdaLayerNormContinuous from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import QKVParallelLinear, ReplicatedLinear +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm_omni.diffusion.attention.backends.abstract import ( @@ -287,79 +291,133 @@ def _compute_video_freqs(self, frame, height, width, idx=0): return freqs.clone().contiguous() +class ColumnParallelApproxGELU(nn.Module): + def __init__(self, dim_in: int, dim_out: int, *, approximate: str, bias: bool = True): + super().__init__() + self.proj = ColumnParallelLinear( + dim_in, + dim_out, + bias=bias, + gather_output=False, + return_bias=False, + ) + self.approximate = approximate + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + return F.gelu(x, approximate=self.approximate) + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + dim_out: int | None = None, + mult: int = 4, + activation_fn: str = "gelu-approximate", + inner_dim: int | None = None, + bias: bool = True, + ) -> None: + super().__init__() + + assert activation_fn == "gelu-approximate", "Only gelu-approximate is supported." + + inner_dim = inner_dim or int(dim * mult) + dim_out = dim_out or dim + + layers: list[nn.Module] = [ + ColumnParallelApproxGELU(dim, inner_dim, approximate="tanh", bias=bias), + nn.Identity(), # placeholder for weight loading + RowParallelLinear( + inner_dim, + dim_out, + input_is_parallel=True, + return_bias=False, + ), + ] + + self.net = nn.ModuleList(layers) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + class QwenImageCrossAttention(nn.Module): def __init__( self, dim: int, # query_dim num_heads: int, head_dim: int, - window_size=(-1, -1), - added_kv_proj_dim: int = None, + added_kv_proj_dim: int, + window_size: tuple[int, int] = (-1, -1), out_bias: bool = True, - qk_norm=True, # rmsnorm - eps=1e-6, - pre_only=False, + qk_norm: bool = True, + eps: float = 1e-6, + pre_only: bool = False, context_pre_only: bool = False, - parallel_attention=False, - out_dim: int = None, + out_dim: int | None = None, ) -> None: - assert dim % num_heads == 0 super().__init__() + assert dim % num_heads == 0 + self.dim = dim - self.num_heads = num_heads - self.head_dim = dim // num_heads + self.head_dim = head_dim + self.total_num_heads = num_heads self.window_size = window_size self.qk_norm = qk_norm self.eps = eps - self.parallel_attention = parallel_attention - # layers - # self.to_q = ReplicatedLinear(dim, dim) - # self.to_k = ReplicatedLinear(dim, dim) - # self.to_v = ReplicatedLinear(dim, dim) self.to_qkv = QKVParallelLinear( hidden_size=dim, head_size=self.head_dim, total_num_heads=num_heads, - disable_tp=True, ) + self.query_num_heads = self.to_qkv.num_heads + self.kv_num_heads = self.to_qkv.num_kv_heads + self.norm_q = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() self.norm_k = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() - self.inner_dim = out_dim if out_dim is not None else head_dim * num_heads - self.inner_kv_dim = self.inner_dim - if added_kv_proj_dim is not None: - assert context_pre_only is not None - # self.add_k_proj = ReplicatedLinear(added_kv_proj_dim, self.inner_kv_dim, bias=True) - # self.add_v_proj = ReplicatedLinear(added_kv_proj_dim, self.inner_kv_dim, bias=True) - # self.add_q_proj = ReplicatedLinear( - # added_kv_proj_dim, self.inner_dim, bias=True - # ) - self.add_kv_proj = QKVParallelLinear( - added_kv_proj_dim, - head_size=self.inner_kv_dim // self.num_heads, - total_num_heads=self.num_heads, - disable_tp=True, - ) - if context_pre_only is not None and not context_pre_only: - self.to_add_out = ReplicatedLinear(self.inner_dim, self.dim, bias=out_bias) - else: - self.to_add_out = None + self.inner_dim = out_dim if out_dim is not None else head_dim * self.total_num_heads - if not pre_only: - self.to_out = nn.ModuleList([]) - self.to_out.append(ReplicatedLinear(self.inner_dim, self.dim, bias=out_bias)) - else: - self.to_out = None + assert context_pre_only is not None + self.add_kv_proj = QKVParallelLinear( + hidden_size=added_kv_proj_dim, + head_size=head_dim, + total_num_heads=num_heads, + ) + self.add_query_num_heads = self.add_kv_proj.num_heads + self.add_kv_num_heads = self.add_kv_proj.num_kv_heads + + assert not context_pre_only + self.to_add_out = RowParallelLinear( + self.inner_dim, + self.dim, + bias=out_bias, + input_is_parallel=True, + return_bias=False, + ) + + assert not pre_only + self.to_out = RowParallelLinear( + self.inner_dim, + self.dim, + bias=out_bias, + input_is_parallel=True, + return_bias=False, + ) self.norm_added_q = RMSNorm(head_dim, eps=eps) self.norm_added_k = RMSNorm(head_dim, eps=eps) self.attn = Attention( - num_heads=num_heads, + num_heads=self.query_num_heads, head_size=self.head_dim, softmax_scale=1.0 / (self.head_dim**0.5), causal=False, + num_kv_heads=self.kv_num_heads, ) self.rope = RotaryEmbedding(is_neox_style=False) @@ -377,61 +435,55 @@ def forward( txt_freqs: torch.Tensor, hidden_states_mask: torch.Tensor | None = None, encoder_hidden_states_mask: torch.Tensor | None = None, - ): - # if mask is all true, set it to None + ) -> tuple[torch.Tensor, torch.Tensor]: if hidden_states_mask is not None and hidden_states_mask.all(): hidden_states_mask = None if encoder_hidden_states_mask is not None and encoder_hidden_states_mask.all(): encoder_hidden_states_mask = None - seq_len_txt = encoder_hidden_states.shape[1] - # Compute QKV for image stream (sample projections) - qkv, _ = self.to_qkv(hidden_states) - img_query, img_key, img_value = qkv.chunk(3, dim=-1) + img_qkv, _ = self.to_qkv(hidden_states) + q_size = self.query_num_heads * self.head_dim + kv_size = self.kv_num_heads * self.head_dim + img_query, img_key, img_value = img_qkv.split([q_size, kv_size, kv_size], dim=-1) - # Compute QKV for text stream (context projections) - qkv, _ = self.add_kv_proj(encoder_hidden_states) - txt_query, txt_key, txt_value = qkv.chunk(3, dim=-1) + txt_qkv, _ = self.add_kv_proj(encoder_hidden_states) + add_q_size = self.add_query_num_heads * self.head_dim + add_kv_size = self.add_kv_num_heads * self.head_dim + txt_query, txt_key, txt_value = txt_qkv.split([add_q_size, add_kv_size, add_kv_size], dim=-1) - # Reshape for multi-head attention - img_query = img_query.unflatten(-1, (self.num_heads, -1)) - img_key = img_key.unflatten(-1, (self.num_heads, -1)) - img_value = img_value.unflatten(-1, (self.num_heads, -1)) + img_query = img_query.unflatten(-1, (self.query_num_heads, self.head_dim)) + img_key = img_key.unflatten(-1, (self.kv_num_heads, self.head_dim)) + img_value = img_value.unflatten(-1, (self.kv_num_heads, self.head_dim)) - txt_query = txt_query.unflatten(-1, (self.num_heads, -1)) - txt_key = txt_key.unflatten(-1, (self.num_heads, -1)) - txt_value = txt_value.unflatten(-1, (self.num_heads, -1)) + txt_query = txt_query.unflatten(-1, (self.add_query_num_heads, self.head_dim)) + txt_key = txt_key.unflatten(-1, (self.add_kv_num_heads, self.head_dim)) + txt_value = txt_value.unflatten(-1, (self.add_kv_num_heads, self.head_dim)) - # Apply QK normalization img_query = self.norm_q(img_query) img_key = self.norm_k(img_key) txt_query = self.norm_added_q(txt_query) txt_key = self.norm_added_k(txt_key) - # Apply RoPE img_cos = vid_freqs.real.to(img_query.dtype) img_sin = vid_freqs.imag.to(img_query.dtype) txt_cos = txt_freqs.real.to(txt_query.dtype) txt_sin = txt_freqs.imag.to(txt_query.dtype) + img_query = self.rope(img_query, img_cos, img_sin) img_key = self.rope(img_key, img_cos, img_sin) txt_query = self.rope(txt_query, txt_cos, txt_sin) txt_key = self.rope(txt_key, txt_cos, txt_sin) - # Concatenate for joint attention - # Order: [text, image] + seq_len_txt = encoder_hidden_states.shape[1] joint_query = torch.cat([txt_query, img_query], dim=1) joint_key = torch.cat([txt_key, img_key], dim=1) joint_value = torch.cat([txt_value, img_value], dim=1) - # Compute joint attention if ( self.parallel_config is not None and self.parallel_config.sequence_parallel_size > 1 and not get_forward_context().split_text_embed_in_sp ): - # if using sequence parallel, but not splitting text embed, - # we need to pass text embedding to attention layer as joint qkv attn_metadata = AttentionMetadata( joint_query=txt_query, joint_key=txt_key, @@ -443,22 +495,17 @@ def forward( if encoder_hidden_states_mask is not None: attn_metadata.joint_attn_mask = encoder_hidden_states_mask - joint_hidden_states = self.attn( - img_query, - img_key, - img_value, - attn_metadata, - ) + joint_hidden_states = self.attn(img_query, img_key, img_value, attn_metadata) else: attn_metadata = None if hidden_states_mask is not None or encoder_hidden_states_mask is not None: - mask_list = [] + mask_list: list[torch.Tensor] = [] if encoder_hidden_states_mask is not None: mask_list.append(encoder_hidden_states_mask) else: mask_list.append( torch.ones( - [encoder_hidden_states.shape[0], encoder_hidden_states.shape[1]], + encoder_hidden_states.shape[:2], dtype=torch.bool, device=encoder_hidden_states.device, ) @@ -468,34 +515,22 @@ def forward( else: mask_list.append( torch.ones( - [hidden_states.shape[0], hidden_states.shape[1]], + hidden_states.shape[:2], dtype=torch.bool, device=hidden_states.device, ) ) - joint_mask = ( - None if len(mask_list) == 0 else torch.cat(mask_list, dim=1) if len(mask_list) > 1 else mask_list[0] - ) + joint_mask = torch.cat(mask_list, dim=1) if len(mask_list) > 1 else mask_list[0] attn_metadata = AttentionMetadata(attn_mask=joint_mask) - joint_hidden_states = self.attn( - joint_query, - joint_key, - joint_value, - attn_metadata, - ) - joint_hidden_states = joint_hidden_states.flatten(2, 3) - joint_hidden_states = joint_hidden_states.to(joint_query.dtype) - # Split attention outputs back - txt_attn_output = joint_hidden_states[:, :seq_len_txt, :] # Text part - img_attn_output = joint_hidden_states[:, seq_len_txt:, :] # Image part + joint_hidden_states = self.attn(joint_query, joint_key, joint_value, attn_metadata) - # Apply output projections - img_attn_output, _ = self.to_out[0](img_attn_output) - if len(self.to_out) > 1: - (img_attn_output,) = self.to_out[1](img_attn_output) # dropout + joint_hidden_states = joint_hidden_states.flatten(2, 3).to(joint_query.dtype) + txt_attn_output = joint_hidden_states[:, :seq_len_txt, :] + img_attn_output = joint_hidden_states[:, seq_len_txt:, :] - txt_attn_output, _ = self.to_add_out(txt_attn_output) + img_attn_output = self.to_out(img_attn_output) + txt_attn_output = self.to_add_out(txt_attn_output) return img_attn_output, txt_attn_output @@ -530,7 +565,7 @@ def __init__( head_dim=attention_head_dim, ) self.img_norm2 = AdaLayerNorm(dim, elementwise_affine=False, eps=eps) - self.img_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + self.img_mlp = FeedForward(dim=dim, dim_out=dim) # Text processing modules self.txt_mod = nn.Sequential( @@ -540,7 +575,7 @@ def __init__( self.txt_norm1 = AdaLayerNorm(dim, elementwise_affine=False, eps=eps) # Text doesn't need separate attention - it's handled by img_attn joint computation self.txt_norm2 = AdaLayerNorm(dim, elementwise_affine=False, eps=eps) - self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + self.txt_mlp = FeedForward(dim=dim, dim_out=dim) self.zero_cond_t = zero_cond_t @@ -892,6 +927,8 @@ def get_rotary_emb_chunk(freqs, padding=0): if original_seq_len is not None: output = output[:, :original_seq_len, :] + torch.cuda.empty_cache() + return Transformer2DModelOutput(sample=output) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: @@ -916,17 +953,22 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loaded_params: set[str] = set() for name, loaded_weight in weights: + original_name = name + lookup_name = name for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: + if weight_name not in original_name: continue - name = name.replace(weight_name, param_name) - param = params_dict[name] + lookup_name = original_name.replace(weight_name, param_name) + param = params_dict[lookup_name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: - param = params_dict[name] + if lookup_name not in params_dict and ".to_out.0." in lookup_name: + lookup_name = lookup_name.replace(".to_out.0.", ".to_out.") + param = params_dict[lookup_name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - loaded_params.add(name) + loaded_params.add(original_name) + loaded_params.add(lookup_name) return loaded_params