diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index c76312c0a833..2d1e69dbcdfd 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -53,6 +53,9 @@ ) from sglang.srt.utils import add_prefix, flatten_nested_list, logger +_KNOWN_BROKEN_AUTOMODEL_CONFIG = "VoxtralRealtimeTextConfig" +_KNOWN_BROKEN_AUTOMODEL_ERROR = "Could not find VoxtralRealtimeTextModel" + class LlavaBaseForCausalLM(nn.Module): def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs): @@ -657,7 +660,22 @@ def _config_cls_name_to_arch_name_mapping( ) -> Dict[str, str]: mapping = {} for config_cls in auto_model_type._model_mapping.keys(): - archs = auto_model_type._model_mapping.get(config_cls, None) + try: + archs = auto_model_type._model_mapping.get(config_cls, None) + except ValueError as exc: + if ( + auto_model_type is not AutoModel + or config_cls.__name__ != _KNOWN_BROKEN_AUTOMODEL_CONFIG + or _KNOWN_BROKEN_AUTOMODEL_ERROR not in str(exc) + ): + raise + logger.warning( + "Skipping broken %s mapping for config %s: %s", + auto_model_type.__name__, + config_cls.__name__, + exc, + ) + continue if archs is not None: if isinstance(archs, tuple): mapping[config_cls.__name__] = tuple( diff --git a/test/registered/unit/models/test_llava.py b/test/registered/unit/models/test_llava.py new file mode 100644 index 000000000000..a929dcfbd811 --- /dev/null +++ b/test/registered/unit/models/test_llava.py @@ -0,0 +1,91 @@ +import unittest +from unittest.mock import patch + +from sglang.srt.models.llava import AutoModel, LlavaForConditionalGeneration +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.test_utils import CustomTestCase + +register_cuda_ci(est_time=1, suite="stage-b-test-1-gpu-small") + + +class PixtralVisionConfig: + pass + + +class VoxtralRealtimeTextConfig: + pass + + +class GoodConfig: + pass + + +class PixtralVisionModel: + pass + + +class GoodArch: + pass + + +class FakeMapping: + def __init__(self, voxtral_error): + self.voxtral_error = voxtral_error + + def keys(self): + return [VoxtralRealtimeTextConfig, PixtralVisionConfig, GoodConfig] + + def get(self, config_cls, default=None): + if config_cls is VoxtralRealtimeTextConfig: + raise self.voxtral_error + if config_cls is PixtralVisionConfig: + return (PixtralVisionModel,) + if config_cls is GoodConfig: + return GoodArch + return default + + +KNOWN_VOXTRAL_ERROR = ValueError( + "Could not find VoxtralRealtimeTextModel neither in " + " nor in " + "!" +) + + +class TestLlavaForConditionalGeneration(CustomTestCase): + def setUp(self): + LlavaForConditionalGeneration._config_cls_name_to_arch_name_mapping.cache_clear() + + def _build_mapping(self, mapping): + with patch.object(AutoModel, "_model_mapping", mapping): + llava_model = object.__new__(LlavaForConditionalGeneration) + return llava_model._config_cls_name_to_arch_name_mapping(AutoModel) + + @patch("sglang.srt.models.llava.logger.warning") + def test_skip_known_broken_voxtral_automodel_mapping_entry(self, mock_warning): + mapping = self._build_mapping(FakeMapping(KNOWN_VOXTRAL_ERROR)) + + self.assertEqual(mapping[GoodConfig.__name__], GoodArch.__name__) + self.assertEqual( + mapping[PixtralVisionConfig.__name__], (PixtralVisionModel.__name__,) + ) + self.assertNotIn(VoxtralRealtimeTextConfig.__name__, mapping) + + mock_warning.assert_called_once() + self.assertEqual( + mock_warning.call_args.args, + ( + "Skipping broken %s mapping for config %s: %s", + AutoModel.__name__, + VoxtralRealtimeTextConfig.__name__, + unittest.mock.ANY, + ), + ) + + def test_other_voxtral_mapping_failures_still_raise(self): + with self.assertRaisesRegex(ValueError, "some other failure"): + self._build_mapping(FakeMapping(ValueError("some other failure"))) + + +if __name__ == "__main__": + unittest.main()