Skip to content

Commit

Permalink
Add torchao quant for mixtral
Browse files Browse the repository at this point in the history
Summary:
Similar to sgl-project#1341 we add torchao quantization to mixtral model

Test Plan:
Note: compile is not working yet, and I can't install torchnightly locally and make it work either.
I'll wait for pytorch 2.5 release which happens in mid Oct, or check that again later

python3 -m sglang.bench_latency --model Qwen/Qwen1.5-MoE-A2.7B --batch-size 1 --input 128 --output 8

Warmup ...
Prefill. latency: 0.05532 s, throughput:   2313.73 token/s
Decode.  latency: 0.00896 s, throughput:    111.65 token/s
Decode.  latency: 0.00833 s, throughput:    120.04 token/s
Decode.  latency: 0.00869 s, throughput:    115.06 token/s
Decode.  latency: 0.00842 s, throughput:    118.79 token/s
Decode.  median latency: 0.00855 s, median throughput:    116.89 token/s
Total. latency:  0.090 s, throughput:   1471.26 token/s
Benchmark ...
Prefill. latency: 0.04294 s, throughput:   2980.61 token/s
Decode.  latency: 0.00839 s, throughput:    119.12 token/s
Decode.  latency: 0.00828 s, throughput:    120.78 token/s
Decode.  latency: 0.00857 s, throughput:    116.64 token/s
Decode.  latency: 0.00853 s, throughput:    117.19 token/s
Decode.  latency: 0.00859 s, throughput:    116.39 token/s
Decode.  median latency: 0.00853 s, median throughput:    117.17 token/s
Total. latency:  0.111 s, throughput:   1226.84 token/s

python3 -m sglang.bench_latency --model Qwen/Qwen1.5-MoE-A2.7B --batch-size 1 --input 128 --output 8 --torchao-config int4wo-128

Warmup ...
Prefill. latency: 0.06413 s, throughput:   1996.05 token/s
Decode.  latency: 0.00764 s, throughput:    130.84 token/s
Decode.  latency: 0.00748 s, throughput:    133.73 token/s
Decode.  latency: 0.00725 s, throughput:    137.84 token/s
Decode.  latency: 0.00721 s, throughput:    138.74 token/s
Decode.  median latency: 0.00737 s, median throughput:    135.76 token/s
Total. latency:  0.094 s, throughput:   1408.61 token/s
Benchmark ...
Prefill. latency: 0.05239 s, throughput:   2443.43 token/s
Decode.  latency: 0.00739 s, throughput:    135.25 token/s
Decode.  latency: 0.00720 s, throughput:    138.90 token/s
Decode.  latency: 0.00718 s, throughput:    139.21 token/s
Decode.  latency: 0.00722 s, throughput:    138.42 token/s
Decode.  latency: 0.00745 s, throughput:    134.30 token/s
Decode.  median latency: 0.00731 s, median throughput:    136.82 token/s
Total. latency:  0.111 s, throughput:   1223.51 token/s

A100, no compile
python3 -m sglang.bench_latency --model Qwen/Qwen1.5-MoE-A2.7B --batch-size 1 --input 128 --output 8 --torchao-config fp8wo
max_total_num_tokens=199454
Warmup ...
Prefill. latency: 0.06958 s, throughput:   1839.60 token/s
Decode.  latency: 0.02343 s, throughput:     42.68 token/s
Decode.  latency: 0.02342 s, throughput:     42.70 token/s
Decode.  latency: 0.02368 s, throughput:     42.23 token/s
Decode.  latency: 0.02337 s, throughput:     42.80 token/s
Decode.  median latency: 0.02342 s, median throughput:     42.69 token/s
Total. latency:  0.163 s, throughput:    807.48 token/s
Benchmark ...
Prefill. latency: 0.05767 s, throughput:   2219.36 token/s
Decode.  latency: 0.02293 s, throughput:     43.61 token/s
Decode.  latency: 0.02026 s, throughput:     49.36 token/s
Decode.  latency: 0.02029 s, throughput:     49.29 token/s
Decode.  latency: 0.02024 s, throughput:     49.41 token/s
Decode.  latency: 0.02026 s, throughput:     49.36 token/s
Decode.  median latency: 0.02025 s, median throughput:     49.39 token/s
Total. latency:  0.222 s, throughput:    611.87 token/s

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Sep 14, 2024
1 parent 70b6802 commit d40fe1c
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 2 deletions.
4 changes: 2 additions & 2 deletions python/sglang/srt/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,8 +416,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params = set(entry[0] for entry in stacked_params_mapping)
for param_suffix in stacked_params:
for name in params_dict:
if param_suffix in name:
param = params_dict[name]
param = params_dict[name]
if param_suffix in name and name.endswith("proj.weight") and param.ndim == 2:
params_dict[name] = torchao_quantize_param_data(
param, self.torchao_config
)
Expand Down
22 changes: 22 additions & 0 deletions python/sglang/srt/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.torchao_utils import torchao_quantize_param_data
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata


Expand Down Expand Up @@ -296,6 +298,7 @@ def __init__(
super().__init__()
self.config = config
self.quant_config = quant_config
self.torchao_config = global_server_args_dict["torchao_config"]
self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
Expand Down Expand Up @@ -375,6 +378,25 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
if self.torchao_config:
if name.endswith("proj.weight") and param.ndim == 2:
params_dict[name] = torchao_quantize_param_data(
param, self.torchao_config
)

if self.torchao_config:
# quantizing the loaded, stacked params, e.g. "...qkv_proj"
stacked_params = set(entry[0] for entry in stacked_params_mapping)
for param_suffix in stacked_params:
for name in params_dict:
param = params_dict[name]
if param_suffix in name and name.endswith("proj.weight") and param.ndim == 2:
params_dict[name] = torchao_quantize_param_data(
param, self.torchao_config
)

self.load_state_dict(params_dict, assign=True)



EntryClass = MixtralForCausalLM
21 changes: 21 additions & 0 deletions python/sglang/srt/models/qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.torchao_utils import torchao_quantize_param_data
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata


Expand Down Expand Up @@ -359,6 +361,7 @@ def __init__(
super().__init__()
self.config = config
self.quant_config = quant_config
self.torchao_config = global_server_args_dict["torchao_config"]
self.model = Qwen2MoeModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
Expand Down Expand Up @@ -450,6 +453,24 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
if self.torchao_config:
if name.endswith("proj.weight") and param.ndim == 2:
params_dict[name] = torchao_quantize_param_data(
param, self.torchao_config
)
if self.torchao_config:
# quantizing the loaded, stacked params, e.g. "...qkv_proj"
stacked_params = set(entry[0] for entry in stacked_params_mapping)
stacked_params.union(set(entry[0] for entry in expert_params_mapping))
for param_suffix in stacked_params:
for name in params_dict:
param = params_dict[name]
if param_suffix in name and param.ndim == 2:
params_dict[name] = torchao_quantize_param_data(
param, self.torchao_config
)

self.load_state_dict(params_dict, assign=True)


EntryClass = Qwen2MoeForCausalLM

0 comments on commit d40fe1c

Please sign in to comment.