Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions python/sglang/srt/layers/torchao_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,17 @@ def proj_filter(
return "proj" in fqn


# TODO: implement a more general filter function
def proj_filter_conv3d(
module: torch.nn.Module,
fqn: str,
):
if isinstance(module, torch.nn.Conv3d):
logger.warning(f"Quantize: skipping {fqn} because it's a Conv3d")
return False
return "proj" in fqn


def apply_torchao_config_to_model(
model: torch.nn.Module,
torchao_config: str,
Expand Down Expand Up @@ -63,7 +74,7 @@ def apply_torchao_config_to_model(
if torchao_config == "" or torchao_config is None:
return model
elif "int8wo" in torchao_config:
quantize_(model, int8_weight_only(), filter_fn=filter_fn)
quantize_(model, int8_weight_only(), filter_fn=proj_filter_conv3d)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Instead of hardcoding proj_filter_conv3d here and in other quantization branches, consider making it the default filter_fn for the apply_torchao_config_to_model function. This would make the fix more robust by applying it to other quantization methods like int4wo and int8dq as well, and would simplify the code.

You could change the function signature on line 49 to:

filter_fn: Optional[Callable] = proj_filter_conv3d,

And then change this line and others to use the default filter_fn.

Suggested change
quantize_(model, int8_weight_only(), filter_fn=proj_filter_conv3d)
quantize_(model, int8_weight_only(), filter_fn=filter_fn)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the default filter_fn is needed by other quantization methods.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why delete filter_fn here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to the need to determine whether to use the original filter or the new filter based on different quantization configurations, passing this decision from an upper layer would create some coupling, as the upper layer would need to be aware of the different quantization settings. However, we still need to consider how to handle modules like Deconv3D. The current proj_filter_conv3d approach is somewhat hacky; it would be appropriate to add a comment or a TODO note here.

elif "int8dq" in torchao_config:
quantize_(model, int8_dynamic_activation_int8_weight(), filter_fn=filter_fn)
elif "int4wo" in torchao_config:
Expand Down Expand Up @@ -101,7 +112,7 @@ def apply_torchao_config_to_model(
elif "fp8wo" in torchao_config:
# this requires newer hardware
# [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
quantize_(model, float8_weight_only(), filter_fn=filter_fn)
quantize_(model, float8_weight_only(), filter_fn=proj_filter_conv3d)
elif "fp8dq" in torchao_config:
granularity = torchao_config.split("-")[-1]
GRANULARITY_MAP = {
Expand All @@ -116,7 +127,7 @@ def apply_torchao_config_to_model(
float8_dynamic_activation_float8_weight(
granularity=GRANULARITY_MAP[granularity]
),
filter_fn=filter_fn,
filter_fn=proj_filter_conv3d,
)
else:
raise ValueError(f"Unexpected config: {torchao_config}")
Expand Down
21 changes: 21 additions & 0 deletions test/srt/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@

import requests

from sglang import Engine
from sglang.lang.chat_template import get_chat_template_by_model_path
from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_IMAGE_URL,
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
Expand Down Expand Up @@ -70,5 +74,22 @@ def test_throughput(self):
assert throughput >= 210


class TestTorchAOForVLM(CustomTestCase):
def test_vlm_generate(self):
model_path = DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST
chat_template = get_chat_template_by_model_path(model_path)
text = f"{chat_template.image_token}What is in this picture? Answer: "

engine = Engine(
model_path=model_path,
max_total_tokens=512,
enable_multimodal=True,
torchao_config="fp8wo",
)
out = engine.generate([text], image_data=[DEFAULT_IMAGE_URL])
engine.shutdown()
self.assertGreater(len(out), 0)


if __name__ == "__main__":
unittest.main()
Loading