Skip to content

Commit

Permalink
refactor _replace_linear_8da4w (#451)
Browse files Browse the repository at this point in the history
* refactor _replace_linear_8da4w

* clean up version

---------
  • Loading branch information
Hanxian97 authored Jul 1, 2024
1 parent dee13e1 commit 39b02de
Showing 1 changed file with 17 additions and 20 deletions.
37 changes: 17 additions & 20 deletions torchao/quantization/GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,10 +893,15 @@ def _replace_linear_8da4w(
linear_class: Type[torch.nn.Module],
copy_weights: bool = False,
):
for name, child in module.named_children():
if isinstance(child, nn.Linear):
if _check_linear_int4_k(child.in_features, groupsize) or padding_allowed:
new_linear = linear_class(

#import the util function here to avoid circular dependency
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter

def filter_fn(child: torch.nn.Module, cur_fqn:str) -> bool:
return isinstance(child, nn.Linear) and (_check_linear_int4_k(child.in_features, groupsize) or padding_allowed)

def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
new_linear = linear_class(
child.in_features,
child.out_features,
bias=False,
Expand All @@ -905,22 +910,14 @@ def _replace_linear_8da4w(
precision=precision,
scales_precision=scales_precision,
)
# In distributed training, the model may be instantiated
# on the meta device, in which case there is no need to
# copy the weights, and doing so will result in an error
if copy_weights and child.weight.device != torch.device("meta"):
new_linear.weight = child.weight
setattr(module, name, new_linear)
else:
_replace_linear_8da4w(
child,
groupsize,
padding_allowed,
precision,
scales_precision,
linear_class,
copy_weights,
)
# In distributed training, the model may be instantiated
# on the meta device, in which case there is no need to
# copy the weights, and doing so will result in an error
if copy_weights and child.weight.device != torch.device("meta"):
new_linear.weight = child.weight
return new_linear

_replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)

def replace_linear_8da4w(
module: torch.nn.Module,
Expand Down

0 comments on commit 39b02de

Please sign in to comment.