@@ -300,11 +300,26 @@ def test_run_with_ao_quantization_configs(
300300 artifact = PipelineArtifact (data = models_dict , context = {})
301301 stage .run (artifact )
302302
303- # Verify quantize_ was called with the model and config
304- mock_quantize .assert_called_once_with (self .model , mock_config , mock_filter_fn )
303+ # Verify quantize_ was called once (with the copied model, not the original)
304+ self .assertEqual (mock_quantize .call_count , 1 )
305+ # Verify the config and filter_fn arguments are correct
306+ call_args = mock_quantize .call_args [0 ]
307+ self .assertNotEqual (self .model , call_args [0 ])
308+ self .assertEqual (call_args [1 ], mock_config )
309+ self .assertEqual (call_args [2 ], mock_filter_fn )
305310
306- # Verify unwrap_tensor_subclass was called with the model
307- mock_unwrap .assert_called_once_with (self .model )
311+ # Verify unwrap_tensor_subclass was called once (with the copied model)
312+ self .assertEqual (mock_unwrap .call_count , 1 )
313+
314+ # Verify that the original models_dict is unchanged
315+ self .assertEqual (models_dict , {"forward" : self .model })
316+
317+ # Verify that the result artifact data contains valid models
318+ result_artifact = stage .get_artifacts ()
319+ self .assertIn ("forward" , result_artifact .data )
320+ self .assertIsNotNone (result_artifact .data ["forward" ])
321+ # verify the result model is NOT the same object as the original
322+ self .assertIsNot (result_artifact .data ["forward" ], self .model )
308323
309324
310325class TestQuantizeStage (unittest .TestCase ):
@@ -398,6 +413,10 @@ def test_run_with_quantizers(
398413 self .assertIn ("forward" , result_artifact .data )
399414 self .assertEqual (result_artifact .data ["forward" ], mock_quantized_model )
400415
416+ # Verify that the original model in the input artifact is unchanged
417+ self .assertEqual (artifact .data ["forward" ], self .model )
418+ self .assertIsNot (result_artifact .data ["forward" ], self .model )
419+
401420 def test_run_empty_example_inputs (self ) -> None :
402421 """Test error when example inputs list is empty."""
403422 mock_quantizer = Mock ()
0 commit comments