diff --git a/studio/backend/tests/test_transformers_version.py b/studio/backend/tests/test_transformers_version.py index f3dae537c7..ba2da1295a 100644 --- a/studio/backend/tests/test_transformers_version.py +++ b/studio/backend/tests/test_transformers_version.py @@ -34,6 +34,7 @@ _tokenizer_class_cache, needs_transformers_5, ) +from utils.models import model_config # --------------------------------------------------------------------------- @@ -188,3 +189,50 @@ def test_local_checkpoint_resolved_via_config(self, tmp_path: Path): # We test the full resolution chain here: resolved = _resolve_base_model(str(tmp_path)) assert needs_transformers_5(resolved) is True + + +class TestVisionModelDetection: + def test_is_vlm_config_accepts_transformers_objects(self): + config = _types.SimpleNamespace( + model_type = "qwen3_5", + architectures = ["Qwen3_5ForConditionalGeneration"], + vision_config = {}, + ) + assert model_config._is_vlm_config(config) is True + + def test_is_vlm_config_accepts_raw_config_dicts(self): + config = { + "model_type": "gemma4", + "architectures": ["Gemma4ForConditionalGeneration"], + "vision_config": {}, + } + assert model_config._is_vlm_config(config) is True + + def test_is_vlm_config_rejects_audio_only_conditional_generation(self): + config = { + "model_type": "whisper", + "architectures": ["WhisperForConditionalGeneration"], + } + assert model_config._is_vlm_config(config) is False + + def test_is_vision_model_falls_back_to_raw_metadata_for_t5_models(self): + with ( + patch("utils.transformers_version.needs_transformers_5", return_value = True), + patch( + "utils.models.model_config.load_model_config", + side_effect = RuntimeError("direct load failed"), + ), + patch( + "utils.models.model_config._is_vision_model_subprocess", + return_value = None, + ), + patch( + "utils.models.model_config._load_model_config_metadata", + return_value = { + "model_type": "qwen3_5", + "architectures": ["Qwen3_5ForConditionalGeneration"], + "vision_config": {}, + }, + ), + ): + assert model_config.is_vision_model("unsloth/Qwen3.5-4B") is True