Skip to content

Commit

Permalink
Factor out the specific configurations to helper functions
Browse files Browse the repository at this point in the history
Summary:
int4wo, int8wo, int8dyn, 8da4w are specific configurations for quantize function, we factor that out in the PR so they are easy to use

Test Plan:
python test/quantization/test_quant_api.py

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed May 28, 2024
1 parent 5b04ff0 commit b4042ab
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 91 deletions.
100 changes: 9 additions & 91 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@
Quantizer,
TwoStepQuantizer,
quantize,
get_apply_8da4w_quant,
get_apply_int4wo_quant,
get_apply_int8wo_quant,
get_apply_int8dyn_quant,
)
from torchao.quantization.utils import (
TORCH_VERSION_AFTER_2_3,
Expand Down Expand Up @@ -416,42 +420,11 @@ def test_eval_wrapper(self):
# TODO: move to a separate test file
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
def test_quantized_tensor_subclass_8da4w(self):
# weight settings
groupsize = 32
mapping_type = MappingType.SYMMETRIC
block_size = (1, groupsize)
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
quant_min = -8
quant_max = 7

# TODO: make a general helper function?
# input settings
def get_per_token_block_size(x):
block_size = []
for i in range(len(x.shape)-1):
block_size.append(1)
block_size.append(x.shape[-1])
return block_size

# input settings
input_mapping_type = MappingType.ASYMMETRIC
input_target_dtype = torch.int8
input_quant_func = lambda x: to_aq(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype)

m = ToyLinearModel().eval()
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs()

def apply_weight_quant(weight):
return to_aq(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps)

def apply_act_quant(weight):
return to_laq(weight, input_quant_func)

# note: order is important
m = quantize(m, apply_weight_quant)
m = quantize(m, apply_act_quant)
m = quantize(m, get_apply_8da4w_quant(groupsize=groupsize))

assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
Expand All @@ -474,27 +447,13 @@ def apply_act_quant(weight):
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_tensor_subclass_int4(self):
# weight settings
groupsize = 32
mapping_type = MappingType.ASYMMETRIC
block_size = (1, groupsize)
target_dtype = torch.int32
quant_min = 0
quant_max = 15
eps = 1e-6
preserve_zero = False
zero_point_dtype = torch.bfloat16
zero_point_domain = ZeroPointDomain.FLOAT

# use 1024 so that we don't need padding
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs()))

def apply_weight_quant(weight):
return to_aq(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain)

m = quantize(m, apply_weight_quant)
groupsize = 32
m = quantize(m, get_apply_int4wo_quant(groupsize=groupsize))
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)

Expand All @@ -511,21 +470,11 @@ def apply_weight_quant(weight):
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_tensor_subclass_int8(self):
# weight settings
mapping_type = MappingType.SYMMETRIC
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64

m = ToyLinearModel().eval().to(torch.bfloat16)
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs()))

def apply_weight_quant(weight):
block_size = (1, weight.shape[1])
return to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)

m = quantize(m, apply_weight_quant)
m = quantize(m, get_apply_int8wo_quant())

assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)
Expand All @@ -543,43 +492,12 @@ def apply_weight_quant(weight):
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_tensor_subclass_int8_dyn_quant(self):
# weight settings
mapping_type = MappingType.SYMMETRIC
def get_weight_block_size(x):
return (1, x.shape[1])
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64

# input settings
def get_per_token_block_size(x):
block_size = list(x.shape)
for i in range(len(block_size)-1):
block_size[i] = 1
return block_size

input_mapping_type = MappingType.SYMMETRIC
input_target_dtype = torch.int8
input_eps = 1e-5
input_quant_min = -127
input_quant_max = 127
input_quant_func = lambda x: to_aq(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)

# use 1024 so that we don't need padding
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
m_copy = copy.deepcopy(m)
# setting batch_size to 20 to be compatible with the kernel
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs(batch_size=20)))

def apply_weight_quant(weight):
block_size = get_weight_block_size(weight)
return to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)

def apply_act_quant(weight):
return to_laq(weight, input_quant_func)

m = quantize(m, apply_weight_quant)
m = quantize(m, apply_act_quant)
m = quantize(m, get_apply_int8dyn_quant())

assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
Expand Down
110 changes: 110 additions & 0 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@
Int8DynamicallyQuantizedLinearWeight,
Int8WeightOnlyQuantizedLinearWeight,
QuantizedLinearWeightBase,
to_laq,
)

from .quant_primitives import (
MappingType,
ZeroPointDomain,
)
from .weight_only import WeightOnlyInt8QuantLinear
from .unified import Quantizer, TwoStepQuantizer
Expand All @@ -56,6 +62,10 @@
"quantize",
"autoquant",
"_get_subclass_inserter",
"get_apply_8da4w_quant",
"get_apply_int4wo_quant",
"get_apply_int8wo_quant",
"get_apply_int8dyn_quant",
]

if TORCH_VERSION_AFTER_2_3:
Expand Down Expand Up @@ -287,3 +297,103 @@ def filter_fn(module, fqn):
_is_linear if filter_fn is None else filter_fn,
)
return model

def get_apply_8da4w_quant(groupsize=32):

def apply_8da4w_quant(weight):
# avoid circular dep
from torchao.dtypes.aqt import to_aq

# weight settings
mapping_type = MappingType.SYMMETRIC
block_size = (1, groupsize)
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
quant_min = -8
quant_max = 7

# TODO: make a general helper function?
# input settings
def get_per_token_block_size(x):
block_size = []
for i in range(len(x.shape)-1):
block_size.append(1)
block_size.append(x.shape[-1])
return block_size

# input settings
input_mapping_type = MappingType.ASYMMETRIC
input_target_dtype = torch.int8
input_quant_func = lambda x: to_aq(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype)

weight = to_aq(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps)
weight = to_laq(weight, input_quant_func)
return weight

return apply_8da4w_quant


def get_apply_int4wo_quant(groupsize=32):
def apply_int4wo_quant(weight):
# avoid circular dep
from torchao.dtypes.aqt import to_aq

groupsize = 32
mapping_type = MappingType.ASYMMETRIC
block_size = (1, groupsize)
target_dtype = torch.int32
quant_min = 0
quant_max = 15
eps = 1e-6
preserve_zero = False
zero_point_dtype = torch.bfloat16
zero_point_domain = ZeroPointDomain.FLOAT
return to_aq(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain)

return apply_int4wo_quant


def get_apply_int8wo_quant():
def apply_int8wo_quant(weight):
# avoid circular dep
from torchao.dtypes.aqt import to_aq

mapping_type = MappingType.SYMMETRIC
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64
block_size = (1, weight.shape[1])
return to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
return apply_int8wo_quant

def get_apply_int8dyn_quant():
def apply_int8dyn_quant(weight):
# avoid circular dep
from torchao.dtypes.aqt import to_aq
# weight settings
mapping_type = MappingType.SYMMETRIC
def get_weight_block_size(x):
return (1, x.shape[1])
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64

# input settings
def get_per_token_block_size(x):
block_size = list(x.shape)
for i in range(len(block_size)-1):
block_size[i] = 1
return block_size

input_mapping_type = MappingType.SYMMETRIC
input_target_dtype = torch.int8
input_eps = 1e-5
input_quant_min = -127
input_quant_max = 127
input_quant_func = lambda x: to_aq(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)

block_size = get_weight_block_size(weight)
weight = to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
weight = to_laq(weight, input_quant_func)
return weight
return apply_int8dyn_quant

0 comments on commit b4042ab

Please sign in to comment.