Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added first bits of Uint2Tensor and BitnetTensor #282

Merged
merged 21 commits into from
Jun 18, 2024

Conversation

andreaskoepf
Copy link
Contributor

Created a UInt2Tensor class (similar to the UInt4Tensor class). Added a BitnetTensor class and a first unit test which quantizes the weights of a nn.Linear() layer and executes the matmul.

Currently generates an error if the commented @torch.compile lines above the unpack_uint2() and pack_uint2() functions are uncommented.

Copy link

pytorch-bot bot commented May 26, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/282

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ae4ead1 with merge base 664f073 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 26, 2024
_apply_weight_only_uint2_quant(m)
y_wo = m(x)
# sqnr = compute_error(y_ref, y_wo)
#opt = torch.compile(m, fullgraph=True, mode="max-autotune")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the error you were getting?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AssertionError: Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with 'allow_non_fake_inputs'. Found in aten.sym_storage_off.set.default(_to_functional_tensor(FakeTensor(..., size=(16, 4), dtype=torch.uint8)))
It generally failed enabling @torch.compile on pack and unpack functions

Copy link
Member

@msaroufim msaroufim May 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdhirsh any ideas? I've seen this scary error before and never understood what it meant beyond AOTAutograd is not doing its thing

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

opened this issue pytorch/pytorch#127374

return output

else:
#@torch.compile
Copy link
Member

@msaroufim msaroufim May 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did these all fail torch.compile checks?

Btw is the idea here to trigger the else condition on cpu only?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that is right. I'll check torch.compile once again.

return (input.view(torch.uint8).to(torch.float32) - zero_point) * scale


class UInt2Tensor(torch.Tensor):
Copy link
Contributor

@jerryzh168 jerryzh168 May 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we want to generalize UInt4Tensor to work with uint1 to uint7 directly, but for now maybe we can put this in prototype to unblocking kernel development

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, will generalize it.

@melvinebenezer
Copy link
Contributor

melvinebenezer commented Jun 1, 2024

TODO List

@CoffeeVampir3
Copy link
Contributor

CoffeeVampir3 commented Jun 2, 2024

I've written a testing network (1 layer MLP) to test full functionality. The primary failure right now is that transpose doesn't work due to padding issues and the matmul operation is not found despite being defined in the bitnet linear fn.
NotImplementedError: aten.amin.default

https://github.com/CoffeeVampir3/ao-bitnet/blob/main/bitnet_staging/bitnet_trained_to_ao_test.py

melvinebenezer and others added 2 commits June 3, 2024 19:31
Added several operations for UInt2Tensor, still needs work.
class TestUInt2(QuantizationTestCase):
def test_gpu_quant(self):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

favor pytest parametrize that way we'll know which failed from CI logs if any

