diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index 19e732d078..018688a678 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -893,7 +893,8 @@ def _replace_linear_8da4w( linear_class: Type[torch.nn.Module], copy_weights: bool = False, ): - + + #import the util function here to avoid circular dependency from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter def filter_fn(child: torch.nn.Module, cur_fqn:str) -> bool: @@ -915,41 +916,9 @@ def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: if copy_weights and child.weight.device != torch.device("meta"): new_linear.weight = child.weight return new_linear - #setattr(module, name, new_linear) _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn) - - ''' - for name, child in module.named_children(): - if isinstance(child, nn.Linear): - if _check_linear_int4_k(child.in_features, groupsize) or padding_allowed: - new_linear = linear_class( - child.in_features, - child.out_features, - bias=False, - device=child.weight.device, - groupsize=groupsize, - precision=precision, - scales_precision=scales_precision, - ) - # In distributed training, the model may be instantiated - # on the meta device, in which case there is no need to - # copy the weights, and doing so will result in an error - if copy_weights and child.weight.device != torch.device("meta"): - new_linear.weight = child.weight - setattr(module, name, new_linear) - else: - _replace_linear_8da4w( - child, - groupsize, - padding_allowed, - precision, - scales_precision, - linear_class, - copy_weights, - ) - ''' def replace_linear_8da4w( module: torch.nn.Module, groupsize: int,