diff --git a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py index a3cc6cbbb67..5f2b8c68a23 100644 --- a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py +++ b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import functools from collections.abc import Iterable +from functools import lru_cache from math import prod from typing import Any @@ -249,7 +249,7 @@ def forward(self, video_fhw, txt_seq_lens, device): return vid_freqs, txt_freqs - @functools.cache + @lru_cache(maxsize=16) def _compute_video_freqs(self, frame, height, width, idx=0): seq_lens = frame * height * width freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) @@ -268,7 +268,7 @@ def _compute_video_freqs(self, frame, height, width, idx=0): freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) return freqs.clone().contiguous() - @functools.cache + @lru_cache(maxsize=16) def _compute_condition_freqs(self, frame, height, width): seq_lens = frame * height * width freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) @@ -311,7 +311,6 @@ def __init__(self, theta: int, axes_dim: list[int], scale_rope=False): ], dim=1, ) - self.rope_cache = {} # DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART self.scale_rope = scale_rope @@ -349,14 +348,7 @@ def forward(self, video_fhw, txt_seq_lens, device): max_vid_index = 0 for idx, fhw in enumerate(video_fhw): frame, height, width = fhw - rope_key = f"{idx}_{height}_{width}" - - if not torch.compiler.is_compiling(): - if rope_key not in self.rope_cache: - self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx) - video_freq = self.rope_cache[rope_key] - else: - video_freq = self._compute_video_freqs(frame, height, width, idx) + video_freq = self._compute_video_freqs(frame, height, width, idx) video_freq = video_freq.to(device) vid_freqs.append(video_freq) @@ -371,7 +363,7 @@ def forward(self, video_fhw, txt_seq_lens, device): return vid_freqs, txt_freqs - @functools.cache + @lru_cache(maxsize=16) def _compute_video_freqs(self, frame, height, width, idx=0): seq_lens = frame * height * width freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)