diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 8d62a842f4..84c398426f 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -863,13 +863,13 @@ class LinearNF4(torch.autograd.Function): @staticmethod def forward(ctx, input: torch.Tensor, weight: NF4Tensor): """Save the quantized nf4 weight for backward pass""" - ctx.nf4_weight = weight + ctx.save_for_backward(weight) return F.linear(input, weight.to(input.dtype)) @staticmethod def backward(ctx, grad_output): """The nf4 weight will never require grad so we can just return the grad_output @ weight.to(grad_output.dtype)""" - weight: NF4Tensor = ctx.nf4_weight + weight: NF4Tensor = ctx.saved_tensors[0] return grad_output @ weight.to(grad_output.dtype), None