1616 Int8WeightOnlyConfig ,
1717 quantize_ ,
1818)
19- from torchao .quantization .quantize_ .workflows .int8 .int8_tensor import (
20- Int8Tensor ,
21- )
2219from torchao .quantization .utils import compute_error
2320from torchao .testing .utils import TorchAOIntegrationTestCase
2421
@@ -56,7 +53,7 @@ class TestInt8Tensor(TorchAOIntegrationTestCase):
5653 def setUp (self ):
5754 super ().setUp ()
5855
59- self .test_shape = (4 , 3 )
56+ self .test_shape = (32 , 20 )
6057 self .dtype = torch .bfloat16
6158 self .batch_size = 32
6259
@@ -66,9 +63,26 @@ def setUp(self):
6663 self .bias = torch .randn (self .test_shape [0 ], dtype = self .dtype )
6764 self .block_size = list (self .test_shape )
6865
69- def test_creation_and_attributes (self ):
66+ @common_utils .parametrize (
67+ "config" ,
68+ [
69+ Int8DynamicActivationInt8WeightConfig (version = 2 ),
70+ Int8WeightOnlyConfig (version = 2 ),
71+ ],
72+ )
73+ def test_creation_and_attributes (self , config ):
7074 """Test tensor creation, dtypes, and ranges"""
71- tensor = Int8Tensor .from_hp (self .weight_fp , self .block_size )
75+ linear = torch .nn .Linear (
76+ self .test_shape [1 ],
77+ self .test_shape [0 ],
78+ bias = False ,
79+ dtype = self .dtype ,
80+ device = "cuda" ,
81+ )
82+ linear .weight .data = self .weight_fp .cuda ()
83+ quantize_ (linear , config )
84+
85+ tensor = linear .weight
7286
7387 self .assertEqual (tensor .shape , self .test_shape )
7488 self .assertEqual (tensor .qdata .dtype , torch .int8 )
@@ -117,46 +131,6 @@ def test_int8_linear_variants(
117131 f"Quantization error is too high got a SQNR of { compute_error (output_fp , output_quantized )} "
118132 )
119133
120- @common_utils .parametrize ("dtype" , [torch .bfloat16 , torch .float16 ])
121- @common_utils .parametrize (
122- "config" ,
123- [
124- Int8DynamicActivationInt8WeightConfig (version = 2 ),
125- Int8WeightOnlyConfig (version = 2 ),
126- ],
127- )
128- @common_utils .parametrize (
129- "sizes" ,
130- [
131- ((128 ,), 256 , 128 ), # 2D
132- ((32 , 128 ), 64 , 256 ), # 3D
133- ],
134- )
135- def test_int8_linear_quantization_accuracy (
136- self ,
137- dtype : torch .dtype ,
138- sizes : tuple ,
139- config ,
140- ):
141- """Test quantization preserves reasonable accuracy"""
142- M , N , K = sizes
143- input_tensor = torch .randn (* M , K , dtype = dtype , device = "cuda" )
144-
145- # Create a linear layer
146- m = ToyTwoLinearModel (K , N , K , dtype = dtype , device = "cuda" ).eval ()
147- m_q = copy .deepcopy (m )
148-
149- # Quantize
150- quantize_ (m_q , config )
151-
152- output_fp = m (input_tensor )
153- output_quantized = m_q (input_tensor )
154-
155- error = compute_error (output_fp , output_quantized )
156- assert error > 20 , (
157- f"Quantization quality is too low, SQNR: { error } dB (expected > { 20 } dB)"
158- )
159-
160134 @common_utils .parametrize ("dtype" , [torch .bfloat16 , torch .float16 ])
161135 @common_utils .parametrize (
162136 "config" ,
@@ -218,36 +192,42 @@ def test_slice(self, config, device, dtype):
218192 with self .assertRaises (NotImplementedError ):
219193 _ = dummy .weight [::2 ]
220194
221- def test_index_select (self ):
222- """test that `x_0 = x[0]` works when `x` is a 2D `Int8Tensor`."""
195+ @common_utils .parametrize (
196+ "config" ,
197+ [
198+ Int8DynamicActivationInt8WeightConfig (version = 2 ),
199+ Int8WeightOnlyConfig (version = 2 ),
200+ ],
201+ )
202+ def test_index_select (self , config ):
203+ """test that `x_0 = x[0]` works when `x` is a 2D quantized tensor."""
223204 N , K = 256 , 512
224205 x = torch .randn (N , K , device = "cuda" , dtype = torch .bfloat16 )
225- x_int8 = Int8Tensor .from_hp (x , block_size = [N , K ])
206+ linear = torch .nn .Linear (K , N , bias = False , dtype = torch .bfloat16 , device = "cuda" )
207+ linear .weight .data = x
208+ quantize_ (linear , config )
209+
210+ x_int8 = linear .weight
226211 x_int8_0 = x_int8 [0 ]
227212 torch .testing .assert_close (
228213 x_int8 .dequantize ()[0 ], x_int8_0 .dequantize (), atol = 0 , rtol = 0
229214 )
230215
231- def test_invalid_input_handling (self ):
232- """Test input validation with specific error types"""
233- invalid_tensor = torch .randn (5 )
234- incompatible_block_size = [1 ]
235-
236- with self .assertRaises (
237- ValueError , msg = "Should reject incompatible tensor dimensions"
238- ):
239- Int8Tensor .from_hp (invalid_tensor , incompatible_block_size )
240-
241- with self .assertRaises (
242- ValueError , msg = "Should reject mismatched block size dimensions"
243- ):
244- Int8Tensor .from_hp (self .weight_fp , [1 ])
245-
246- def test_dequantization_accuracy (self ):
216+ @common_utils .parametrize (
217+ "config" ,
218+ [
219+ Int8DynamicActivationInt8WeightConfig (version = 2 ),
220+ Int8WeightOnlyConfig (version = 2 ),
221+ ],
222+ )
223+ def test_dequantization_accuracy (self , config ):
247224 """Test dequantization accuracy separately"""
248- test_data = torch .tensor ([[1.0 , - 1.0 ]], dtype = torch .bfloat16 )
249- tensor = Int8Tensor .from_hp (test_data , [1 , 2 ])
225+ test_data = torch .tensor ([[1.0 , - 1.0 ]], dtype = torch .bfloat16 , device = "cuda" )
226+ linear = torch .nn .Linear (2 , 1 , bias = False , dtype = torch .bfloat16 , device = "cuda" )
227+ linear .weight .data = test_data
228+ quantize_ (linear , config )
250229
230+ tensor = linear .weight
251231 dequantized = tensor .dequantize ()
252232 self .assertEqual (dequantized .shape , test_data .shape )
253233 self .assertLess (
@@ -268,23 +248,14 @@ def test_available_gpu_kernels(self):
268248 x = torch .randn (M , K , device = "cuda" , dtype = torch .bfloat16 )
269249
270250 out , code = run_and_get_code (m , x )
271- kernels = {}
272-
273- # Check for Triton kernels
274- if "torch.ops.triton" in code [0 ]:
275- kernels ["triton" ] = True
276- print ("Triton kernels are available for int8 quantization" )
277- else :
278- kernels ["triton" ] = False
279- print ("Triton kernels are NOT available for int8 quantization" )
251+ has_triton = "triton" in code [0 ].lower () # Trition
252+ has_fbgemm = "fbgemm" in code [0 ].lower () # FB-GEMM
253+ has_int_mm = "_int_mm" in code [0 ] # Int8 MatMul
280254
281- # Check for FBGEMM kernels
282- if "torch.ops.fbgemm" in code [0 ]:
283- kernels ["fbgemm" ] = True
284- print ("FBGEMM kernels are available for int8 quantization" )
285- else :
286- kernels ["fbgemm" ] = False
287- print ("FBGEMM kernels are NOT available for int8 quantization" )
255+ self .assertTrue (
256+ has_triton or has_fbgemm or has_int_mm ,
257+ f"No int8 quantization kernels found. has_triton={ has_triton } , has_fbgemm={ has_fbgemm } , has_int_mm={ has_int_mm } " ,
258+ )
288259
289260
290261if __name__ == "__main__" :
0 commit comments