for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]:
x = torch.randn(*x_shape).to(device)
m = nn.Sequential(nn.Linear(4, 16)).to(device)
y_ref = m(x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need some compile test

triton_pack_uint2[grid](uint8_data, output, n_elements, BLOCK_SIZE=1024)
return output

else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are the kernels compilable on CPU? just wondering if these should ever get called today or they're mostly in so we can eventually revert the triton kernels?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, That's right. Will eventually revert the triton kernels with torch.compile.

return torch.equal(self.elem, other.elem)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not feedback to you guys but I kind of hate how this looks - @cpuhrsch @jerryzh168 is there a more readable design pattern we could use? Something like an abstract class perhaps? Or a Dispatcher class?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I like the implements pattern from NF4 and AQT.

def implements(aten_ops):
"""Use this decorator to implement a function for an aten op in __torch_dispatch__"""
def decorator(func):
for op in aten_ops:
NF4_OPS_TABLE[op] = func
return func
return decorator

Then you can do stuff like

@implements([torch.ops.aten.to.dtype])
def to_dtype(func, *args, **kwargs):
    if not args[0][0].is_contiguous():
        assert args[0][0].t().is_contiguous()
        return torch.ops.aten.to.dtype(args[0][0].t(), args[0][1]).t()
    return args[0][0].get_original_weight().to(args[0][1])

People can then also use this wrapper out of tree like in our tutorial:

@torchao.dtypes.nf4tensor.implements([torch.ops.aten.gelu.default])

Copy link
Contributor

@cpuhrsch cpuhrsch Jun 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch_dispatch then becomes fairly simple

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
"""TODO we are not supporting torch dispatch at the moment
instead we have created a Autograd.Function to handle the linear
"""
# All ops in the NF4_OPS_TABLE expect NF4 Tensors as inputs
# And don't support mixed tensor subclasses. This will trigger the handler for
# the next type in the dispatch list
def allowed_subclasses(type):
return (
issubclass(cls, type)
or issubclass(torch._subclasses.fake_tensor.FakeTensor, type)
or issubclass(
torch._subclasses.functional_tensor.FunctionalTensor, type
)
)
if not all(allowed_subclasses(t) for t in types):
return NotImplemented("Up to the next one to handle")
if func in NF4_OPS_TABLE:
return NF4_OPS_TABLE[func](func, args, kwargs)
raise NotImplementedError(
f"NF4Tensor dispatch: attempting to run {func}, this is not supported"
)

Please note that raising an exception at the end is very important. Otherwise you'll just return None, which is valid Python, even though you might have meant to say "This isn't implemented".

Also, it's probably preferred (note: trust but verify this advice!) to use return NotImplemented instead of raising an exception within __torch_dispatch__, because it'll get caught and handled by the PyTorch dispatcher.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cpuhrsch incorporated your feedback in 1fdeb91. @msaroufim
Have a look if its ok.

import torch

"""
Contains generic functions to pack and unpack uint8 tensors into uint2, uint3, uint4, uint5, uint6, and uint7 tensors.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean uintx into uint8?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes!, my bad

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

return packed_data


def unpack_uint6(packed_data: torch.Tensor) -> torch.Tensor:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw how does this wcompare to @vayuda's work?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a bit of rework. After discussio on cuda-mode between @andreaskoepf and @vayuda. Some of the odd bit functionality will be merged into @vayuda's algorithimic bitpacking.py. However will use this implementation to test correctness and speed with both approaches.


def pack_uint6(uint8_data: torch.Tensor) -> torch.Tensor:
"""Pack the 6 lowest bits of 4 input bytes into 3 bytes

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very cool docstrings overall!

assert torch.all(k == check)


if __name__ == "__main__":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we move these to test/

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@msaroufim msaroufim marked this pull request as ready for review June 15, 2024 20:53
@msaroufim msaroufim marked this pull request as draft June 15, 2024 21:41
@melvinebenezer
Copy link
Contributor

melvinebenezer commented Jun 16, 2024

1fdeb91

  • implements pattern for uint2 and bitnet

  • move uintgen tests to test/

# Quantize the input tensor to int2
quant = x.sign() + 1

if target_dtype == torch.uint2:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need torch.uint2 it doesn't do anything. You can remove the 0.3 skip test as well in the test file

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

return BitnetTensor(tensor)
raise NotImplementedError(f"to {dtype} not supported")

if __name__ == "__main__":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather have all of this functionality be in the test file

As in make sure you can instantiate a BItNet tensor, copy it it, transpose it, multiply and convert to from the main test

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added test cases for BitnetTensor, UIntTensor

@implements([torch.ops.aten.detach.default])
def detach(func, args, kwargs):
(tensor,) = args
return tensor.elem.detach()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sidecar comment:

  • most of your view op implementations above wrap the output back into a Uint2Tensor ("propagating" the subclass-ness through the model when views are encountered)

  • detach is just another view op

  • so you probably want your impl for detach to wrap the output in your subclass?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that is correct. Fixed it

Copy link
Contributor

@bdhirsh bdhirsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very cool :)

@msaroufim msaroufim marked this pull request as ready for review June 17, 2024 19:14
@msaroufim msaroufim self-requested a review June 17, 2024 19:14
@msaroufim
Copy link
Member

nice a segfault lol - although this is not from your code

@jerryzh168 have you seen this before? https://github.com/pytorch/ao/actions/runs/9557079384/job/26344629090?pr=282#step:12:2377

@msaroufim msaroufim changed the title [WIP] Added first bits of Uint2Tensor and BitnetTensor Added first bits of Uint2Tensor and BitnetTensor Jun 18, 2024
@msaroufim msaroufim merged commit cb3bd8c into pytorch:main Jun 18, 2024
13 checks passed
dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
* Added first bits of Uint2Tensor and BitnetTensor

Co-authored-by: James Melvin Ebenezer <[email protected]>

* add conversion to standard signed and unsigned dtypes

* added triton kernel for pack and unpack

* fix: test cases and device allocation for triton kernels

* fix: moved uint2 to prototype folder

* Add packing and unpacking functions for uint{2,3,4,5,6,7}.

Co-authored-by: James Melvin Ebenezer <[email protected]>

* housekeeping: renamed uint_small to  uintgen and simple comments

* Update uint2.py

Added several operations for UInt2Tensor, still needs work.

* added pytest ,compile tests and some cleanup

* fix: implements pattern for uint2 and BitnetTensor

* fix: torch.uint2 available after torch 2.3

* fix: test cases for BitnetTensor, UInt2Tensor and bitpacking gen

* fix: removed torch.uint2

* fix: wrap detach in UIntTensor, torch.compile test

* fix: CI errors on compile tests

* fix: skip tests less than torch 2.4

* Added pytest fixture

* remove tensor core flag

---------

Co-authored-by: James Melvin Ebenezer <[email protected]>
Co-authored-by: Z <[email protected]>
Co-authored-by: Pawan Jayakumar <[email protected]>
Co-authored-by: Mark Saroufim <[email protected]>
@jerryzh168
Copy link
Contributor

nice a segfault lol - although this is not from your code

@jerryzh168 have you seen this before? pytorch/ao/actions/runs/9557079384/job/26344629090?pr=282#step:12:2377

just saw this comment... I just saw the same error when I'm porting the uintx to pytorch (#635), but Charlie says that autoquant subclass is weird so I disabled it

yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants