diff --git a/.gitignore b/.gitignore index c0ee968064..b5e002235e 100644 --- a/.gitignore +++ b/.gitignore @@ -203,6 +203,7 @@ checkpoints/ # Cache directories cache/ !vllm_omni/diffusion/cache/ +!tests/diffusion/cache/ .cache/ diffusion_cache/ kv_cache/ diff --git a/tests/diffusion/cache/test_cache_dit.py b/tests/diffusion/cache/test_cache_dit.py new file mode 100644 index 0000000000..0b7ef72358 --- /dev/null +++ b/tests/diffusion/cache/test_cache_dit.py @@ -0,0 +1,40 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Model specific tests for CacheDiT enablement. +""" + +from unittest.mock import Mock, patch + +import pytest + +import vllm_omni.diffusion.cache.cache_dit_backend as cd_backend +from vllm_omni.diffusion.data import DiffusionCacheConfig + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + +SEPARATE_CFG_ENABLERS = [ + cd_backend.enable_cache_for_ltx2, + cd_backend.enable_cache_for_wan22, + cd_backend.enable_cache_for_longcat_image, +] + +SAMPLE_CACHE_CONFIG = DiffusionCacheConfig() + + +@pytest.mark.parametrize("enabler", SEPARATE_CFG_ENABLERS) +@patch("vllm_omni.diffusion.cache.cache_dit_backend.BlockAdapter") +@patch("vllm_omni.diffusion.cache.cache_dit_backend.cache_dit") +def test_separate_cfg(mock_cache_dit, mock_block_adapter, enabler): + """Ensure that custom enablers for models with separate CFG pass + the param through to cache_dit correctly. + + Regression test for: https://github.com/vllm-project/vllm-omni/pull/2860 + """ + mock_pipeline = Mock() + enabler(mock_pipeline, SAMPLE_CACHE_CONFIG) + + mock_cache_dit.enable_cache.assert_called_once() + adapter_kwargs = mock_block_adapter.call_args.kwargs + assert adapter_kwargs["has_separate_cfg"] is True diff --git a/vllm_omni/diffusion/cache/cache_dit_backend.py b/vllm_omni/diffusion/cache/cache_dit_backend.py index e9f79da4f3..3daf883e0d 100644 --- a/vllm_omni/diffusion/cache/cache_dit_backend.py +++ b/vllm_omni/diffusion/cache/cache_dit_backend.py @@ -281,6 +281,7 @@ def enable_cache_for_longcat_image(pipeline: Any, cache_config: Any) -> Callable ], forward_pattern=[ForwardPattern.Pattern_1, ForwardPattern.Pattern_1], params_modifiers=[modifier], + has_separate_cfg=True, ) ), cache_config=db_cache_config, @@ -632,6 +633,7 @@ def enable_cache_for_ltx2(pipeline: Any, cache_config: Any) -> Callable[[int], N forward_pattern=ForwardPattern.Pattern_0, # Treat audio_hidden_states as encoder_hidden_states in Pattern_0 check_forward_pattern=False, + has_separate_cfg=True, ), cache_config=db_cache_config, calibrator_config=calibrator_config,