-
Notifications
You must be signed in to change notification settings - Fork 644
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add torchao quant (int4/int8/fp8) to llama models (#1341)
Co-authored-by: Lianmin Zheng <[email protected]>
- Loading branch information
1 parent
e4d68af
commit a7c47e0
Showing
10 changed files
with
151 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
""" | ||
Common utilities for torchao. | ||
""" | ||
|
||
import torch | ||
from torchao.quantization import ( | ||
int4_weight_only, | ||
int8_dynamic_activation_int8_weight, | ||
int8_weight_only, | ||
quantize_, | ||
) | ||
|
||
|
||
def torchao_quantize_param_data(param, torchao_config): | ||
dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False) | ||
dummy_linear.weight = param | ||
if "int8wo" in torchao_config: | ||
quantize_(dummy_linear, int8_weight_only()) | ||
elif "int8dq" in torchao_config: | ||
quantize_(dummy_linear, int8_dynamic_activation_int8_weight()) | ||
elif "int4wo" in torchao_config: | ||
group_size = int(torchao_config.split("-")[-1]) | ||
assert group_size in [ | ||
32, | ||
64, | ||
128, | ||
256, | ||
], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}" | ||
quantize_(dummy_linear, int4_weight_only(group_size=group_size)) | ||
elif "fp8wo" in torchao_config: | ||
from torchao.quantization import float8_weight_only | ||
|
||
# this requires newer hardware | ||
# [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89 | ||
quantize_(dummy_linear, float8_weight_only()) | ||
return dummy_linear.weight |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import unittest | ||
from types import SimpleNamespace | ||
|
||
import requests | ||
|
||
from sglang.srt.utils import kill_child_process | ||
from sglang.test.run_eval import run_eval | ||
from sglang.test.test_utils import ( | ||
DEFAULT_MODEL_NAME_FOR_TEST, | ||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, | ||
DEFAULT_URL_FOR_TEST, | ||
popen_launch_server, | ||
) | ||
|
||
|
||
class TestTorchCompile(unittest.TestCase): | ||
@classmethod | ||
def setUpClass(cls): | ||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST | ||
cls.base_url = DEFAULT_URL_FOR_TEST | ||
cls.process = popen_launch_server( | ||
cls.model, | ||
cls.base_url, | ||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, | ||
other_args=["--torchao-config", "int4wo-128"], | ||
) | ||
|
||
@classmethod | ||
def tearDownClass(cls): | ||
kill_child_process(cls.process.pid) | ||
|
||
def test_mmlu(self): | ||
args = SimpleNamespace( | ||
base_url=self.base_url, | ||
model=self.model, | ||
eval_name="mmlu", | ||
num_examples=64, | ||
num_threads=32, | ||
) | ||
|
||
metrics = run_eval(args) | ||
assert metrics["score"] >= 0.65 | ||
|
||
def run_decode(self, max_new_tokens): | ||
response = requests.post( | ||
self.base_url + "/generate", | ||
json={ | ||
"text": "The capital of France is", | ||
"sampling_params": { | ||
"temperature": 0, | ||
"max_new_tokens": max_new_tokens, | ||
}, | ||
"ignore_eos": True, | ||
}, | ||
) | ||
return response.json() | ||
|
||
def test_throughput(self): | ||
import time | ||
|
||
max_tokens = 256 | ||
|
||
tic = time.time() | ||
res = self.run_decode(max_tokens) | ||
tok = time.time() | ||
print(res["text"]) | ||
throughput = max_tokens / (tok - tic) | ||
print(f"Throughput: {throughput} tokens/s") | ||
assert throughput >= 210 | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters