Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NF4] Various bugs in how NF4 handles .to() to move to a different device #1310

Open
gau-nernst opened this issue Nov 19, 2024 · 0 comments · May be fixed by #1312
Open

[NF4] Various bugs in how NF4 handles .to() to move to a different device #1310

gau-nernst opened this issue Nov 19, 2024 · 0 comments · May be fixed by #1312
Labels
bug Something isn't working

Comments

@gau-nernst
Copy link
Collaborator

gau-nernst commented Nov 19, 2024

Reproduction

import torch
from torch import nn
from torchao.dtypes.nf4tensor import to_nf4

x = torch.randn(1024, 1024)
x_nf4 = to_nf4(x)
print(x_nf4.cuda())  # this will dequantize NF4 -> unwanted
print(x_nf4.to(device="cuda"))  # this will raise error
print(x_nf4.to("cuda"))  # this will do the right thing

# .cpu() does not move .nf4 to CPU, because call_from_inner_tensors does not call the method on .nf4
x = torch.randn(1024, 1024).cuda()
x_nf4 = to_nf4(x).cpu()
print(x_nf4.quantized_data.device)  # cpu
print(x_nf4.nf4.device)  # cuda:0
print(x_nf4.to(torch.float32))  # error due to device mismatch

# not working with nn.Module
linear = nn.Linear(1024, 1024)
linear.weight = nn.Parameter(to_nf4(linear.weight.detach()), requires_grad=False)
linear.cuda()  # NF4 weight is not moved to CUDA
# linear.to("cuda")  # same problem

print(linear.weight.device)  # cuda:0
print(linear.weight.quantized_data.device)  # cpu
print(linear.weight.to(torch.float32).device)  # cpu

Summary:

  1. NF4Tensor.cuda() will dequantize -> this is unwanted
  2. NF4Tensor.to(device="cuda") will raise IndexError, since args[1] does not exist
  3. NF4Tensor.cpu() does not move .nf4 attribute -> cannot dequantize
  4. Does not work with nn.Module.to(device)
  • IMO, the semantics NF4Tensor.to(torch.float32) will dequantize is the culprit that causes these troubles + it is not consistent with AQT behavor. If .to(dtype) does not dequantize (only change appearance dtype), we only need to implement aten._to_copy instead of Tensor.cpu, Tensor.to and myriad of others. Though I understand this design is to make NF4 feels more like a true dtype.
  • I think it makes more sense to designate NF4Tensor.dequantize() as the method to dequantize the tensor (also consistent with plain Tensor behavior, though plain Tensor.dequantize() will always return FP32), instead of the current situation (NF4Tensor.dequantize() is a static method for lookup table, while NF4Tensor.get_original_weight() does dequant)
  • Changing this is BC, so we probably leave it as is.
@gau-nernst gau-nernst added the bug Something isn't working label Nov 19, 2024
@gau-nernst gau-nernst linked a pull request Nov 19, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant