diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index c76312c0a833..57af26cff8aa 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -654,17 +654,16 @@ def _get_sgl_model_cls(self, config, auto_model_type: Type[AutoModel] = AutoMode @lru_cache def _config_cls_name_to_arch_name_mapping( self, auto_model_type: Type[AutoModel] - ) -> Dict[str, str]: + ) -> Dict[str, Union[str, Tuple[str, ...]]]: mapping = {} - for config_cls in auto_model_type._model_mapping.keys(): - archs = auto_model_type._model_mapping.get(config_cls, None) + lazy_mapping = auto_model_type._model_mapping + raw_config_mapping = lazy_mapping._config_mapping + raw_model_mapping = lazy_mapping._model_mapping + + for model_type, config_cls_name in raw_config_mapping.items(): + archs = raw_model_mapping.get(model_type, None) if archs is not None: - if isinstance(archs, tuple): - mapping[config_cls.__name__] = tuple( - arch.__name__ for arch in archs - ) - else: - mapping[config_cls.__name__] = archs.__name__ + mapping[config_cls_name] = archs return mapping def __init__(