Skip to content

Commit 2523c70

Browse files
authored
Merge branch 'main' into patch-1
2 parents 7364e19 + cda787c commit 2523c70

File tree

2 files changed

+286
-57
lines changed

2 files changed

+286
-57
lines changed

test/quantization/test_quant_api.py

+73-13
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,10 @@ def test_eval_wrapper(self):
395395
# TODO: move to a separate test file
396396
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
397397
def test_quantized_tensor_subclass_8da4w(self):
398-
from torchao.quantization.subclass import AffineQuantizedTensor
398+
from torchao.quantization.subclass import (
399+
AffineQuantizedTensor,
400+
LinearActQuantizedTensor,
401+
)
399402
from torchao.quantization.quant_primitives import MappingType
400403
import copy
401404

@@ -409,6 +412,7 @@ def test_quantized_tensor_subclass_8da4w(self):
409412
quant_max = 7
410413

411414
# TODO: make a general helper function?
415+
# input settings
412416
def get_per_token_block_size(x):
413417
block_size = []
414418
for i in range(len(x.shape)-1):
@@ -421,13 +425,18 @@ def get_per_token_block_size(x):
421425
input_target_dtype = torch.int8
422426
input_quant_func = lambda x: AffineQuantizedTensor.from_float(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype)
423427

428+
def dynamic_quant(linear):
429+
# note: order is important
430+
linear.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(linear.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps), requires_grad=False)
431+
linear.weight = torch.nn.Parameter(LinearActQuantizedTensor.from_float(linear.weight, input_quant_func), requires_grad=False)
432+
424433
m = ToyLinearModel().eval()
425434
m_copy = copy.deepcopy(m)
426435
example_inputs = m.example_inputs()
427-
m.linear1.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(m.linear1.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, input_quant_func=input_quant_func), requires_grad=False)
428-
m.linear2.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(m.linear2.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, input_quant_func=input_quant_func), requires_grad=False)
429-
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
430-
assert isinstance(m.linear2.weight, AffineQuantizedTensor)
436+
dynamic_quant(m.linear1)
437+
dynamic_quant(m.linear2)
438+
assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
439+
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
431440

432441
# reference
433442
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
@@ -461,9 +470,6 @@ def test_quantized_tensor_subclass_int4(self):
461470
preserve_zero = False
462471
zero_point_dtype = torch.bfloat16
463472

464-
# weight only quantization
465-
input_quant_func = None
466-
467473
# use 1024 so that we don't need padding
468474
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
469475
m_copy = copy.deepcopy(m)
@@ -475,7 +481,6 @@ def to_quantized(weight):
475481
zero_point_dtype=zero_point_dtype,
476482
preserve_zero=preserve_zero,
477483
zero_point_domain=ZeroPointDomain.FLOAT,
478-
input_quant_func=input_quant_func,
479484
)
480485

481486
m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False)
@@ -506,16 +511,13 @@ def test_quantized_tensor_subclass_int8(self):
506511
eps = torch.finfo(torch.float32).eps
507512
zero_point_dtype = torch.int64
508513

509-
# weight only quantization
510-
input_quant_func = None
511-
512514
m = ToyLinearModel().eval().to(torch.bfloat16)
513515
m_copy = copy.deepcopy(m)
514516
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs()))
515517

516518
def to_quantized(weight):
517519
block_size = (1, weight.shape[1])
518-
return AffineQuantizedTensor.from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, input_quant_func=input_quant_func)
520+
return AffineQuantizedTensor.from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
519521

520522
m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False)
521523
m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False)
@@ -532,5 +534,63 @@ def to_quantized(weight):
532534
torch.testing.assert_close(res, ref, rtol=0.00001, atol=1e-2)
533535

534536

537+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
538+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
539+
def test_quantized_tensor_subclass_int8_dyn_quant(self):
540+
from torchao.quantization.subclass import AffineQuantizedTensor
541+
from torchao.quantization.subclass import LinearActQuantizedTensor
542+
from torchao.quantization.quant_primitives import MappingType
543+
from torchao.quantization.quant_primitives import ZeroPointDomain
544+
import copy
545+
546+
# weight settings
547+
mapping_type = MappingType.SYMMETRIC
548+
def get_weight_block_size(x):
549+
return (1, x.shape[1])
550+
target_dtype = torch.int8
551+
eps = torch.finfo(torch.float32).eps
552+
zero_point_dtype = torch.int64
553+
554+
# input settings
555+
def get_per_token_block_size(x):
556+
block_size = list(x.shape)
557+
for i in range(len(block_size)-1):
558+
block_size[i] = 1
559+
return block_size
560+
561+
input_mapping_type = MappingType.SYMMETRIC
562+
input_target_dtype = torch.int8
563+
input_eps = 1e-5
564+
input_quant_min = -127
565+
input_quant_max = 127
566+
input_quant_func = lambda x: AffineQuantizedTensor.from_float(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float)
567+
568+
# use 1024 so that we don't need padding
569+
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
570+
m_copy = copy.deepcopy(m)
571+
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs()))
572+
573+
def dynamic_quant(linear):
574+
# note: order is important
575+
linear.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(linear.weight, mapping_type, get_weight_block_size(linear.weight), target_dtype, eps=eps, zero_point_dtype=zero_point_dtype), requires_grad=False)
576+
linear.weight = torch.nn.Parameter(LinearActQuantizedTensor.from_float(linear.weight, input_quant_func), requires_grad=False)
577+
578+
dynamic_quant(m.linear1)
579+
dynamic_quant(m.linear2)
580+
assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
581+
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
582+
assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor)
583+
assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor)
584+
585+
# reference
586+
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
587+
change_linear_weights_to_int8_dqtensors(m_copy)
588+
589+
res = m(*example_inputs)
590+
ref = m_copy(*example_inputs)
591+
592+
self.assertTrue(torch.equal(res, ref))
593+
594+
535595
if __name__ == "__main__":
536596
unittest.main()

0 commit comments

Comments
 (0)