From a6f03c12f4039be5c4f448d45825317e48d4b772 Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 4 Jun 2024 12:34:13 -0700 Subject: [PATCH] switch to save for backward since are now a tensor input --- torchao/dtypes/nf4tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 8d62a842f4..84c398426f 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -863,13 +863,13 @@ class LinearNF4(torch.autograd.Function): @staticmethod def forward(ctx, input: torch.Tensor, weight: NF4Tensor): """Save the quantized nf4 weight for backward pass""" - ctx.nf4_weight = weight + ctx.save_for_backward(weight) return F.linear(input, weight.to(input.dtype)) @staticmethod def backward(ctx, grad_output): """The nf4 weight will never require grad so we can just return the grad_output @ weight.to(grad_output.dtype)""" - weight: NF4Tensor = ctx.nf4_weight + weight: NF4Tensor = ctx.saved_tensors[0] return grad_output @ weight.to(grad_output.dtype), None