Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change NF4Tensor dtype and add support for linear #62

Merged
merged 12 commits into from
Mar 22, 2024
Merged

Change NF4Tensor dtype and add support for linear #62

merged 12 commits into from
Mar 22, 2024

Conversation

cpuhrsch
Copy link
Contributor

@cpuhrsch cpuhrsch commented Mar 18, 2024

Current behavior

[ 6:08PM (nightly20240318py310) /scratch/cpuhrsch/dev/ao - git:main]$ cat /tmp/asdf4.py
import torch
import torchao

a = torch.randn(32)
print(a)

a_nf4 = torchao.dtypes.to_nf4(a, 16, 2)
print(a_nf4, type(a_nf4))

a_bfloat16 = a_nf4.to(torch.bfloat16)
print(a_bfloat16, type(a_bfloat16))
[ 6:09PM (nightly20240318py310) /scratch/cpuhrsch/dev/ao - git:main]$ python /tmp/asdf4.py
tensor([ 0.7889,  0.0880,  0.8752,  0.2042,  0.1245,  0.1266,  0.9101, -0.2222,
        -0.9368, -0.5316, -0.2631, -1.7176,  0.7433,  0.4885, -0.1441,  0.7750,
        -0.8555, -0.4909, -0.5342, -0.2490,  0.2436, -0.9419, -1.9183,  0.2931,
         0.5316,  0.4517, -1.3808,  0.2183, -0.2289,  0.8697, -0.1336,  0.0282])
tensor([ 0.7578,  0.1367,  0.9688,  0.1367,  0.1367,  0.1367,  0.9688, -0.1572,
        -0.8984, -0.4902, -0.3164, -1.7188,  0.7578,  0.4238, -0.1572,  0.7578,
        -0.7578, -0.5469, -0.5469, -0.1758,  0.3105, -1.0078, -1.9219,  0.3105,
         0.4727,  0.4727, -1.3359,  0.1533, -0.1758,  0.8477, -0.1758,  0.0000]) <class 'torchao.dtypes.nf4tensor.NF4Tensor'>
tensor([ 0.7578,  0.1367,  0.9688,  0.1367,  0.1367,  0.1367,  0.9688, -0.1572,
        -0.8984, -0.4902, -0.3164, -1.7188,  0.7578,  0.4238, -0.1572,  0.7578,
        -0.7578, -0.5469, -0.5469, -0.1758,  0.3105, -1.0078, -1.9219,  0.3105,
         0.4727,  0.4727, -1.3359,  0.1533, -0.1758,  0.8477, -0.1758,  0.0000]) <class 'torchao.dtypes.nf4tensor.NF4Tensor'>

New behavior

tensor([-1.1915, -1.1581,  0.3459,  1.7048,  1.1047, -0.0950,  0.7118,  0.1842,
         1.7371, -0.8684,  0.0813, -0.0257,  0.6888, -0.8177,  0.1687, -0.2958,
        -1.2135, -2.1445, -1.4045,  0.5073, -0.3807, -1.0531,  0.5965,  0.6871,
        -0.2533, -0.9507, -0.3514,  1.7766,  0.6843,  0.6631,  0.9751, -2.8361])
tensor([-1.2031, -1.2031,  0.2793,  1.7344,  0.9766, -0.1582,  0.7656,  0.1377,
         1.7344, -0.9062,  0.1377,  0.0000,  0.7656, -0.9062,  0.1377, -0.3203,
        -1.1250, -1.9766, -1.4922,  0.4590, -0.2598, -1.1250,  0.6992,  0.6992,
        -0.2598, -0.8125, -0.2598,  1.6016,  0.6992,  0.6992,  0.9609, -2.8438]) <class 'torchao.dtypes.nf4tensor.NF4Tensor'>
tensor([-1.2031, -1.2031,  0.2793,  1.7344,  0.9766, -0.1582,  0.7656,  0.1377,
         1.7344, -0.9062,  0.1377,  0.0000,  0.7656, -0.9062,  0.1377, -0.3203,
        -1.1250, -1.9766, -1.4922,  0.4590, -0.2598, -1.1250,  0.6992,  0.6992,
        -0.2598, -0.8125, -0.2598,  1.6016,  0.6992,  0.6992,  0.9609, -2.8438],
       dtype=torch.bfloat16) <class 'torch.Tensor'>

The previous behavior happened because the dtype of NF4Tensor was bfloat16 and so the conversion to(torch.bfloat16) was seen as idempotent and didn't actually trigger any underlying functions.

