Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 94 additions & 48 deletions studio/backend/utils/models/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,7 @@ def load_model_config(
"cogvlm2",
"minicpmv",
}
_AUDIO_ONLY_MODEL_TYPES = {"csm", "whisper"}

# Pre-computed .venv_t5 path and backend dir for subprocess version switching.
_VENV_T5_DIR = str(Path.home() / ".unsloth" / "studio" / ".venv_t5")
Expand Down Expand Up @@ -548,7 +549,7 @@ def load_model_config(

def _is_vision_model_subprocess(
model_name: str, hf_token: Optional[str] = None
) -> bool:
) -> Optional[bool]:
"""Run is_vision_model check in a subprocess with transformers 5.x.

Same pattern as training/inference workers: spawn a clean subprocess
Expand Down Expand Up @@ -580,7 +581,7 @@ def _is_vision_model_subprocess(
model_name,
stderr or result.stdout.strip(),
)
return False
return None

data = json.loads(result.stdout.strip())
if "error" in data:
Expand All @@ -589,7 +590,7 @@ def _is_vision_model_subprocess(
model_name,
data["error"],
)
return False
return None

is_vlm = data["is_vision"]
logger.info(
Expand All @@ -604,11 +605,66 @@ def _is_vision_model_subprocess(

except subprocess.TimeoutExpired:
logger.warning("Vision check subprocess timed out for '%s'", model_name)
return False
return None
except Exception as exc:
logger.warning("Vision check subprocess failed for '%s': %s", model_name, exc)
return None


def _is_vlm_config(config: Any) -> bool:
if isinstance(config, dict):
model_type = config.get("model_type")
architectures = config.get("architectures")
has_vision_config = "vision_config" in config
has_img_processor = "img_processor" in config
has_image_token_index = "image_token_index" in config
else:
model_type = getattr(config, "model_type", None)
architectures = getattr(config, "architectures", None)
has_vision_config = hasattr(config, "vision_config")
has_img_processor = hasattr(config, "img_processor")
has_image_token_index = hasattr(config, "image_token_index")

if model_type in _AUDIO_ONLY_MODEL_TYPES:
return False

if architectures:
if any(x.endswith(_VLM_ARCH_SUFFIXES) for x in architectures):
return True

if (
has_vision_config or has_img_processor or has_image_token_index
) or model_type in _VLM_MODEL_TYPES:
return True
return False


def _load_model_config_metadata(
model_name: str, hf_token: Optional[str] = None
) -> Optional[dict[str, Any]]:
try:
if is_local_path(model_name):
config_path = Path(normalize_path(model_name)) / "config.json"
if config_path.is_file():
return json.loads(config_path.read_text())
return None

from huggingface_hub import hf_hub_download

download_kwargs: dict[str, Any] = {}
if hf_token:
download_kwargs["token"] = hf_token

config_path = hf_hub_download(
repo_id = model_name,
filename = "config.json",
**download_kwargs,
)
return json.loads(Path(config_path).read_text())
except Exception as exc:
logger.warning("Could not load raw config metadata for %s: %s", model_name, exc)
return None


def is_vision_model(model_name: str, hf_token: Optional[str] = None) -> bool:
"""
Expand All @@ -628,60 +684,50 @@ def is_vision_model(model_name: str, hf_token: Optional[str] = None) -> bool:
# recognize their architectures.
from utils.transformers_version import needs_transformers_5

if needs_transformers_5(model_name):
logger.info(
"Model '%s' needs transformers 5.x — checking vision via subprocess",
model_name,
)
return _is_vision_model_subprocess(model_name, hf_token = hf_token)
needs_t5 = needs_transformers_5(model_name)

try:
config = load_model_config(model_name, use_auth = True, token = hf_token)

# Exclude audio-only models that share ForConditionalGeneration suffix
# (e.g. CsmForConditionalGeneration, WhisperForConditionalGeneration)
_audio_only_model_types = {"csm", "whisper"}
model_type = getattr(config, "model_type", None)
if model_type in _audio_only_model_types:
if _is_vlm_config(config):
model_type = getattr(config, "model_type", None)
architectures = getattr(config, "architectures", [])
logger.info(
"Model %s detected as VLM in-process: model_type=%s architectures=%s",
model_name,
model_type,
architectures,
)
return True
if not needs_t5:
return False

# Check 1: Architecture class name patterns
if hasattr(config, "architectures"):
is_vlm = any(x.endswith(_VLM_ARCH_SUFFIXES) for x in config.architectures)
if is_vlm:
logger.info(
f"Model {model_name} detected as VLM: architecture {config.architectures}"
)
return True

# Check 2: Has vision_config (most VLMs: LLaVA, Gemma-3, Qwen2-VL, etc.)
if hasattr(config, "vision_config"):
logger.info(f"Model {model_name} detected as VLM: has vision_config")
return True
except Exception as e:
logger.warning(
f"Could not determine if {model_name} is vision model in-process: {e}"
)

# Check 3: Has img_processor (Phi-3.5 Vision uses this instead of vision_config)
if hasattr(config, "img_processor"):
logger.info(f"Model {model_name} detected as VLM: has img_processor")
return True
if needs_t5:
logger.info(
"Model '%s' needs transformers 5.x — checking vision via subprocess",
model_name,
)
subprocess_result = _is_vision_model_subprocess(model_name, hf_token = hf_token)
if subprocess_result is not None:
return subprocess_result

# Check 4: Has image_token_index (common in VLMs for image placeholder tokens)
if hasattr(config, "image_token_index"):
logger.info(f"Model {model_name} detected as VLM: has image_token_index")
config_data = _load_model_config_metadata(model_name, hf_token = hf_token)
if config_data is not None:
if _is_vlm_config(config_data):
logger.info(
"Model %s detected as VLM from raw config metadata: model_type=%s architectures=%s",
model_name,
config_data.get("model_type"),
config_data.get("architectures", []),
)
return True

# Check 5: Known VLM model_type values that may not match above checks
if hasattr(config, "model_type"):
if config.model_type in _VLM_MODEL_TYPES:
logger.info(
f"Model {model_name} detected as VLM: model_type={config.model_type}"
)
return True

return False

except Exception as e:
logger.warning(f"Could not determine if {model_name} is vision model: {e}")
return False
return False


VALID_AUDIO_TYPES = ("snac", "csm", "bicodec", "dac", "whisper", "audio_vlm")
Expand Down
Loading