diff --git a/neural_compressor/adaptor/pytorch.py b/neural_compressor/adaptor/pytorch.py index 59da156eebd..e3015a87471 100644 --- a/neural_compressor/adaptor/pytorch.py +++ b/neural_compressor/adaptor/pytorch.py @@ -4170,8 +4170,12 @@ def _get_module_scale_zeropoint(self, model, tune_cfg, prefix=""): sub_name = node.target if not hasattr(model, node.target): continue - if "scale" in node.target: - tune_cfg["get_attr"][sub_name] = float(getattr(model, node.target)) + # Improved scale detection logic + if "scale" in node.target and not any(exclude in node.target for exclude in ["layer_scale", "gamma"]): + try: + tune_cfg["get_attr"][sub_name] = getattr(model, node.target).tolist() + except Exception as e: + logger.warning(f"Could not convert {node.target} to list, skipping... Error: {str(e)}") elif "zero_point" in node.target: tune_cfg["get_attr"][sub_name] = int(getattr(model, node.target)) else: diff --git a/test/adaptor/test_pytorch_layer_scale.py b/test/adaptor/test_pytorch_layer_scale.py new file mode 100644 index 00000000000..5c86293ad2b --- /dev/null +++ b/test/adaptor/test_pytorch_layer_scale.py @@ -0,0 +1,120 @@ +import unittest + +import torch +import torch.nn as nn + +from neural_compressor import quantization +from neural_compressor.config import PostTrainingQuantConfig + +torch.manual_seed(42) + + +class CalibDataloader: + """Simple calibration dataloader for testing.""" + + def __init__(self, data, label): + self.data = data + self.label = label + self.batch_size = 1 # Since we're yielding single samples + + def __iter__(self): + yield self.data, self.label + + +class ConvEncoderWithLayerScale(nn.Module): + """Test model with layer_scale parameter that caused the original issue.""" + + def __init__(self, dim=64, hidden_dim=128, kernel_size=3, drop_path=0.0, use_layer_scale=True): + super().__init__() + self.dwconv = nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim) + self.norm = nn.BatchNorm2d(dim) + self.pwconv1 = nn.Conv2d(dim, hidden_dim, kernel_size=1) + self.act = nn.GELU() + self.pwconv2 = nn.Conv2d(hidden_dim, dim, kernel_size=1) + self.drop_path = nn.Identity() if drop_path <= 0.0 else nn.Dropout(drop_path) + self.use_layer_scale = use_layer_scale + if use_layer_scale: + self.layer_scale = nn.Parameter(torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True) + + def forward(self, x): + input = x + x = self.dwconv(x) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.use_layer_scale: + x = self.layer_scale * x + x = input + self.drop_path(x) + return x + + +class ConvEncoderWithLayerGamma(nn.Module): + """Test model with renamed layer_gamma parameter (the fix).""" + + def __init__(self, dim=64, hidden_dim=128, kernel_size=3, drop_path=0.0, use_layer_scale=True): + super().__init__() + self.dwconv = nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim) + self.norm = nn.BatchNorm2d(dim) + self.pwconv1 = nn.Conv2d(dim, hidden_dim, kernel_size=1) + self.act = nn.GELU() + self.pwconv2 = nn.Conv2d(hidden_dim, dim, kernel_size=1) + self.drop_path = nn.Identity() if drop_path <= 0.0 else nn.Dropout(drop_path) + self.use_layer_scale = use_layer_scale + if use_layer_scale: + self.layer_gamma = nn.Parameter(torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True) + + def forward(self, x): + input = x + x = self.dwconv(x) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.use_layer_scale: + x = self.layer_gamma * x + x = input + self.drop_path(x) + return x + + +class TestPyTorchLayerScale(unittest.TestCase): + @classmethod + def setUpClass(self): + self.constant_data = torch.randn(1, 64, 32, 32) + self.constant_label = torch.randint(0, 10, (1,)) + + def test_layer_scale_error(self): + """Test that the original layer_scale parameter causes an error.""" + model = ConvEncoderWithLayerScale() + model.eval() + + calib_dataloader = CalibDataloader(self.constant_data, self.constant_label) + + # Configure quantization + conf = PostTrainingQuantConfig() + + # Try to quantize and verify it fails + q_model = quantization.fit(model, conf, calib_dataloader=calib_dataloader) + # The quantization should fail and return None + self.assertIsNotNone(q_model, "Quantization should succeed with layer_scale parameter") + + def test_layer_gamma_success(self): + """Test that the renamed layer_gamma parameter works correctly.""" + model = ConvEncoderWithLayerGamma() + model.eval() + + calib_dataloader = CalibDataloader(self.constant_data, self.constant_label) + + # Configure quantization + conf = PostTrainingQuantConfig() + + # This should succeed with layer_gamma parameter + try: + q_model = quantization.fit(model, conf, calib_dataloader=calib_dataloader) + self.assertIsNotNone(q_model) + except ValueError as e: + self.fail(f"Quantization failed with layer_gamma: {str(e)}") + + +if __name__ == "__main__": + unittest.main()