diff --git a/python/sglang/srt/layers/torchao_utils.py b/python/sglang/srt/layers/torchao_utils.py index e08abd5ae1d5..2c97159a9472 100644 --- a/python/sglang/srt/layers/torchao_utils.py +++ b/python/sglang/srt/layers/torchao_utils.py @@ -36,6 +36,17 @@ def proj_filter( return "proj" in fqn +# TODO: implement a more general filter function +def proj_filter_conv3d( + module: torch.nn.Module, + fqn: str, +): + if isinstance(module, torch.nn.Conv3d): + logger.warning(f"Quantize: skipping {fqn} because it's a Conv3d") + return False + return "proj" in fqn + + def apply_torchao_config_to_model( model: torch.nn.Module, torchao_config: str, @@ -63,7 +74,7 @@ def apply_torchao_config_to_model( if torchao_config == "" or torchao_config is None: return model elif "int8wo" in torchao_config: - quantize_(model, int8_weight_only(), filter_fn=filter_fn) + quantize_(model, int8_weight_only(), filter_fn=proj_filter_conv3d) elif "int8dq" in torchao_config: quantize_(model, int8_dynamic_activation_int8_weight(), filter_fn=filter_fn) elif "int4wo" in torchao_config: @@ -101,7 +112,7 @@ def apply_torchao_config_to_model( elif "fp8wo" in torchao_config: # this requires newer hardware # [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89 - quantize_(model, float8_weight_only(), filter_fn=filter_fn) + quantize_(model, float8_weight_only(), filter_fn=proj_filter_conv3d) elif "fp8dq" in torchao_config: granularity = torchao_config.split("-")[-1] GRANULARITY_MAP = { @@ -116,7 +127,7 @@ def apply_torchao_config_to_model( float8_dynamic_activation_float8_weight( granularity=GRANULARITY_MAP[granularity] ), - filter_fn=filter_fn, + filter_fn=proj_filter_conv3d, ) else: raise ValueError(f"Unexpected config: {torchao_config}") diff --git a/test/srt/test_torchao.py b/test/srt/test_torchao.py index 13c7b60b5cbe..53368aaa45ba 100644 --- a/test/srt/test_torchao.py +++ b/test/srt/test_torchao.py @@ -3,10 +3,14 @@ import requests +from sglang import Engine +from sglang.lang.chat_template import get_chat_template_by_model_path from sglang.srt.utils import kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( + DEFAULT_IMAGE_URL, DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, CustomTestCase, @@ -70,5 +74,22 @@ def test_throughput(self): assert throughput >= 210 +class TestTorchAOForVLM(CustomTestCase): + def test_vlm_generate(self): + model_path = DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST + chat_template = get_chat_template_by_model_path(model_path) + text = f"{chat_template.image_token}What is in this picture? Answer: " + + engine = Engine( + model_path=model_path, + max_total_tokens=512, + enable_multimodal=True, + torchao_config="fp8wo", + ) + out = engine.generate([text], image_data=[DEFAULT_IMAGE_URL]) + engine.shutdown() + self.assertGreater(len(out), 0) + + if __name__ == "__main__": unittest.main()