diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index 275e716aa1..018688a678 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -893,10 +893,15 @@ def _replace_linear_8da4w( linear_class: Type[torch.nn.Module], copy_weights: bool = False, ): - for name, child in module.named_children(): - if isinstance(child, nn.Linear): - if _check_linear_int4_k(child.in_features, groupsize) or padding_allowed: - new_linear = linear_class( + + #import the util function here to avoid circular dependency + from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter + + def filter_fn(child: torch.nn.Module, cur_fqn:str) -> bool: + return isinstance(child, nn.Linear) and (_check_linear_int4_k(child.in_features, groupsize) or padding_allowed) + + def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: + new_linear = linear_class( child.in_features, child.out_features, bias=False, @@ -905,22 +910,14 @@ def _replace_linear_8da4w( precision=precision, scales_precision=scales_precision, ) - # In distributed training, the model may be instantiated - # on the meta device, in which case there is no need to - # copy the weights, and doing so will result in an error - if copy_weights and child.weight.device != torch.device("meta"): - new_linear.weight = child.weight - setattr(module, name, new_linear) - else: - _replace_linear_8da4w( - child, - groupsize, - padding_allowed, - precision, - scales_precision, - linear_class, - copy_weights, - ) + # In distributed training, the model may be instantiated + # on the meta device, in which case there is no need to + # copy the weights, and doing so will result in an error + if copy_weights and child.weight.device != torch.device("meta"): + new_linear.weight = child.weight + return new_linear + + _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn) def replace_linear_8da4w( module: torch.nn.Module,