-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Add tests for is_vision_model() caching behaviour #4855
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
a25c064
Add tests for is_vision_model() caching behaviour
danielhanchen 14d6731
Fix review feedback: remove dead helper, fix exception test
rolandtannous 44f0458
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 203d678
Use strict mock specs so tests exercise intended detection paths
rolandtannous File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,238 @@ | ||
| # SPDX-License-Identifier: AGPL-3.0-only | ||
| # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 | ||
|
|
||
| """Tests for is_vision_model() caching behaviour. | ||
|
|
||
| The vision detection cache (``_vision_detection_cache``) mirrors the existing | ||
| ``_audio_detection_cache`` pattern used by ``detect_audio_type()``. These | ||
| tests verify that: | ||
|
|
||
| * Repeated calls for the same model hit the cache (no redundant work). | ||
| * Different models each trigger their own detection. | ||
| * Both True and False results are cached. | ||
| * The subprocess path (transformers 5.x models) is also cached. | ||
| * Exceptions that fall back to False are cached. | ||
| """ | ||
|
|
||
| import sys | ||
| import types as _types | ||
| from pathlib import Path | ||
| from unittest.mock import patch, MagicMock | ||
|
|
||
| import pytest | ||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # sys.path + logger stub — same pattern as the rest of the test suite | ||
| # --------------------------------------------------------------------------- | ||
| _BACKEND_DIR = str(Path(__file__).resolve().parent.parent) | ||
| if _BACKEND_DIR not in sys.path: | ||
| sys.path.insert(0, _BACKEND_DIR) | ||
|
|
||
| _loggers_stub = _types.ModuleType("loggers") | ||
| _loggers_stub.get_logger = lambda name: __import__("logging").getLogger(name) | ||
| sys.modules.setdefault("loggers", _loggers_stub) | ||
|
|
||
| from utils.models.model_config import ( | ||
| is_vision_model, | ||
| _is_vision_model_uncached, | ||
| _vision_detection_cache, | ||
| ) | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Helpers | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| @pytest.fixture(autouse = True) | ||
| def _clear_vision_cache(): | ||
| """Ensure every test starts with a fresh cache.""" | ||
| _vision_detection_cache.clear() | ||
| yield | ||
| _vision_detection_cache.clear() | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Cache hit / miss tests | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| class TestVisionCacheHitMiss: | ||
| """Verify the cache prevents redundant detection calls.""" | ||
|
|
||
| @patch("utils.models.model_config._is_vision_model_uncached", return_value = True) | ||
| def test_second_call_uses_cache(self, mock_uncached): | ||
| """Calling is_vision_model() twice for the same model should invoke | ||
| the uncached function only once.""" | ||
| assert is_vision_model("org/my-vlm") is True | ||
| assert is_vision_model("org/my-vlm") is True | ||
| mock_uncached.assert_called_once_with("org/my-vlm", None) | ||
|
|
||
| @patch("utils.models.model_config._is_vision_model_uncached", return_value = False) | ||
| def test_different_models_each_detected(self, mock_uncached): | ||
| """Different model names should each trigger detection.""" | ||
| is_vision_model("model-a") | ||
| is_vision_model("model-b") | ||
| assert mock_uncached.call_count == 2 | ||
|
|
||
| @patch("utils.models.model_config._is_vision_model_uncached", return_value = True) | ||
| def test_cache_returns_correct_value(self, mock_uncached): | ||
| """The cached value must match what _is_vision_model_uncached returned.""" | ||
| first = is_vision_model("org/vlm") | ||
| second = is_vision_model("org/vlm") | ||
| assert first is True | ||
| assert second is True | ||
|
|
||
|
|
||
| class TestVisionCacheStoresFalse: | ||
| """Non-VLM results (False) must also be cached to avoid re-detection.""" | ||
|
|
||
| @patch("utils.models.model_config._is_vision_model_uncached", return_value = False) | ||
| def test_false_result_cached(self, mock_uncached): | ||
| assert is_vision_model("org/text-only") is False | ||
| assert is_vision_model("org/text-only") is False | ||
| mock_uncached.assert_called_once() | ||
| assert _vision_detection_cache[("org/text-only", None)] is False | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Subprocess path (transformers 5.x) caching | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| class TestVisionCacheSubprocessPath: | ||
| """Models needing transformers 5.x go through _is_vision_model_subprocess. | ||
| The cache should prevent the subprocess from being spawned more than once | ||
| per model per process.""" | ||
|
|
||
| @patch("utils.models.model_config._is_vision_model_subprocess", return_value = True) | ||
| @patch("utils.transformers_version.needs_transformers_5", return_value = True) | ||
| def test_subprocess_called_once_with_cache(self, mock_needs_t5, mock_subprocess): | ||
| """Subprocess should only fire on the first call; second is cached.""" | ||
| # First call: goes through uncached → subprocess | ||
| assert is_vision_model("unsloth/Qwen3.5-2B") is True | ||
| # Second call: cache hit, no subprocess | ||
| assert is_vision_model("unsloth/Qwen3.5-2B") is True | ||
|
|
||
| mock_subprocess.assert_called_once() | ||
| assert _vision_detection_cache[("unsloth/Qwen3.5-2B", None)] is True | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Exception handling — cache the False fallback | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| class TestVisionCacheOnException: | ||
| """When detection raises an exception, _is_vision_model_uncached catches | ||
| it and returns False. That False must be cached so subsequent calls don't | ||
| retry and fail again.""" | ||
|
|
||
| @patch( | ||
| "utils.models.model_config.load_model_config", | ||
| side_effect = OSError("network down"), | ||
| ) | ||
| @patch("utils.transformers_version.needs_transformers_5", return_value = False) | ||
| def test_exception_result_cached(self, mock_needs_t5, mock_load_config): | ||
| """A real exception inside _is_vision_model_uncached should be caught, | ||
| return False, and that False should be cached for subsequent calls.""" | ||
| # First call: load_model_config raises → except branch → False | ||
| assert is_vision_model("broken/model") is False | ||
| # Second call: cache hit, load_model_config not called again | ||
| assert is_vision_model("broken/model") is False | ||
| mock_load_config.assert_called_once() | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Direct detection path (non-transformers-5 models) caching | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| class TestVisionCacheDirectPath: | ||
| """For models that do NOT need transformers 5.x, the detection goes through | ||
| load_model_config directly. The cache must work the same way.""" | ||
|
|
||
| @patch("utils.transformers_version.needs_transformers_5", return_value = False) | ||
| @patch("utils.models.model_config.load_model_config") | ||
| def test_direct_vlm_detection_cached(self, mock_load_config, mock_needs_t5): | ||
| """A standard VLM detected via architecture suffix should be cached.""" | ||
| cfg = MagicMock(spec = []) # strict: only explicitly set attrs exist | ||
| cfg.model_type = "gemma3" | ||
| cfg.architectures = ["Gemma3ForConditionalGeneration"] | ||
| mock_load_config.return_value = cfg | ||
|
|
||
| assert is_vision_model("google/gemma-3-4b-it") is True | ||
| assert is_vision_model("google/gemma-3-4b-it") is True | ||
| # load_model_config should only be called once | ||
| mock_load_config.assert_called_once() | ||
|
|
||
| @patch("utils.transformers_version.needs_transformers_5", return_value = False) | ||
| @patch("utils.models.model_config.load_model_config") | ||
| def test_direct_non_vlm_detection_cached(self, mock_load_config, mock_needs_t5): | ||
| """A standard text model (no VLM indicators) should cache False.""" | ||
| cfg = MagicMock(spec = []) # spec=[] means no attributes at all | ||
| cfg.model_type = "llama" | ||
| cfg.architectures = ["LlamaForCausalLM"] | ||
| mock_load_config.return_value = cfg | ||
|
|
||
| # LlamaForCausalLM doesn't end with VLM suffixes, no vision_config, etc. | ||
| assert is_vision_model("meta-llama/Llama-3-8B") is False | ||
| assert is_vision_model("meta-llama/Llama-3-8B") is False | ||
| mock_load_config.assert_called_once() | ||
|
|
||
| @patch("utils.transformers_version.needs_transformers_5", return_value = False) | ||
| @patch("utils.models.model_config.load_model_config") | ||
| def test_vision_config_attr_detected_and_cached( | ||
| self, mock_load_config, mock_needs_t5 | ||
| ): | ||
| """Models with vision_config (LLaVA, Qwen2-VL, etc.) should be cached as True.""" | ||
| cfg = MagicMock(spec = []) # strict: only explicitly set attrs exist | ||
| cfg.model_type = "qwen2_vl" | ||
| cfg.architectures = ["Qwen2VLForCausalLM"] # Doesn't match VLM suffixes | ||
| cfg.vision_config = {"hidden_size": 1024} | ||
| mock_load_config.return_value = cfg | ||
|
|
||
| assert is_vision_model("Qwen/Qwen2-VL-7B") is True | ||
| assert is_vision_model("Qwen/Qwen2-VL-7B") is True | ||
| mock_load_config.assert_called_once() | ||
|
|
||
| @patch("utils.transformers_version.needs_transformers_5", return_value = False) | ||
| @patch("utils.models.model_config.load_model_config") | ||
| def test_audio_model_excluded_and_cached(self, mock_load_config, mock_needs_t5): | ||
| """Audio-only models (csm, whisper) with ForConditionalGeneration | ||
| should be excluded from VLM detection and cached as False.""" | ||
| cfg = MagicMock(spec = []) # strict: only explicitly set attrs exist | ||
| cfg.model_type = "whisper" | ||
| cfg.architectures = ["WhisperForConditionalGeneration"] | ||
| mock_load_config.return_value = cfg | ||
|
|
||
| assert is_vision_model("openai/whisper-large-v3") is False | ||
| assert is_vision_model("openai/whisper-large-v3") is False | ||
| mock_load_config.assert_called_once() | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # hf_token handling | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| class TestVisionCacheTokenHandling: | ||
| """The cache is keyed on (model_name, hf_token). | ||
| Different tokens for the same model should trigger separate detections | ||
| to handle gated models correctly.""" | ||
|
|
||
| @patch("utils.models.model_config._is_vision_model_uncached", return_value = True) | ||
| def test_different_tokens_trigger_new_detection(self, mock_uncached): | ||
| """Calls with different tokens should trigger separate detections to | ||
| handle gated models correctly (e.g. unauthenticated probe → False, | ||
| then authenticated call should re-check).""" | ||
| assert is_vision_model("gated/model", hf_token = "token-a") is True | ||
| assert is_vision_model("gated/model", hf_token = "token-b") is True | ||
| assert mock_uncached.call_count == 2 | ||
|
|
||
| @patch("utils.models.model_config._is_vision_model_uncached", return_value = True) | ||
| def test_same_token_uses_cache(self, mock_uncached): | ||
| """Repeated calls with identical model + token should hit cache.""" | ||
| assert is_vision_model("gated/model", hf_token = "token-a") is True | ||
| assert is_vision_model("gated/model", hf_token = "token-a") is True | ||
| mock_uncached.assert_called_once() | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.