Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NF4] .to() fixes #1312

Merged
merged 2 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions test/dtypes/test_nf4.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,17 +504,22 @@ def test_to_cuda(self):
self.assertEqual(nf4_tensor.device.type, "cpu")
nf4_tensor = nf4_tensor.to("cuda", non_blocking=True)
self.assertEqual(nf4_tensor.device.type, "cuda")
self.assertEqual(type(nf4_tensor), NF4Tensor)
nf4_tensor.get_original_weight() # make sure we can dequantize

nf4_tensor = to_nf4(torch.randn(512 * 512))
self.assertEqual(nf4_tensor.device.type, "cpu")
nf4_tensor = nf4_tensor.to("cuda")
self.assertEqual(nf4_tensor.device.type, "cuda")
self.assertEqual(type(nf4_tensor), NF4Tensor)
nf4_tensor.get_original_weight()

nf4_tensor = to_nf4(torch.randn(512 * 512))
self.assertEqual(nf4_tensor.device.type, "cpu")
nf4_tensor = nf4_tensor.to("cuda", torch.bfloat16)
self.assertEqual(nf4_tensor.device.type, "cuda")
self.assertEqual(nf4_tensor.dtype, torch.bfloat16)
self.assertEqual(type(nf4_tensor), torch.Tensor) # dequantized

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_to_cpu(self):
Expand All @@ -524,6 +529,37 @@ def test_to_cpu(self):
for attr in _INNER_TENSOR_NAMES_FOR_SHARDING:
inner_tensor = getattr(nf4_tensor, attr)
self.assertEqual(inner_tensor.device.type, "cpu")
nf4_tensor.get_original_weight() # make sure we can dequantize

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_to_module(self):
linear = nn.Linear(512, 512, bias=False)
linear.weight = nn.Parameter(
to_nf4(linear.weight.detach()), requires_grad=False
)
linear.cuda()
self.assertEqual(linear.weight.device.type, "cuda")
weight = linear.weight.get_original_weight()
self.assertEqual(weight.device.type, "cuda")

linear.cpu()
self.assertEqual(linear.weight.device.type, "cpu")
weight = linear.weight.get_original_weight()
self.assertEqual(weight.device.type, "cpu")

linear = nn.Linear(512, 512, bias=False)
linear.weight = nn.Parameter(
to_nf4(linear.weight.detach()), requires_grad=False
)
linear.to("cuda")
self.assertEqual(linear.weight.device.type, "cuda")
weight = linear.weight.get_original_weight()
self.assertEqual(weight.device.type, "cuda")

linear.to("cpu")
self.assertEqual(linear.weight.device.type, "cpu")
weight = linear.weight.get_original_weight()
self.assertEqual(weight.device.type, "cpu")

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)])
Expand Down
55 changes: 32 additions & 23 deletions torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,35 +980,44 @@ def decorator(func):
@implements_torch_function(torch.Tensor.to)
def function_to_dtype(*args, **kwargs):
tensor = args[0]
if isinstance(args[1], torch.dtype):
# Tensor.to(dtype, non_blocking, copy, memory_format)
return tensor.get_original_weight().to(*args[1:], **kwargs)
elif (
isinstance(args[1], torch.device)
or (
isinstance(args[1], str)
and (args[1] == "cpu" or args[1].startswith("cuda"))
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
*args[1:], **kwargs
)

# dtype is specified -> dequantize
if dtype is not None:
return tensor.get_original_weight().to(
device, dtype, non_blocking, memory_format=convert_to_format
)
) and len(args) == 2:
# Tensor.to(device, non_blocking)
device = args[1]
updated_attrs = call_from_inner_tensors(tensor, "to", args[1:], kwargs)
updated_attrs["device"] = device
return NF4Tensor(*construct_nf4_args(tensor, updated_attrs))
else:
# Tensor.to(device, dtype, non_blocking, copy, memory_format)
# Tensor.to(other, non_blocking, copy)
raise NotImplementedError(
f"NF4Tensor.to({args[1:]}, {kwargs}) is not supported, passing to dispatch"

# dtype is not specified -> keep NF4
updated_attrs = dict(device=device)
tensor_attrs, _ = tensor.__tensor_flatten__()
for attr in tensor_attrs:
inner_tensor = getattr(tensor, attr)
updated_attrs[attr] = inner_tensor.to(
device, dtype, non_blocking, memory_format=convert_to_format
)
return NF4Tensor(*construct_nf4_args(tensor, updated_attrs))


@implements_torch_function(torch.Tensor.cpu)
def function_cpu(*args, **kwargs):
nf4tensor = args[0]
updated_attrs = call_from_inner_tensors(nf4tensor, "cpu", args[1:], kwargs)
updated_attrs["device"] = "cpu"
return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs))
# Tensor.cpu(self, memory_format)
return args[0].to("cpu", *args[1:], **kwargs)


@implements_torch_function(torch.Tensor.cuda)
def function_cuda(*args, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this func not work call_from_inner_tensors

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

call_from_inner_tensors() does not call the method on .scaler_mean and .nf4 attribute, hence I use __tensor_flatten__ instead.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. That's another anti-pattern since call_from_inner_tensors is typically applied for all the tensors and not just specific to sharding properties. But this makes sense. Maybe we should just like update it and have a flag that says ignore sharding or not.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise, thank you. Yeah, I really didn't like the to calling dequant secretly, so I think this more aligns with what we've seen in the rest of the library.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I also penned some of my thoughts in #1310.

To clarify, this PR does not change the behavior "to calling dequant secretly (when dtype is specified)". Apart from fixes for .scaler_mean and .nf4 attributes, this PR only changes .cuda() behavior to not dequantize (previously .cuda() will propagate to aten._to_copy, which dequantize), so it's more consistent with .cpu() as well as the general "not dequantize when dtype is not specified".

# Tensor.cuda(self, device, non_blocking, memory_format)
tensor = args[0]
updated_attrs = dict()
tensor_attrs, _ = tensor.__tensor_flatten__()
for attr in tensor_attrs:
inner_tensor = getattr(tensor, attr)
updated_attrs[attr] = inner_tensor.cuda(*args[1:], **kwargs)
updated_attrs["device"] = updated_attrs[tensor_attrs[0]].device
return NF4Tensor(*construct_nf4_args(tensor, updated_attrs))


@implements_torch_function(F.linear)
Expand Down
Loading