@@ -59,10 +59,6 @@ def setUp(self):
5959 self .batch_size = 32
6060
6161 torch .manual_seed (42 )
62- self .weight_fp = torch .randn (* self .test_shape , dtype = self .dtype )
63- self .input_fp = torch .randn (* self .test_shape , dtype = self .dtype )
64- self .bias = torch .randn (self .test_shape [0 ], dtype = self .dtype )
65- self .block_size = list (self .test_shape )
6662
6763 @common_utils .parametrize (
6864 "config" ,
@@ -80,16 +76,13 @@ def test_creation_and_attributes(self, config):
8076 dtype = self .dtype ,
8177 device = "cuda" ,
8278 )
83- linear .weight .data = self .weight_fp .cuda ()
8479 quantize_ (linear , config )
8580
86- tensor = linear .weight
81+ w = linear .weight
8782
88- self .assertEqual (tensor .shape , self .test_shape )
89- self .assertEqual (tensor .qdata .dtype , torch .int8 )
90- self .assertTrue (
91- torch .all (tensor .qdata >= - 128 ) and torch .all (tensor .qdata <= 127 )
92- )
83+ self .assertEqual (w .shape , self .test_shape )
84+ self .assertEqual (w .qdata .dtype , torch .int8 )
85+ self .assertTrue (torch .all (w .qdata >= - 128 ) and torch .all (w .qdata <= 127 ))
9386
9487 @common_utils .parametrize ("dtype" , [torch .bfloat16 , torch .float32 ])
9588 @common_utils .parametrize ("compile" , [True , False ])
@@ -122,6 +115,9 @@ def test_int8_linear_variants(
122115
123116 quantize_ (model_q , config )
124117
118+ self .assertEqual (model_q .linear2 .weight .scale .shape , (K ,))
119+ self .assertEqual (model_q .linear2 .weight .scale .ndim , 1 )
120+
125121 if compile :
126122 model_q = torch .compile (model_q , fullgraph = True )
127123
@@ -132,23 +128,6 @@ def test_int8_linear_variants(
132128 f"Quantization error is too high got a SQNR of { compute_error (output_fp , output_quantized )} "
133129 )
134130
135- @common_utils .parametrize ("dtype" , [torch .bfloat16 , torch .float16 ])
136- @common_utils .parametrize (
137- "config" ,
138- [
139- Int8DynamicActivationInt8WeightConfig (version = 2 ),
140- Int8WeightOnlyConfig (version = 2 ),
141- ],
142- )
143- def test_per_row_scale_shape (self , dtype , config ):
144- """Test per-row quantization maintains 1D scale"""
145- N , K = 64 , 128
146- linear = torch .nn .Linear (K , N , bias = False , dtype = dtype , device = "cuda" )
147- quantize_ (linear , config )
148-
149- self .assertEqual (linear .weight .scale .shape , (N ,))
150- self .assertEqual (linear .weight .scale .ndim , 1 )
151-
152131 @common_utils .parametrize (
153132 "config" ,
154133 [
0 commit comments