@cpuhrsch cpuhrsch requested a review from rohan-varma March 18, 2024 18:05
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 18, 2024
@cpuhrsch cpuhrsch requested a review from drisspg March 18, 2024 18:05
@cpuhrsch cpuhrsch changed the title Change NF4Tensor dtype to uint4 Change NF4Tensor dtype to bit2x4 and add support for linear Mar 18, 2024
@@ -546,7 +580,7 @@ class LinearNF4(torch.autograd.Function):
def forward(ctx, input: torch.Tensor, weight: NF4Tensor):
"""Save the quantized nf4 weight for backward pass"""
ctx.nf4_weight = weight
return F.linear(input, weight.get_original_weight())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@drisspg - So we used to dequantize for each linear call? I guess that makes sense since it's essentially weight only quant.

@@ -160,7 +192,7 @@ def __new__(
tensor_meta.original_shape,
tensor_meta.original_strides,
tensor_meta.storage_offset,
dtype=tensor_meta.dtype,
dtype=torch.bits2x4,
Copy link
Contributor

@drisspg drisspg Mar 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for provenance I still dont like this 🙃

I think that nf4tensor's outer wrapper subclass should have the same dtype as the type that it was created from.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. We need a better extensibility story for dtypes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I think we want to deprecate these, why not use torch.uint2?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nf4 is a 4bit type. I suppose another mitigation is a type guard at torch_dispatch level and using torch.bits8 just so the allocator will always spit out bytes (not like it has a choice).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.bits2x4 means 8 bit though, these dtypes (including bits1x8, bits4x2) should be removed actually, since torch.bits8 means the same thing because the meaning is uninterpreted dtypes

so what are you trying to express here? 2 bits * 2 that packed into a 4 bit?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this sounds like a uint4 tensor with a different packing format, can you reuse uint4 Tensor as the underlying dtype (by inheriting from UInt4Tensor probably)? can you write down all the use cases for nf4 dtype as well so we get some idea of how we can support it?

bits8 is generally not recommended right now either btw, since all these bit shifting ops etc. are already available in uint8 so we'd recommend uint8 if you want a 8 bit dtype.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for provenance I still dont like this 🙃. I think that nf4tensor's outer wrapper subclass should have the same dtype as the type that it was created from.

I agree. Having this represent the high precision dtype has worked well for Float8Tensor.

Copy link
Contributor

@drisspg drisspg Mar 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah Uint4Tensor is the same as NF4Tensor, AFAIK I think ed copied the packing format from NFTensor in nuggets and that was the basis of uint4tensor.

Nf4Tensor was copied over to ao and not inherited for speed of enabling torchtune. But I agree that NF4 should like inherit from uint4

That being said this same outer tensor dtype issue applies the same for the uint4tensor same as it does this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So Float8Tensor's dtype is bfloat16?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. For float8 tensor specifically, this is required, because we need to trick autograd's assert x.dtype == x.grad.dtype restriction. But it's also conceptually simple to reason about, "this is an emulation of a bfloat16 tensor with scaled float8".

@@ -188,6 +221,7 @@ def __init__(
self.scaler_mean = scaler_mean
self.quantized_data = quantized_data
self.nf4 = nf4
self.transpose = transpose
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, maybe I'll put this into the strides instead and rely on is_contiguous instead.

inpt_tensor = torch.rand(128, dtype=torch.bfloat16)
inpt_tensor_nf4 = to_nf4(inpt_tensor, 32, 2)
assert type(inpt_tensor_nf4) != torch.Tensor
assert type(inpt_tensor_nf4.to(torch.bfloat16)) == torch.Tensor
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also check the dtype of inpt_tensor_nf4.to(torch.bfloat16)?

out3 = torch.compile(torch.nn.functional.linear, mode='max-autotune')(inp, a_nf4)

# torch.testing.assert_allclose(out1, out2)
# torch.testing.assert_allclose(out1, out3)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these tests pass?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not without custom atol/rtol. But this happens in a test higher up. I need to refactor these tests a bit more before landing.

@@ -428,7 +462,7 @@ def quantize_tensor_nearest(
# pyre-fixme[40]: Static method `dequantize` cannot override a non-static method
# defined in `torch._C.TensorBase`.
def dequantize(value: torch.Tensor, nf4: torch.Tensor) -> torch.Tensor:
"""Dequantize a nf4 value to float16 format"""
"""Dequantize a nf4 value to bfloat16 format"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is nf4 tensor still restricted to bf16 only for the higher precision, are there any blockers in supporting fp32?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to support arbitrary precision for conversion, but of course the fidelity of nf4 is independen of the dtype that was passed during construction.

Copy link
Member

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG overall from my end

@cpuhrsch cpuhrsch changed the title Change NF4Tensor dtype to bit2x4 and add support for linear Change NF4Tensor dtype and add support for linear Mar 22, 2024
@cpuhrsch cpuhrsch merged commit f7e12c8 into main Mar 22, 2024
3 checks passed
dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants