diff --git a/studio/backend/utils/models/model_config.py b/studio/backend/utils/models/model_config.py index f97ea993eb..eb65951245 100644 --- a/studio/backend/utils/models/model_config.py +++ b/studio/backend/utils/models/model_config.py @@ -490,6 +490,7 @@ def load_model_config( "cogvlm2", "minicpmv", } +_AUDIO_ONLY_MODEL_TYPES = {"csm", "whisper"} # Pre-computed .venv_t5 path and backend dir for subprocess version switching. _VENV_T5_DIR = str(Path.home() / ".unsloth" / "studio" / ".venv_t5") @@ -548,7 +549,7 @@ def load_model_config( def _is_vision_model_subprocess( model_name: str, hf_token: Optional[str] = None -) -> bool: +) -> Optional[bool]: """Run is_vision_model check in a subprocess with transformers 5.x. Same pattern as training/inference workers: spawn a clean subprocess @@ -580,7 +581,7 @@ def _is_vision_model_subprocess( model_name, stderr or result.stdout.strip(), ) - return False + return None data = json.loads(result.stdout.strip()) if "error" in data: @@ -589,7 +590,7 @@ def _is_vision_model_subprocess( model_name, data["error"], ) - return False + return None is_vlm = data["is_vision"] logger.info( @@ -604,11 +605,66 @@ def _is_vision_model_subprocess( except subprocess.TimeoutExpired: logger.warning("Vision check subprocess timed out for '%s'", model_name) - return False + return None except Exception as exc: logger.warning("Vision check subprocess failed for '%s': %s", model_name, exc) + return None + + +def _is_vlm_config(config: Any) -> bool: + if isinstance(config, dict): + model_type = config.get("model_type") + architectures = config.get("architectures") + has_vision_config = "vision_config" in config + has_img_processor = "img_processor" in config + has_image_token_index = "image_token_index" in config + else: + model_type = getattr(config, "model_type", None) + architectures = getattr(config, "architectures", None) + has_vision_config = hasattr(config, "vision_config") + has_img_processor = hasattr(config, "img_processor") + has_image_token_index = hasattr(config, "image_token_index") + + if model_type in _AUDIO_ONLY_MODEL_TYPES: return False + if architectures: + if any(x.endswith(_VLM_ARCH_SUFFIXES) for x in architectures): + return True + + if ( + has_vision_config or has_img_processor or has_image_token_index + ) or model_type in _VLM_MODEL_TYPES: + return True + return False + + +def _load_model_config_metadata( + model_name: str, hf_token: Optional[str] = None +) -> Optional[dict[str, Any]]: + try: + if is_local_path(model_name): + config_path = Path(normalize_path(model_name)) / "config.json" + if config_path.is_file(): + return json.loads(config_path.read_text()) + return None + + from huggingface_hub import hf_hub_download + + download_kwargs: dict[str, Any] = {} + if hf_token: + download_kwargs["token"] = hf_token + + config_path = hf_hub_download( + repo_id = model_name, + filename = "config.json", + **download_kwargs, + ) + return json.loads(Path(config_path).read_text()) + except Exception as exc: + logger.warning("Could not load raw config metadata for %s: %s", model_name, exc) + return None + def is_vision_model(model_name: str, hf_token: Optional[str] = None) -> bool: """ @@ -628,60 +684,50 @@ def is_vision_model(model_name: str, hf_token: Optional[str] = None) -> bool: # recognize their architectures. from utils.transformers_version import needs_transformers_5 - if needs_transformers_5(model_name): - logger.info( - "Model '%s' needs transformers 5.x — checking vision via subprocess", - model_name, - ) - return _is_vision_model_subprocess(model_name, hf_token = hf_token) + needs_t5 = needs_transformers_5(model_name) try: config = load_model_config(model_name, use_auth = True, token = hf_token) - - # Exclude audio-only models that share ForConditionalGeneration suffix - # (e.g. CsmForConditionalGeneration, WhisperForConditionalGeneration) - _audio_only_model_types = {"csm", "whisper"} - model_type = getattr(config, "model_type", None) - if model_type in _audio_only_model_types: + if _is_vlm_config(config): + model_type = getattr(config, "model_type", None) + architectures = getattr(config, "architectures", []) + logger.info( + "Model %s detected as VLM in-process: model_type=%s architectures=%s", + model_name, + model_type, + architectures, + ) + return True + if not needs_t5: return False - # Check 1: Architecture class name patterns - if hasattr(config, "architectures"): - is_vlm = any(x.endswith(_VLM_ARCH_SUFFIXES) for x in config.architectures) - if is_vlm: - logger.info( - f"Model {model_name} detected as VLM: architecture {config.architectures}" - ) - return True - - # Check 2: Has vision_config (most VLMs: LLaVA, Gemma-3, Qwen2-VL, etc.) - if hasattr(config, "vision_config"): - logger.info(f"Model {model_name} detected as VLM: has vision_config") - return True + except Exception as e: + logger.warning( + f"Could not determine if {model_name} is vision model in-process: {e}" + ) - # Check 3: Has img_processor (Phi-3.5 Vision uses this instead of vision_config) - if hasattr(config, "img_processor"): - logger.info(f"Model {model_name} detected as VLM: has img_processor") - return True + if needs_t5: + logger.info( + "Model '%s' needs transformers 5.x — checking vision via subprocess", + model_name, + ) + subprocess_result = _is_vision_model_subprocess(model_name, hf_token = hf_token) + if subprocess_result is not None: + return subprocess_result - # Check 4: Has image_token_index (common in VLMs for image placeholder tokens) - if hasattr(config, "image_token_index"): - logger.info(f"Model {model_name} detected as VLM: has image_token_index") + config_data = _load_model_config_metadata(model_name, hf_token = hf_token) + if config_data is not None: + if _is_vlm_config(config_data): + logger.info( + "Model %s detected as VLM from raw config metadata: model_type=%s architectures=%s", + model_name, + config_data.get("model_type"), + config_data.get("architectures", []), + ) return True - - # Check 5: Known VLM model_type values that may not match above checks - if hasattr(config, "model_type"): - if config.model_type in _VLM_MODEL_TYPES: - logger.info( - f"Model {model_name} detected as VLM: model_type={config.model_type}" - ) - return True - return False - except Exception as e: - logger.warning(f"Could not determine if {model_name} is vision model: {e}") - return False + return False VALID_AUDIO_TYPES = ("snac", "csm", "bicodec", "dac", "whisper", "audio_vlm")