99
1010import torch
1111from torch ._inductor .utils import run_and_get_code
12+ from torch .testing import FileCheck
1213from torch .testing ._internal import common_utils
1314
1415from torchao .quantization import (
@@ -145,12 +146,8 @@ def test_per_row_scale_shape(self, dtype, config):
145146 linear = torch .nn .Linear (K , N , bias = False , dtype = dtype , device = "cuda" )
146147 quantize_ (linear , config )
147148
148- # Dynamic: per-row (1D scale [N]), Weight-only: per-tensor (scalar)
149- if isinstance (config , Int8DynamicActivationInt8WeightConfig ):
150- self .assertEqual (linear .weight .scale .shape , (N ,))
151- self .assertEqual (linear .weight .scale .ndim , 1 )
152- else :
153- self .assertEqual (linear .weight .scale .numel (), 1 )
149+ self .assertEqual (linear .weight .scale .shape , (N ,))
150+ self .assertEqual (linear .weight .scale .ndim , 1 )
154151
155152 @common_utils .parametrize (
156153 "config" ,
@@ -162,7 +159,7 @@ def test_per_row_scale_shape(self, dtype, config):
162159 @common_utils .parametrize ("device" , ["cpu" , "cuda" ])
163160 @common_utils .parametrize ("dtype" , [torch .bfloat16 , torch .float16 ])
164161 def test_slice (self , config , device , dtype ):
165- """Test tensor slicing"""
162+ """Test tensor slicing with per-row quantization """
166163 tensor_size = 256
167164 slice_sizes = (64 , 128 )
168165
@@ -176,19 +173,8 @@ def test_slice(self, config, device, dtype):
176173
177174 self .assertEqual (weight1 .qdata , dummy .weight .qdata .narrow (0 , 0 , slice_sizes [0 ]))
178175 self .assertEqual (weight2 .qdata , dummy .weight .qdata .narrow (1 , 0 , slice_sizes [1 ]))
179-
180- # Int8DynamicActivationInt8WeightConfig uses per-row (PerRow)
181- # Int8WeightOnlyConfig uses per-tensor (PerTensor)
182- if isinstance (config , Int8DynamicActivationInt8WeightConfig ):
183- # PerRow: dim 0 slicing affects scale, dim 1 doesn't
184- self .assertEqual (
185- weight1 .scale , dummy .weight .scale .narrow (0 , 0 , slice_sizes [0 ])
186- )
187- self .assertEqual (weight2 .scale , dummy .weight .scale )
188- else :
189- # PerTensor: scale unchanged by slicing
190- self .assertEqual (weight1 .scale , dummy .weight .scale )
191- self .assertEqual (weight2 .scale , dummy .weight .scale )
176+ self .assertEqual (weight1 .scale , dummy .weight .scale .narrow (0 , 0 , slice_sizes [0 ]))
177+ self .assertEqual (weight2 .scale , dummy .weight .scale )
192178 with self .assertRaises (NotImplementedError ):
193179 _ = dummy .weight [::2 ]
194180
@@ -230,13 +216,15 @@ def test_dequantization_accuracy(self, config):
230216 tensor = linear .weight
231217 dequantized = tensor .dequantize ()
232218 self .assertEqual (dequantized .shape , test_data .shape )
233- self .assertLess (
234- torch .abs (dequantized - test_data ).max ().item (),
235- 0.1 ,
236- msg = f"Dequantization error exceeds tolerance of { 0.1 } " ,
219+ assert compute_error (dequantized , test_data ) > 20 , (
220+ f"Dequantization error is too high to get a SQNR of { compute_error (dequantized , test_data )} "
237221 )
238222
239- def test_available_gpu_kernels (self ):
223+ @common_utils .parametrize (
224+ "kernel" ,
225+ ["triton_per_fused" , "extern_kernels._int_mm" , "triton_poi_fused" ],
226+ )
227+ def test_available_gpu_kernels (self , kernel ):
240228 """Check which GPU kernels are available"""
241229 M , K , N = 128 , 256 , 512
242230 m = torch .nn .Sequential (
@@ -248,14 +236,7 @@ def test_available_gpu_kernels(self):
248236 x = torch .randn (M , K , device = "cuda" , dtype = torch .bfloat16 )
249237
250238 out , code = run_and_get_code (m , x )
251- has_triton = "triton" in code [0 ].lower () # Trition
252- has_fbgemm = "fbgemm" in code [0 ].lower () # FB-GEMM
253- has_int_mm = "_int_mm" in code [0 ] # Int8 MatMul
254-
255- self .assertTrue (
256- has_triton or has_fbgemm or has_int_mm ,
257- f"No int8 quantization kernels found. has_triton={ has_triton } , has_fbgemm={ has_fbgemm } , has_int_mm={ has_int_mm } " ,
258- )
239+ FileCheck ().check (kernel ).run (code [0 ])
259240
260241
261242if __name__ == "__main__" :
0 commit comments