3030 QuantizationStrategy ,
3131 apply_quantization_config ,
3232)
33+ from compressed_tensors .config import CompressionFormat
3334from compressed_tensors .quantization .lifecycle .forward import fake_quantize
3435from safetensors .torch import save_file
36+ from compressed_tensors .compressors .model_compressors .model_compressor import (
37+ ModelCompressor ,
38+ )
3539from torch .nn import Linear , Module , Sequential
3640
3741
@@ -90,15 +94,17 @@ def test_end_to_end_asymmetric_quantization(
9094
9195 model = SimpleModel ()
9296 original_weights = {
93- "layer1" : model .layer1 .weight .clone (),
94- "layer2" : model .layer2 .weight .clone (),
97+ "layer1" : model .layer1 .weight .detach (). clone (),
98+ "layer2" : model .layer2 .weight .detach (). clone (),
9599 }
96100
97101 quant_config = create_asymmetric_quant_config (
98102 num_bits = 4 ,
99103 strategy = strategy ,
100104 group_size = group_size
101105 )
106+ # Set pack-quantized format for ModelCompressor usage
107+ quant_config .format = CompressionFormat .pack_quantized .value
102108 apply_quantization_config (model , quant_config )
103109
104110 if strategy == QuantizationStrategy .GROUP :
@@ -126,35 +132,33 @@ def test_end_to_end_asymmetric_quantization(
126132 assert compressed_state_dict ["layer1.weight_zero_point" ].dtype == torch .int32
127133 assert compressed_state_dict ["layer2.weight_zero_point" ].dtype == torch .int32
128134
129- save_file (compressed_state_dict , tmp_path / "model.safetensors" )
130-
131- reconstructed_gen = compressor .decompress (
132- tmp_path , names_to_scheme = quantized_modules_to_scheme
133- )
134-
135- reconstructed_weights = {}
136- for module_name , module_data in reconstructed_gen :
137- reconstructed_weights [module_name ] = module_data
138-
139- assert "layer1" in reconstructed_weights
140- assert "layer2" in reconstructed_weights
141- assert "weight" in reconstructed_weights ["layer1" ]
142- assert "weight" in reconstructed_weights ["layer2" ]
143-
144- assert reconstructed_weights ["layer1" ]["weight" ].shape == original_weights ["layer1" ].shape
145- assert reconstructed_weights ["layer2" ]["weight" ].shape == original_weights ["layer2" ].shape
146-
147135 new_model = SimpleModel ()
148- new_model .layer1 .weight .data = reconstructed_weights ["layer1" ]["weight" ]
149- new_model .layer2 .weight .data = reconstructed_weights ["layer2" ]["weight" ]
150-
151- test_input = torch .randn (1 , 512 )
152- with torch .no_grad ():
153- output = new_model (test_input )
154-
155- assert output .shape == (1 , 128 )
156- assert not torch .isnan (output ).any ()
157- assert not torch .isinf (output ).any ()
136+ apply_quantization_config (new_model , quant_config )
137+
138+ for module_name in ["layer1" , "layer2" ]:
139+ module = getattr (new_model , module_name )
140+ prefix = f"{ module_name } ."
141+ for key , value in compressed_state_dict .items ():
142+ if key .startswith (prefix ):
143+ param_name = key [len (prefix ):]
144+ if hasattr (module , param_name ):
145+ getattr (module , param_name ).data = value .clone ()
146+ else :
147+ module .register_parameter (
148+ param_name , torch .nn .Parameter (value .clone (), requires_grad = False )
149+ )
150+
151+ mc = ModelCompressor (quantization_config = quant_config )
152+ mc .decompress_model (new_model )
153+
154+ assert new_model .layer1 .weight .shape == original_weights ["layer1" ].shape
155+ assert new_model .layer2 .weight .shape == original_weights ["layer2" ].shape
156+ assert new_model .layer1 .weight .dtype .is_floating_point
157+ assert new_model .layer2 .weight .dtype .is_floating_point
158+ assert not torch .isnan (new_model .layer1 .weight ).any ()
159+ assert not torch .isnan (new_model .layer2 .weight ).any ()
160+ assert not torch .isinf (new_model .layer1 .weight ).any ()
161+ assert not torch .isinf (new_model .layer2 .weight ).any ()
158162
159163
160164@pytest .mark .parametrize ("num_bits" , [4 , 8 ])
@@ -174,6 +178,7 @@ def test_asymmetric_quantization_accuracy(num_bits, mock_per_group_calibration):
174178 strategy = QuantizationStrategy .GROUP ,
175179 group_size = 128 ,
176180 )
181+ quant_config .format = CompressionFormat .pack_quantized .value
177182
178183 class SingleLayer (Module ):
179184 def __init__ (self ):
@@ -194,31 +199,26 @@ def __init__(self):
194199 model .state_dict ().copy (), names_to_scheme = quantized_modules_to_scheme
195200 )
196201
197- save_file (compressed_state_dict , tmp_path / "model.safetensors" )
198-
199- reconstructed_gen = compressor .decompress (
200- tmp_path , names_to_scheme = quantized_modules_to_scheme
201- )
202-
203- reconstructed = {}
204- for module_name , module_data in reconstructed_gen :
205- reconstructed [module_name ] = module_data
206-
207- assert "layer" in reconstructed
208- assert "weight" in reconstructed ["layer" ]
209- assert reconstructed ["layer" ]["weight" ].shape == shape
210-
211- decompressed_weights = reconstructed ["layer" ]["weight" ]
202+ new_model = SingleLayer ()
203+ apply_quantization_config (new_model , quant_config )
204+
205+ module = new_model .layer
206+ for key , value in compressed_state_dict .items ():
207+ if key .startswith ("layer." ):
208+ param_name = key [len ("layer." ):]
209+ if hasattr (module , param_name ):
210+ getattr (module , param_name ).data = value .clone ()
211+ else :
212+ module .register_parameter (
213+ param_name , torch .nn .Parameter (value .clone (), requires_grad = False )
214+ )
215+
216+ mc = ModelCompressor (quantization_config = quant_config )
217+ mc .decompress_model (new_model )
218+
219+ decompressed_weights = new_model .layer .weight
220+ assert decompressed_weights .shape == shape
212221 assert not torch .isnan (decompressed_weights ).any ()
213222 assert not torch .isinf (decompressed_weights ).any ()
214-
215- assert decompressed_weights .abs ().max () < 100
216- assert decompressed_weights .abs ().max () > 0.01
217-
218-
219- if __name__ == "__main__" :
220- test_end_to_end_asymmetric_quantization (QuantizationStrategy .GROUP , 128 )
221- test_end_to_end_asymmetric_quantization (QuantizationStrategy .CHANNEL , None )
222- test_asymmetric_quantization_accuracy (4 )
223- test_asymmetric_quantization_accuracy (8 )
224- print ("All tests passed!" )
223+ threshold = torch .std (torch .rand (shape ) - torch .rand (shape ))
224+ assert torch .std (biased_weights - decompressed_weights ) < threshold
0 commit comments