diff --git a/python/sglang/multimodal_gen/registry.py b/python/sglang/multimodal_gen/registry.py index 1e2fafeb472f..1211d8981e8f 100644 --- a/python/sglang/multimodal_gen/registry.py +++ b/python/sglang/multimodal_gen/registry.py @@ -11,7 +11,6 @@ import importlib import os import pkgutil -import re from functools import lru_cache from typing import Any, Callable, Dict, List, Optional, Tuple, Type @@ -158,6 +157,13 @@ def register_configs( _MODEL_NAME_DETECTORS.append((model_id, detector)) +def get_model_short_name(model_id: str) -> str: + if "/" in model_id: + return model_id.split("/")[-1] + else: + return model_id + + def _get_config_info(model_path: str) -> Optional[ConfigInfo]: """ Gets the ConfigInfo for a given model path using mappings and detectors. @@ -169,14 +175,16 @@ def _get_config_info(model_path: str) -> Optional[ConfigInfo]: return _CONFIG_REGISTRY.get(model_id) # 2. Partial match: find the best (longest) match against all registered model hf paths. - cleaned_model_path = re.sub(r"--", "/", model_path.lower()) + model_name = get_model_short_name(model_path.lower()) all_model_hf_paths = sorted(_MODEL_HF_PATH_TO_NAME.keys(), key=len, reverse=True) - for model_hf_path in all_model_hf_paths: - if model_hf_path.lower() in cleaned_model_path: + for registered_model_hf_id in all_model_hf_paths: + registered_model_name = get_model_short_name(registered_model_hf_id.lower()) + + if registered_model_name == model_name: logger.debug( - f"Resolved model name '{model_hf_path}' from partial path match." + f"Resolved model name '{registered_model_hf_id}' from partial path match." ) - model_id = _MODEL_HF_PATH_TO_NAME[model_hf_path] + model_id = _MODEL_HF_PATH_TO_NAME[registered_model_hf_id] return _CONFIG_REGISTRY.get(model_id) # 3. Use detectors diff --git a/python/sglang/multimodal_gen/runtime/loader/component_loader.py b/python/sglang/multimodal_gen/runtime/loader/component_loader.py index 9ae7abf1134b..5c311789e962 100644 --- a/python/sglang/multimodal_gen/runtime/loader/component_loader.py +++ b/python/sglang/multimodal_gen/runtime/loader/component_loader.py @@ -120,7 +120,8 @@ def __init__(self, device=None) -> None: self.device = device def should_offload(self, server_args, model_config: ModelConfig | None = None): - raise NotImplementedError() + # offload by default + return True def target_device(self, should_offload): if should_offload: