Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion python/sglang/srt/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@
)
from sglang.srt.utils import add_prefix, flatten_nested_list, logger

_KNOWN_BROKEN_AUTOMODEL_CONFIG = "VoxtralRealtimeTextConfig"
_KNOWN_BROKEN_AUTOMODEL_ERROR = "Could not find VoxtralRealtimeTextModel"


class LlavaBaseForCausalLM(nn.Module):
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
Expand Down Expand Up @@ -657,7 +660,22 @@ def _config_cls_name_to_arch_name_mapping(
) -> Dict[str, str]:
mapping = {}
for config_cls in auto_model_type._model_mapping.keys():
archs = auto_model_type._model_mapping.get(config_cls, None)
try:
archs = auto_model_type._model_mapping.get(config_cls, None)
except ValueError as exc:
if (
auto_model_type is not AutoModel
or config_cls.__name__ != _KNOWN_BROKEN_AUTOMODEL_CONFIG
or _KNOWN_BROKEN_AUTOMODEL_ERROR not in str(exc)
):
raise
logger.warning(
"Skipping broken %s mapping for config %s: %s",
auto_model_type.__name__,
config_cls.__name__,
exc,
)
continue
if archs is not None:
if isinstance(archs, tuple):
mapping[config_cls.__name__] = tuple(
Expand Down
91 changes: 91 additions & 0 deletions test/registered/unit/models/test_llava.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import unittest
from unittest.mock import patch

from sglang.srt.models.llava import AutoModel, LlavaForConditionalGeneration
from sglang.test.ci.ci_register import register_cuda_ci
from sglang.test.test_utils import CustomTestCase

register_cuda_ci(est_time=1, suite="stage-b-test-1-gpu-small")


class PixtralVisionConfig:
pass


class VoxtralRealtimeTextConfig:
pass


class GoodConfig:
pass


class PixtralVisionModel:
pass


class GoodArch:
pass


class FakeMapping:
def __init__(self, voxtral_error):
self.voxtral_error = voxtral_error

def keys(self):
return [VoxtralRealtimeTextConfig, PixtralVisionConfig, GoodConfig]

def get(self, config_cls, default=None):
if config_cls is VoxtralRealtimeTextConfig:
raise self.voxtral_error
if config_cls is PixtralVisionConfig:
return (PixtralVisionModel,)
if config_cls is GoodConfig:
return GoodArch
return default


KNOWN_VOXTRAL_ERROR = ValueError(
"Could not find VoxtralRealtimeTextModel neither in "
"<module 'transformers.models.voxtral_realtime'> nor in "
"<module 'transformers'>!"
)


class TestLlavaForConditionalGeneration(CustomTestCase):
def setUp(self):
LlavaForConditionalGeneration._config_cls_name_to_arch_name_mapping.cache_clear()

def _build_mapping(self, mapping):
with patch.object(AutoModel, "_model_mapping", mapping):
llava_model = object.__new__(LlavaForConditionalGeneration)
return llava_model._config_cls_name_to_arch_name_mapping(AutoModel)

@patch("sglang.srt.models.llava.logger.warning")
def test_skip_known_broken_voxtral_automodel_mapping_entry(self, mock_warning):
mapping = self._build_mapping(FakeMapping(KNOWN_VOXTRAL_ERROR))

self.assertEqual(mapping[GoodConfig.__name__], GoodArch.__name__)
self.assertEqual(
mapping[PixtralVisionConfig.__name__], (PixtralVisionModel.__name__,)
)
self.assertNotIn(VoxtralRealtimeTextConfig.__name__, mapping)

mock_warning.assert_called_once()
self.assertEqual(
mock_warning.call_args.args,
(
"Skipping broken %s mapping for config %s: %s",
AutoModel.__name__,
VoxtralRealtimeTextConfig.__name__,
unittest.mock.ANY,
),
)

def test_other_voxtral_mapping_failures_still_raise(self):
with self.assertRaisesRegex(ValueError, "some other failure"):
self._build_mapping(FakeMapping(ValueError("some other failure")))


if __name__ == "__main__":
unittest.main()
Loading