Skip to content

Autoquant v2 initial version #1240

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Nov 21, 2024
102 changes: 86 additions & 16 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,28 +205,31 @@ def main(


if quantization:
from torchao.quantization.quant_api import (
from torchao.quantization import (
quantize_,
autoquant,
int8_weight_only,
int8_dynamic_activation_int8_weight,
int4_weight_only,
int8_dynamic_activation_int4_weight,
fpx_weight_only,
uintx_weight_only,
autoquant,
float8_weight_only,
float8_dynamic_activation_float8_weight,
)
from torchao.prototype.quantization.autoquant_v2 import autoquant_v2
from torchao.utils import unwrap_tensor_subclass

from torchao.quantization.granularity import PerTensor, PerRow
from torchao.utils import unwrap_tensor_subclass
if "spinquant" in quantization:
from torchao.prototype.spinquant import apply_spinquant
apply_spinquant(model)
if "int8wo" in quantization:
quantize_(model, int8_weight_only())
if "int8dq" in quantization:
elif "int8dq" in quantization:
quantize_(model, int8_dynamic_activation_int8_weight())
if "int4wo" in quantization:
elif "int4wo" in quantization:
if "hqq" in quantization:
use_hqq=True
else:
Expand All @@ -246,14 +249,14 @@ def main(
layout=MarlinQQQLayout(),
),
)
else:
else:
from torchao.dtypes import MarlinSparseLayout
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
if "fp6" in quantization:
quantize_(model, fpx_weight_only(3, 2))
if "embed-int8wo" in quantization:
elif "embed-int8wo" in quantization:
quantize_(model, int8_weight_only(group_size=64), filter_fn=lambda x, *args: isinstance(x, torch.nn.Embedding))
if quantization.startswith("awq"):
elif quantization.startswith("awq"):
from torchao._models._eval import TransformerEvalWrapper
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3
from torchao.prototype.awq.example import get_calib_dataset
Expand All @@ -274,13 +277,13 @@ def main(
input_prep_func=prepare_inputs_for_model,
device=device,
).run_eval(
tasks=['wikitext'],
tasks=['wikitext'],
limit=1,
)
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
use_hqq = "hqq" in quantization
quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size, use_hqq=use_hqq), is_observed_linear)
if "uintx" in quantization:
elif "uintx" in quantization:
# uintx-nbits-group_size, e.g. "uintx-2-64"
if "hqq" in quantization:
# uintx-nbits-group_size-hqq
Expand All @@ -294,9 +297,9 @@ def main(
dtype = _NBITS_TO_DTYPE[nbits]
group_size = int(_quant_args[2])
quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq))
if "float8wo" in quantization:
elif "float8wo" in quantization:
quantize_(model, float8_weight_only())
if "float8dq" in quantization:
elif "float8dq" in quantization:
granularity = str(quantization.split("-")[-1])
if granularity=="tensor":
granularity = PerTensor()
Expand All @@ -305,13 +308,79 @@ def main(
else:
granularity = PerTensor()
quantize_(model, float8_dynamic_activation_float8_weight(granularity=granularity))
if "autoquant" in quantization:
elif "autoquant_v2" in quantization:
from torchao._models._eval import InputRecorder
from torchao._models.llama.model import prepare_inputs_for_model

calibration_seq_length = 256
calibration_limit = 1
inputs = InputRecorder(
tokenizer,
calibration_seq_length,
prepare_inputs_for_model,
False, # pad_calibration_inputs
model.config.vocab_size,
device="cuda"
).record_inputs(
["wikitext"],
1,
).get_inputs()[0].values[0]
inputs = prepare_inputs_for_model(inputs)
with torch.device("cuda"):
model.setup_caches(
max_batch_size=1, max_seq_length=calibration_seq_length
)

if "autoquant_v2-int4" == quantization:
model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST, example_input=inputs)
elif "autoquant_v2-float8" == quantization:
model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST, example_input=inputs)
else:
model = autoquant_v2(model, manual=True, example_input=inputs)

print("running generate")
generate(
model,
encode_tokens(tokenizer, prompt, bos=True, device=device),
max_new_tokens,
batch_size,
interactive=False,
temperature=temperature,
top_k=top_k,
)

print("running finalize autoquant")
# do autoquantization
model.finalize_autoquant()
elif "autoquant" in quantization:
from torchao._models._eval import InputRecorder
from torchao._models.llama.model import prepare_inputs_for_model

calibration_seq_length = 256
calibration_limit = 1
inputs = InputRecorder(
tokenizer,
calibration_seq_length,
prepare_inputs_for_model,
False, # pad_calibration_inputs
model.config.vocab_size,
device="cuda"
).record_inputs(
["wikitext"],
1,
).get_inputs()[0].values[0]
inputs = prepare_inputs_for_model(inputs)
with torch.device("cuda"):
model.setup_caches(
max_batch_size=1, max_seq_length=calibration_seq_length
)

if "autoquant-int4" == quantization:
model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST)
model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST, example_input=inputs)
elif "autoquant-float8" == quantization:
model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST)
model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST, example_input=inputs)
else:
model = autoquant(model, manual=True)
model = autoquant(model, manual=True, example_input=inputs)

generate(
model,
Expand All @@ -325,6 +394,7 @@ def main(

# do autoquantization
model.finalize_autoquant()

else:
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(model)
Expand Down Expand Up @@ -489,7 +559,7 @@ def callback(x):
parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.')
parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
parser.add_argument('-q', '--quantization', type=str,
parser.add_argument('-q', '--quantization', type=str,
help=(
'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, '
+'autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, '
Expand Down
4 changes: 3 additions & 1 deletion torchao/_models/sam/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,7 @@ sh setup.sh

Finally, you can run benchmarks with
```
sh benchmark_sam.sh
sh benchmark.sh
```

You can check out the result in results.csv
32 changes: 31 additions & 1 deletion torchao/_models/sam/eval_combo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,14 @@
import time
import resource

from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight, int4_weight_only
import torchao
from torchao.quantization import (
quantize_,
int8_dynamic_activation_int8_weight,
int4_weight_only,
autoquant,
)
from torchao.prototype.quantization.autoquant_v2 import autoquant_v2
from torchao.sparsity import sparsify_, apply_fake_sparsity, semi_sparse_weight
from torchao.dtypes import SemiSparseLayout, MarlinSparseLayout
from torchao.utils import unwrap_tensor_subclass
Expand Down Expand Up @@ -336,6 +343,29 @@ def mlp_only(mod, name):
mlp_lin2_only)
if not TORCH_VERSION_AT_LEAST_2_5:
predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder)

elif compress is not None and "autoquant_v2" in compress:
example_input = torch.randn(1, 3, 1024, 1024, dtype=torch.bfloat16, device=device)
if "autoquant_v2-int4" == compress:
autoquant_v2(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST)
elif "autoquant_v2-float8" == compress:
autoquant_v2(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST)
else:
autoquant_v2(predictor.model.image_encoder, example_input=example_input, manual=True)

predictor.model.image_encoder(example_input)
predictor.model.image_encoder.finalize_autoquant()

elif compress is not None and "autoquant" in compress:
example_input = torch.randn(1, 3, 1024, 1024, dtype=torch.bfloat16, device=device)
if "autoquant-int4" == compress:
autoquant(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST)
elif "autoquant-float8" == compress:
autoquant(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST)
else:
autoquant(predictor.model.image_encoder, example_input=example_input, manual=True)
predictor.model.image_encoder(example_input)
predictor.model.image_encoder.finalize_autoquant()
else:
assert compress is None, f"Unsupported compress mode {compress}"

Expand Down
Loading
Loading