Skip to content

Commit 30b404c

Browse files
authored
Add torchao quant for mixtral and qwen_moe (#1418)
1 parent 70b6802 commit 30b404c

File tree

4 files changed

+50
-20
lines changed

4 files changed

+50
-20
lines changed

python/sglang/srt/layers/torchao_utils.py

+38-1
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,20 @@
22
Common utilities for torchao.
33
"""
44

5+
from typing import Dict, Set
6+
57
import torch
68

79

8-
def torchao_quantize_param_data(param, torchao_config):
10+
def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
11+
"""Quantize a Tensor with torchao quantization specified by torchao_config
12+
13+
Args:
14+
`param`: weight parameter of the linear module
15+
`torchao_config`: type of quantization and their arguments we want to use to
16+
quantize the Tensor, e.g. int4wo-128 means int4 weight only quantization with group_size
17+
128
18+
"""
919
# Lazy import to suppress some warnings
1020
from torchao.quantization import (
1121
int4_weight_only,
@@ -36,3 +46,30 @@ def torchao_quantize_param_data(param, torchao_config):
3646
# [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
3747
quantize_(dummy_linear, float8_weight_only())
3848
return dummy_linear.weight
49+
50+
51+
def apply_torchao_config_(
52+
self: torch.nn.Module,
53+
params_dict: Dict[str, torch.Tensor],
54+
param_suffixes: Set[str],
55+
) -> None:
56+
"""A util function used for quantizing the weight parameters after they are loaded if
57+
self.torchao_config is specified
58+
59+
Args:
60+
`self`: the model we want to quantize
61+
`params_dict`: dictionary mapping from param_name to the parameter Tensor
62+
`param_suffixes`: a set of suffixes, we'll quantize the Tensor matching these suffixes
63+
64+
Returns:
65+
None, the `params_dict` is modified inplace and the weights of `self` model are quantized
66+
"""
67+
if self.torchao_config:
68+
for param_suffix in param_suffixes:
69+
for name in params_dict:
70+
param = params_dict[name]
71+
if param_suffix in name and param.ndim == 2:
72+
params_dict[name] = torchao_quantize_param_data(
73+
param, self.torchao_config
74+
)
75+
self.load_state_dict(params_dict, assign=True)

python/sglang/srt/models/llama.py

+2-19
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from sglang.srt.layers.layernorm import RMSNorm
4242
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
4343
from sglang.srt.layers.radix_attention import RadixAttention
44-
from sglang.srt.layers.torchao_utils import torchao_quantize_param_data
44+
from sglang.srt.layers.torchao_utils import apply_torchao_config_
4545
from sglang.srt.managers.schedule_batch import global_server_args_dict
4646
from sglang.srt.model_executor.forward_batch_info import InputMetadata
4747

@@ -405,24 +405,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
405405
weight_loader = getattr(param, "weight_loader", default_weight_loader)
406406
weight_loader(param, loaded_weight)
407407

408-
if self.torchao_config:
409-
if name.endswith("proj.weight") and param.ndim == 2:
410-
params_dict[name] = torchao_quantize_param_data(
411-
param, self.torchao_config
412-
)
413-
414-
if self.torchao_config:
415-
# quantizing the loaded, stacked params, e.g. "...qkv_proj"
416-
stacked_params = set(entry[0] for entry in stacked_params_mapping)
417-
for param_suffix in stacked_params:
418-
for name in params_dict:
419-
if param_suffix in name:
420-
param = params_dict[name]
421-
params_dict[name] = torchao_quantize_param_data(
422-
param, self.torchao_config
423-
)
424-
425-
self.load_state_dict(params_dict, assign=True)
408+
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
426409

427410

428411
class Phi3ForCausalLM(LlamaForCausalLM):

python/sglang/srt/models/mixtral.py

+5
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
from sglang.srt.layers.layernorm import RMSNorm
4242
from sglang.srt.layers.logits_processor import LogitsProcessor
4343
from sglang.srt.layers.radix_attention import RadixAttention
44+
from sglang.srt.layers.torchao_utils import apply_torchao_config_
45+
from sglang.srt.managers.schedule_batch import global_server_args_dict
4446
from sglang.srt.model_executor.forward_batch_info import InputMetadata
4547

4648

@@ -296,6 +298,7 @@ def __init__(
296298
super().__init__()
297299
self.config = config
298300
self.quant_config = quant_config
301+
self.torchao_config = global_server_args_dict["torchao_config"]
299302
self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
300303
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
301304
self.logits_processor = LogitsProcessor(config)
@@ -376,5 +379,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
376379
)
377380
weight_loader(param, loaded_weight)
378381

382+
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
383+
379384

380385
EntryClass = MixtralForCausalLM

python/sglang/srt/models/qwen2_moe.py

+5
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
from sglang.srt.layers.layernorm import RMSNorm
4848
from sglang.srt.layers.logits_processor import LogitsProcessor
4949
from sglang.srt.layers.radix_attention import RadixAttention
50+
from sglang.srt.layers.torchao_utils import apply_torchao_config_
51+
from sglang.srt.managers.schedule_batch import global_server_args_dict
5052
from sglang.srt.model_executor.forward_batch_info import InputMetadata
5153

5254

@@ -359,6 +361,7 @@ def __init__(
359361
super().__init__()
360362
self.config = config
361363
self.quant_config = quant_config
364+
self.torchao_config = global_server_args_dict["torchao_config"]
362365
self.model = Qwen2MoeModel(config, cache_config, quant_config)
363366
self.lm_head = ParallelLMHead(
364367
config.vocab_size, config.hidden_size, quant_config=quant_config
@@ -451,5 +454,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
451454
)
452455
weight_loader(param, loaded_weight)
453456

457+
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
458+
454459

455460
EntryClass = Qwen2MoeForCausalLM

0 commit comments

Comments
 (0)