Skip to content

Commit

Permalink
rename autoquant v2
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryzh168 committed Nov 19, 2024
1 parent b8f6dcc commit cb72ebc
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 12 deletions.
3 changes: 2 additions & 1 deletion torchao/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from torchao.quantization import (
autoquant,
autoquant_v2,
_autoquant_v2,
quantize_,
)
from . import dtypes
Expand All @@ -41,6 +41,7 @@
__all__ = [
"dtypes",
"autoquant",
"_autoquant_v2",
"quantize_",
"testing",
]
Expand Down
10 changes: 5 additions & 5 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def main(


if quantization:
from torchao.quantization.quant_api import (
from torchao.quantization import (
quantize_,
int8_weight_only,
int8_dynamic_activation_int8_weight,
Expand All @@ -216,7 +216,7 @@ def main(
fpx_weight_only,
uintx_weight_only,
autoquant,
autoquant_v2,
_autoquant_v2,
unwrap_tensor_subclass,
float8_weight_only,
float8_dynamic_activation_float8_weight,
Expand Down Expand Up @@ -330,11 +330,11 @@ def main(
)

if "autoquant_v2-int4" == quantization:
model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.quantization.V2_DEFAULT_INT4_AUTOQUANT_CLASS_LIST, example_input=inputs)
model = _autoquant_v2(model, manual=True, qtensor_class_list = torchao.quantization.V2_DEFAULT_INT4_AUTOQUANT_CLASS_LIST, example_input=inputs)
elif "autoquant_v2-float8" == quantization:
model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.quantization.V2_OTHER_AUTOQUANT_CLASS_LIST, example_input=inputs)
model = _autoquant_v2(model, manual=True, qtensor_class_list = torchao.quantization.V2_OTHER_AUTOQUANT_CLASS_LIST, example_input=inputs)
else:
model = autoquant_v2(model, manual=True, example_input=inputs)
model = _autoquant_v2(model, manual=True, example_input=inputs)

print("running generate")
generate(
Expand Down
8 changes: 4 additions & 4 deletions torchao/_models/sam/eval_combo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
int8_dynamic_activation_int8_weight,
int4_weight_only,
autoquant,
autoquant_v2,
_autoquant_v2,
)
from torchao.sparsity import sparsify_, apply_fake_sparsity, semi_sparse_weight
from torchao.dtypes import SemiSparseLayout, MarlinSparseLayout
Expand Down Expand Up @@ -347,11 +347,11 @@ def mlp_only(mod, name):
elif compress is not None and "autoquant_v2" in compress:
example_input = torch.randn(1, 3, 1024, 1024, dtype=torch.bfloat16, device=device)
if "autoquant_v2-int4" == compress:
autoquant_v2(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.V2_DEFAULT_INT4_AUTOQUANT_CLASS_LIST)
_autoquant_v2(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.V2_DEFAULT_INT4_AUTOQUANT_CLASS_LIST)
elif "autoquant_v2-float8" == compress:
autoquant_v2(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.V2_OTHER_AUTOQUANT_CLASS_LIST)
_autoquant_v2(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.V2_OTHER_AUTOQUANT_CLASS_LIST)
else:
autoquant_v2(predictor.model.image_encoder, example_input=example_input, manual=True)
_autoquant_v2(predictor.model.image_encoder, example_input=example_input, manual=True)

predictor.model.image_encoder(example_input)
predictor.model.image_encoder.finalize_autoquant()
Expand Down
4 changes: 2 additions & 2 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
OTHER_AUTOQUANT_CLASS_LIST as V2_OTHER_AUTOQUANT_CLASS_LIST,
)
from .autoquant_v2 import (
autoquant_v2,
autoquant_v2 as _autoquant_v2,
)
from .GPTQ import (
Int4WeightOnlyGPTQQuantizer,
Expand Down Expand Up @@ -103,7 +103,7 @@
"DEFAULT_INT4_AUTOQUANT_CLASS_LIST",
"OTHER_AUTOQUANT_CLASS_LIST",
# experimental api
"autoquant_v2",
"_autoquant_v2",
"V2_DEFAULT_AUTOQUANT_CLASS_LIST",
"V2_DEFAULT_INT4_AUTOQUANT_CLASS_LIST",
"V2_OTHER_AUTOQUANT_CLASS_LIST",
Expand Down

0 comments on commit cb72ebc

Please sign in to comment.