Skip to content

Commit 1fbf788

Browse files
authored
[NF4] Add quantize_() API support for NF4 (#1216)
quantize api for nf4
1 parent 8c07d22 commit 1fbf788

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

test/dtypes/test_nf4.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
_INNER_TENSOR_NAMES_FOR_SHARDING,
3131
NF4Tensor,
3232
linear_nf4,
33+
nf4_weight_only,
3334
to_nf4,
3435
)
3536

@@ -281,6 +282,32 @@ def test_empty_like(self, input_size: Union[Tuple[int], int]):
281282
self.assertEqual(new_tensor.get_device(), -1) # that it's on CPU
282283
self.assertEqual(new_tensor.size(), nf4_tensor.size())
283284

285+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
286+
@parametrize("compile", [False, True])
287+
def test_quantize_api(self, compile):
288+
nf4_linear = nn.Linear(512, 512, device="cuda")
289+
torchao.quantize_(nf4_linear, nf4_weight_only())
290+
assert isinstance(nf4_linear.weight, NF4Tensor)
291+
292+
ref_linear = copy.deepcopy(nf4_linear)
293+
ref_linear.weight.data = ref_linear.weight.get_original_weight() # dequantize
294+
295+
if compile:
296+
nf4_linear.compile()
297+
ref_linear.compile()
298+
299+
nf4_x = torch.randn(2, 512, device="cuda").requires_grad_()
300+
ref_x = nf4_x.detach().clone().requires_grad_()
301+
302+
nf4_out = nf4_linear(nf4_x)
303+
ref_out = ref_linear(ref_x)
304+
self.assertEqual(nf4_out, ref_out)
305+
306+
grad_out = torch.randn(2, 512, device="cuda")
307+
nf4_out.backward(grad_out)
308+
ref_out.backward(grad_out)
309+
self.assertEqual(nf4_x.grad, ref_x.grad)
310+
284311

285312
class TestFSDPOps(TestCase):
286313
@parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)])

torchao/dtypes/nf4tensor.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -954,6 +954,15 @@ def to_nf4(tensor, block_size: int = 64, scaler_block_size: int = 256):
954954
return NF4Tensor.from_tensor(tensor, block_size, scaler_block_size)
955955

956956

957+
def nf4_weight_only(block_size: int = 64, scaler_block_size: int = 256):
958+
from torchao.quantization.quant_api import _get_linear_subclass_inserter
959+
960+
def _to_nf4(tensor: torch.Tensor):
961+
return NF4Tensor.from_tensor(tensor, block_size, scaler_block_size)
962+
963+
return _get_linear_subclass_inserter(_to_nf4)
964+
965+
957966
NF4_TORCH_FUNCTIONS = {}
958967

959968

@@ -1000,6 +1009,17 @@ def function_cpu(*args, **kwargs):
10001009
return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs))
10011010

10021011

1012+
@implements_torch_function(F.linear)
1013+
def _(*args, **kwargs):
1014+
input = args[0]
1015+
weight = args[1]
1016+
bias = args[2] if len(args) > 2 else None
1017+
out = LinearNF4.apply(input, weight)
1018+
if bias is not None:
1019+
out = out + bias
1020+
return out
1021+
1022+
10031023
@torch._dynamo.allow_in_graph
10041024
def nf4_constructor(
10051025
tensor_meta: SubclassTensorArgs,

0 commit comments

Comments
 (0)