Skip to content

Commit

Permalink
Copy weights and preserve device for 8da4w QAT linear (#211)
Browse files Browse the repository at this point in the history
* Copy weights and preserve device for 8da4w QAT linear

Summary: This fixes two correctness bugs. First, we never copied
over the weights from the existing linear, so we would start from
random weights even when loading from checkpoints. Second, we
never preserved the device of the original linear. This is
important for settings like FSDP, where we expect non-zero ranks
to have their parameters on the meta device in order to
initialize these parameters correctly.

Test Plan:
python test/quantization/test_qat.py -k test_qat_8da4w_quantizer
python test/quantization/test_qat.py -k test_qat_8da4w_quantizer_meta_weights

Reviewers: jerryzh168, cpuhrsch

Subscribers: jerryzh168, cpuhrsch, supriyar

* Update test_qat.py
  • Loading branch information
andrewor14 authored May 6, 2024
1 parent 58b0899 commit ce78e79
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 23 deletions.
23 changes: 12 additions & 11 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,17 +169,6 @@ def test_qat_8da4w_quantizer(self):
qat_model = qat_quantizer.prepare(m)
ptq_model = ptq_quantizer.quantize(m2)

# Force the weights to be the same
self._set_ptq_weight(
ptq_model.linear1, qat_model.linear1.weight, group_size,
)
self._set_ptq_weight(
ptq_model.sub.linear, qat_model.sub.linear.weight, group_size,
)
self._set_ptq_weight(
ptq_model.linear2, qat_model.linear2.weight, group_size,
)

# Compare model values
torch.manual_seed(self.SEED)
x = m.example_inputs()
Expand All @@ -200,6 +189,18 @@ def test_qat_8da4w_quantizer(self):
for k in ptq_state_dict.keys():
torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
def test_qat_8da4w_quantizer_meta_weights(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer

with torch.device("meta"):
m = M()
self.assertTrue(all(v.is_meta for v in m.state_dict().values()))
group_size = 16
qat_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
qat_model = qat_quantizer.prepare(m)
self.assertTrue(all(v.is_meta for v in qat_model.state_dict().values()))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
def test_qat_8da4w_quantizer_disable_fake_quant(self):
"""
Expand Down
27 changes: 16 additions & 11 deletions torchao/quantization/GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,22 +1127,26 @@ def _replace_linear_8da4w(
precision: torch.dtype,
scales_precision: torch.dtype,
linear_class: Type[torch.nn.Module],
copy_weights: bool = False,
):
for name, child in module.named_children():
if isinstance(child, nn.Linear):
if _check_linear_int4_k(child.in_features, groupsize) or padding_allowed:
setattr(
module,
name,
linear_class(
child.in_features,
child.out_features,
bias=False,
groupsize=groupsize,
precision=precision,
scales_precision=scales_precision,
),
new_linear = linear_class(
child.in_features,
child.out_features,
bias=False,
device=child.weight.device,
groupsize=groupsize,
precision=precision,
scales_precision=scales_precision,
)
# In distributed training, the model may be instantiated
# on the meta device, in which case there is no need to
# copy the weights, and doing so will result in an error
if copy_weights and child.weight.device != torch.device("meta"):
new_linear.weight = child.weight
setattr(module, name, new_linear)
else:
_replace_linear_8da4w(
child,
Expand All @@ -1151,6 +1155,7 @@ def _replace_linear_8da4w(
precision,
scales_precision,
linear_class,
copy_weights,
)

def replace_linear_8da4w(
Expand Down
4 changes: 3 additions & 1 deletion torchao/quantization/prototype/qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def prepare(
self.precision,
self.scales_precision,
Int8DynActInt4WeightQATLinear,
copy_weights = True,
)
return model

Expand Down Expand Up @@ -111,6 +112,7 @@ def __init__(
in_features: int,
out_features: int,
bias: bool = False,
device: torch.device = None,
groupsize: int = 256,
precision: torch.dtype = torch.float32,
scales_precision: torch.dtype = torch.float32,
Expand All @@ -119,7 +121,7 @@ def __init__(
in_features,
out_features,
bias,
device=None,
device=device,
dtype=precision,
)
assert (
Expand Down

0 comments on commit ce78e79

Please sign in to comment.