Skip to content

Commit

Permalink
clean up version
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanxian97 committed Jun 27, 2024
1 parent 80c71b6 commit 5be1645
Showing 1 changed file with 2 additions and 33 deletions.
35 changes: 2 additions & 33 deletions torchao/quantization/GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,7 +893,8 @@ def _replace_linear_8da4w(
linear_class: Type[torch.nn.Module],
copy_weights: bool = False,
):


#import the util function here to avoid circular dependency
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter

def filter_fn(child: torch.nn.Module, cur_fqn:str) -> bool:
Expand All @@ -915,41 +916,9 @@ def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
if copy_weights and child.weight.device != torch.device("meta"):
new_linear.weight = child.weight
return new_linear
#setattr(module, name, new_linear)

_replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)


'''
for name, child in module.named_children():
if isinstance(child, nn.Linear):
if _check_linear_int4_k(child.in_features, groupsize) or padding_allowed:
new_linear = linear_class(
child.in_features,
child.out_features,
bias=False,
device=child.weight.device,
groupsize=groupsize,
precision=precision,
scales_precision=scales_precision,
)
# In distributed training, the model may be instantiated
# on the meta device, in which case there is no need to
# copy the weights, and doing so will result in an error
if copy_weights and child.weight.device != torch.device("meta"):
new_linear.weight = child.weight
setattr(module, name, new_linear)
else:
_replace_linear_8da4w(
child,
groupsize,
padding_allowed,
precision,
scales_precision,
linear_class,
copy_weights,
)
'''
def replace_linear_8da4w(
module: torch.nn.Module,
groupsize: int,
Expand Down

0 comments on commit 5be1645

Please sign in to comment.