Skip to content

Commit

Permalink
Renaming quantize to quantize_ (#467)
Browse files Browse the repository at this point in the history
Summary:
Addressing feedback for `quantize` API from #391 (comment)

this is an API that changes model inplace, so we want to change the name to reflect that. inplace model quantization is important especially for LLM since it will be hard to load the model to memory. we typically load the model to meta device and then load the quantized weight.

Test Plan:
python test/quantization/test_quant_api.py
python test/integration/test_integration.py

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 authored Jul 4, 2024
1 parent cdb6e98 commit 6fa2d96
Show file tree
Hide file tree
Showing 12 changed files with 67 additions and 67 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ All with no intrusive code changes and minimal accuracy degradation.
Quantizing your models is a 1 liner that should work on any model with an `nn.Linear` including your favorite HuggingFace model. You can find a more comprehensive usage instructions [here](torchao/quantization/) and a HuggingFace inference example [here](scripts/hf_eval.py)

```python
from torchao.quantization.quant_api import quantize, int4_weight_only
m = quantize(m, int4_weight_only())
from torchao.quantization.quant_api import quantize_, int4_weight_only
quantize_(m, int4_weight_only())
```

Benchmarks are run on a machine with a single A100 GPU using the script in `_models/llama` which generates text in a latency-optimized way (batchsize=1)
Expand Down Expand Up @@ -83,7 +83,7 @@ swap_linear_with_semi_sparse_linear(model, {"seq.0": SemiSparseLinear})

* [MX](torchao/prototype/mx_formats) implementing training and inference support with tensors using the [OCP MX spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) data types, which can be described as groupwise scaled float8/float6/float4/int8, with the scales being constrained to powers of two. This work is prototype as the hardware support is not available yet.
* [nf4](torchao/dtypes/nf4tensor.py) which was used to [implement QLoRA](https://github.com/pytorch/torchtune/blob/main/docs/source/tutorials/qlora_finetune.rst) one of the most popular finetuning algorithms without writing custom Triton or CUDA code. Accessible talk [here](https://x.com/HamelHusain/status/1800315287574847701)
* [fp6](torchao/prototype/quant_llm/) for 2x faster inference over fp16 with an easy to use API `quantize(model, fp6_llm_weight_only())`
* [fp6](torchao/prototype/quant_llm/) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fp6_llm_weight_only())`

## Composability

Expand Down
16 changes: 8 additions & 8 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
int4_weight_only,
int8_weight_only,
int8_dynamic_activation_int8_weight,
quantize,
quantize_,
_replace_with_custom_fn_if_matches_filter,
)
# APIs to be deprecated (used for torch 2.2.2 and 2.3)
Expand Down Expand Up @@ -98,21 +98,21 @@

def _int8wo_api(mod):
if TORCH_VERSION_AFTER_2_4:
quantize(mod, int8_weight_only(), set_inductor_config=False)
quantize_(mod, int8_weight_only(), set_inductor_config=False)
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int8_woqtensors(mod)

def _int8da_int8w_api(mod):
if TORCH_VERSION_AFTER_2_4:
quantize(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False)
quantize_(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False)
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int8_dqtensors(mod)

def _int4wo_api(mod):
if TORCH_VERSION_AFTER_2_4:
quantize(mod, int4_weight_only(), set_inductor_config=False)
quantize_(mod, int4_weight_only(), set_inductor_config=False)
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int4_woqtensors(mod)
Expand All @@ -127,8 +127,8 @@ def _int4wo_api(mod):
def undo_recommended_configs():
torch._inductor.config.coordinate_descent_tuning = False
torch._inductor.config.coordinate_descent_check_all_directions = False
torch._inductor.config.force_fuse_int_mm_with_mul = False
torch._inductor.config.fx_graph_cache = False
torch._inductor.config.force_fuse_int_mm_with_mul = False
torch._inductor.config.fx_graph_cache = False
torch._inductor.config.triton.unique_kernel_names = False
torch.set_float32_matmul_precision("highest")

Expand Down Expand Up @@ -844,7 +844,7 @@ def api(mod):
kwargs_copy = kwargs.copy()
kwargs_copy["group_size"] = groupsize
del kwargs_copy["groupsize"]
quantize(mod, int4_weight_only(**kwargs_copy))
quantize_(mod, int4_weight_only(**kwargs_copy))
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int4_woqtensors(mod, **kwargs)
Expand All @@ -865,7 +865,7 @@ def test_dynamic_quant(self):
m = nn.Sequential(nn.Linear(K, N))

y_ref = m(x)
quantize(m, int8_dynamic_activation_int8_weight())
quantize_(m, int8_dynamic_activation_int8_weight())
y_test = m(x)

sqnr = compute_error(y_ref, y_test)
Expand Down
4 changes: 2 additions & 2 deletions test/prototype/test_quant_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from torchao.prototype.quant_llm.quant_llm import _pack_tc_fpx, _pack_tc_fp6
from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32
from torchao.quantization.quant_api import quantize
from torchao.quantization.quant_api import quantize_


_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
Expand Down Expand Up @@ -91,7 +91,7 @@ def test_quant_llm_quantize(self, ebits, mbits, bias):

linear = torch.nn.Linear(IC, OC, bias=bias, device=device)
fpx_linear = copy.deepcopy(linear)
quantize(fpx_linear, quant_llm_fpx_weight_only(ebits, mbits))
quantize_(fpx_linear, quant_llm_fpx_weight_only(ebits, mbits))

x = torch.randn(N, IC, device=device, dtype=torch.half)
expected = fpx_linear(x)
Expand Down
18 changes: 9 additions & 9 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
Int8WeightOnlyQuantizedLinearWeight,
Int4WeightOnlyQuantizedLinearWeight,
)
from torchao import quantize
from torchao import quantize_
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
Quantizer,
Expand Down Expand Up @@ -89,7 +89,7 @@ def convert(self, model: torch.nn.Module) -> torch.nn.Module:

class TorchCompileDynamicQuantizer(Quantizer):
def quantize(self, model: torch.nn.Module) -> torch.nn.Module:
quantize(model, int8_dynamic_activation_int8_weight())
quantize_(model, int8_dynamic_activation_int8_weight())
return model

class ToyLinearModel(torch.nn.Module):
Expand Down Expand Up @@ -152,7 +152,7 @@ class TestQuantFlow(TestCase):
def test_dynamic_quant_gpu_singleline(self):
m = ToyLinearModel().eval()
example_inputs = m.example_inputs()
m = quantize(m, int8_dynamic_activation_int8_weight())
quantize_(m, int8_dynamic_activation_int8_weight())
quantized = m(*example_inputs)
# AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64
# While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {})
Expand Down Expand Up @@ -195,7 +195,7 @@ def test_int8_wo_quant_save_load(self):
)
m = ToyLinearModel().eval().cpu()
def api(model):
model = quantize(model, int8_weight_only())
quantize_(model, int8_weight_only())
unwrap_tensor_subclass(model)

api(m)
Expand Down Expand Up @@ -501,7 +501,7 @@ def test_quantized_tensor_subclass_8da4w(self):
m = ToyLinearModel().eval()
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs()
m = quantize(m, int8_dynamic_activation_int4_weight(group_size=group_size))
quantize_(m, int8_dynamic_activation_int4_weight(group_size=group_size))

assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
Expand Down Expand Up @@ -530,7 +530,7 @@ def test_quantized_tensor_subclass_int4(self):
example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda")

group_size = 32
m = quantize(m, int4_weight_only(group_size=group_size))
quantize_(m, int4_weight_only(group_size=group_size))
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)

Expand All @@ -550,7 +550,7 @@ def test_quantized_tensor_subclass_int8_wo(self):
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs()))

m = quantize(m, int8_weight_only())
quantize_(m, int8_weight_only())

assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)
Expand All @@ -573,7 +573,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
m_copy = copy.deepcopy(m)
# setting batch_size to 20 to be compatible with the kernel
example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda")
m = quantize(m, int8_dynamic_activation_int8_weight())
quantize_(m, int8_dynamic_activation_int8_weight())

assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
Expand Down Expand Up @@ -607,7 +607,7 @@ def test_quantized_tensor_subclass_save_load(self):
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs(dtype=torch.bfloat16)

m = quantize(m, int8_weight_only())
quantize_(m, int8_weight_only())
ref = m(*example_inputs)
with tempfile.NamedTemporaryFile() as f:
torch.save(m.state_dict(), f)
Expand Down
4 changes: 2 additions & 2 deletions torchao/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@

from torchao.quantization import (
autoquant,
quantize,
quantize_,
)
from . import dtypes

__all__ = [
"dtypes",
"autoquant",
"quantize",
"quantize_",
]

# test-pytorchbot
36 changes: 18 additions & 18 deletions torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

)
from torchao.quantization.quant_api import (
quantize, int4_weight_only, int8_weight_only, int8_dynamic_activation_int8_weight, unwrap_tensor_subclass
quantize_, int4_weight_only, int8_weight_only, int8_dynamic_activation_int8_weight, unwrap_tensor_subclass

)
from torchao._models._eval import TransformerEvalWrapper, InputRecorder
Expand Down Expand Up @@ -60,13 +60,13 @@ def run_evaluation(

if quantization:
if "int8wo" in quantization:
quantize(model, int8_weight_only())
quantize_(model, int8_weight_only())
if "int8dq" in quantization:
quantize(model, int8_dynamic_activation_int8_weight())
quantize_(model, int8_dynamic_activation_int8_weight())
if "int4wo" in quantization and not "gptq" in quantization:
groupsize=int(quantization.split("-")[-1])
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
quantize(model.to(device), int4_weight_only(group_size=groupsize))
quantize_(model.to(device), int4_weight_only(group_size=groupsize))
if "int4wo" in quantization and "gptq" in quantization:
groupsize=int(quantization.split("-")[-2])
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
Expand Down Expand Up @@ -94,8 +94,8 @@ def run_evaluation(
model = torch.compile(model, mode="max-autotune", fullgraph=True)
with torch.no_grad():
TransformerEvalWrapper(
model=model.to(device),
tokenizer=tokenizer,
model=model.to(device),
tokenizer=tokenizer,
max_seq_length=max_length,
input_prep_func=prepare_inputs_for_model,
device=device,
Expand All @@ -122,16 +122,16 @@ def run_evaluation(

args = parser.parse_args()
run_evaluation(
args.checkpoint_path,
args.tasks,
args.limit,
args.device,
args.precision,
args.quantization,
args.compile,
args.max_length,
args.calibration_tasks,
args.calibration_limit,
args.calibration_seq_length,
args.pad_calibration_inputs,
args.checkpoint_path,
args.tasks,
args.limit,
args.device,
args.precision,
args.quantization,
args.compile,
args.max_length,
args.calibration_tasks,
args.calibration_limit,
args.calibration_seq_length,
args.pad_calibration_inputs,
)
14 changes: 7 additions & 7 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def generate(
T_new = T + max_new_tokens
seq = torch.empty(T_new, dtype=prompt.dtype, device=device)
seq[:T] = prompt.view(-1)

# setup model cache
max_seq_length = min(T_new, model.config.block_size) if not interactive else 350
with torch.device(device):
Expand Down Expand Up @@ -158,7 +158,7 @@ def main(
"""

torchao.quantization.utils.recommended_inductor_config_setter()

assert checkpoint_path.is_file(), checkpoint_path
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), str(tokenizer_path)
Expand All @@ -180,11 +180,11 @@ def main(
prompt_length = encoded.size(0)

torch.manual_seed(1234)


if quantization:
from torchao.quantization.quant_api import (
quantize,
quantize_,
int8_weight_only,
int8_dynamic_activation_int8_weight,
int4_weight_only,
Expand All @@ -193,13 +193,13 @@ def main(
)

if "int8wo" in quantization:
quantize(model, int8_weight_only())
quantize_(model, int8_weight_only())
if "int8dq" in quantization:
quantize(model, int8_dynamic_activation_int8_weight())
quantize_(model, int8_dynamic_activation_int8_weight())
if "int4wo" in quantization:
groupsize=int(quantization.split("-")[-1])
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
quantize(model, int4_weight_only(group_size=groupsize))
quantize_(model, int4_weight_only(group_size=groupsize))
if "autoquant" == quantization:
model = autoquant(model, manual=True)

Expand Down
6 changes: 3 additions & 3 deletions torchao/prototype/quant_llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ This is a FP16 x FPx mixed matmul kernel optimized for io bound workloads per [F
## Usage

```python
from torchao.quantization.quant_api import quantize
from torchao.quantization.quant_api import quantize_
from torchao.prototype.quant_llm import fp6_llm_weight_only, quant_llm_fpx_weight_only

model = ...
model.half() # not necessary, but recommeneded to maintain accuracy
quantize(model, fp6_llm_weight_only()) # convert nn.Lineaer.weight to FP6 E3M2 in-place
quantize_(model, fp6_llm_weight_only()) # convert nn.Lineaer.weight to FP6 E3M2 in-place

# for generic FPx EyMz where x = 1 + y + z
# quantize(model, quant_llm_fpx_weight_only(2, 2)) # use FP5 E2M2 instead
# quantize_(model, quant_llm_fpx_weight_only(2, 2)) # use FP5 E2M2 instead

# fully compatible with torch.compile()
model.compile(mode="max-autotune", fullgraph=True)
Expand Down
10 changes: 5 additions & 5 deletions torchao/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
from torchao.dtypes import to_affine_quantized
import copy
from torchao.quantization.quant_api import (
quantize,
quantize_,
int4_weight_only,
)

Expand All @@ -106,7 +106,7 @@ m_bf16 = torch.compile(m_bf16, mode='max-autotune')
# apply int4 weight only quant (compatible with tinygemm int4 weight only quant mm kernel in torchao)
group_size = 32
# only works for torch 2.4+
m = quantize(m, int4_weight_only(group_size=group_size))
quantize_(m, int4_weight_only(group_size=group_size))

# temporary workaround for tensor subclass + torch.compile
from torchao.utils import unwrap_tensor_subclass
Expand Down Expand Up @@ -173,7 +173,7 @@ torch._inductor.config.force_fuse_int_mm_with_mul = True

# for torch 2.4+
from torchao.quantization import quantize, int8_dynamic_activation_int8_weight
quantize(model, int8_dynamic_activation_int8_weight())
quantize_(model, int8_dynamic_activation_int8_weight())

# for torch 2.2.2 and 2.3
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
Expand All @@ -185,7 +185,7 @@ change_linear_weights_to_int8_dqtensors(model)
```python
# for torch 2.4+
from torchao.quantization import quantize, int8_weight_only
quantize(model, int8_weight_only())
quantize_(model, int8_weight_only())

# for torch 2.2.2 and 2.3
from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors
Expand All @@ -200,7 +200,7 @@ This technique works best when the torch._inductor.config.use_mixed_mm option is
```python
# for torch 2.4+
from torchao.quantization import quantize, int4_weight_only
quantize(model, int4_weight_only())
quantize_(model, int4_weight_only())

# for torch 2.2.2 and 2.3
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"quantize_affine",
"dequantize_affine",
"choose_qprams_affine",
"quantize",
"quantize_",
"int8_dynamic_activation_int4_weight",
"int8_dynamic_activation_int8_weight",
"int4_weight_only",
Expand Down
Loading

0 comments on commit 6fa2d96

Please sign in to comment.