-
Notifications
You must be signed in to change notification settings - Fork 185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added first bits of Uint2Tensor and BitnetTensor #282
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/282
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit ae4ead1 with merge base 664f073 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
test/dtypes/test_uint2.py
Outdated
_apply_weight_only_uint2_quant(m) | ||
y_wo = m(x) | ||
# sqnr = compute_error(y_ref, y_wo) | ||
#opt = torch.compile(m, fullgraph=True, mode="max-autotune") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the error you were getting?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AssertionError: Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with 'allow_non_fake_inputs'. Found in aten.sym_storage_off.set.default(_to_functional_tensor(FakeTensor(..., size=(16, 4), dtype=torch.uint8)))
It generally failed enabling @torch.compile on pack and unpack functions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bdhirsh any ideas? I've seen this scary error before and never understood what it meant beyond AOTAutograd is not doing its thing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
opened this issue pytorch/pytorch#127374
torchao/dtypes/uint2.py
Outdated
return output | ||
|
||
else: | ||
#@torch.compile |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did these all fail torch.compile checks?
Btw is the idea here to trigger the else condition on cpu only?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that is right. I'll check torch.compile once again.
torchao/dtypes/uint2.py
Outdated
return (input.view(torch.uint8).to(torch.float32) - zero_point) * scale | ||
|
||
|
||
class UInt2Tensor(torch.Tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we want to generalize UInt4Tensor to work with uint1 to uint7 directly, but for now maybe we can put this in prototype to unblocking kernel development
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, will generalize it.
TODO List
|
a4ceeb1
to
fcd7c08
Compare
Co-authored-by: James Melvin Ebenezer <[email protected]>
Co-authored-by: James Melvin Ebenezer <[email protected]>
I've written a testing network (1 layer MLP) to test full functionality. The primary failure right now is that transpose doesn't work due to padding issues and the matmul operation is not found despite being defined in the bitnet linear fn. |
Added several operations for UInt2Tensor, still needs work.
test/dtypes/test_uint2.py
Outdated
class TestUInt2(QuantizationTestCase): | ||
def test_gpu_quant(self): | ||
device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
favor pytest parametrize that way we'll know which failed from CI logs if any
test/dtypes/test_uint2.py
Outdated
for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]: | ||
x = torch.randn(*x_shape).to(device) | ||
m = nn.Sequential(nn.Linear(4, 16)).to(device) | ||
y_ref = m(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need some compile test
torchao/prototype/dtypes/uint2.py
Outdated
triton_pack_uint2[grid](uint8_data, output, n_elements, BLOCK_SIZE=1024) | ||
return output | ||
|
||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are the kernels compilable on CPU? just wondering if these should ever get called today or they're mostly in so we can eventually revert the triton kernels?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, That's right. Will eventually revert the triton kernels with torch.compile.
return torch.equal(self.elem, other.elem) | ||
|
||
@classmethod | ||
def __torch_dispatch__(cls, func, types, args, kwargs=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not feedback to you guys but I kind of hate how this looks - @cpuhrsch @jerryzh168 is there a more readable design pattern we could use? Something like an abstract class perhaps? Or a Dispatcher class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I like the implements
pattern from NF4 and AQT.
ao/torchao/dtypes/nf4tensor.py
Lines 46 to 54 in 12f44ab
def implements(aten_ops): | |
"""Use this decorator to implement a function for an aten op in __torch_dispatch__""" | |
def decorator(func): | |
for op in aten_ops: | |
NF4_OPS_TABLE[op] = func | |
return func | |
return decorator |
Then you can do stuff like
@implements([torch.ops.aten.to.dtype])
def to_dtype(func, *args, **kwargs):
if not args[0][0].is_contiguous():
assert args[0][0].t().is_contiguous()
return torch.ops.aten.to.dtype(args[0][0].t(), args[0][1]).t()
return args[0][0].get_original_weight().to(args[0][1])
People can then also use this wrapper out of tree like in our tutorial:
Line 24 in 12f44ab
@torchao.dtypes.nf4tensor.implements([torch.ops.aten.gelu.default]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch_dispatch then becomes fairly simple
ao/torchao/dtypes/nf4tensor.py
Lines 757 to 782 in 12f44ab
@classmethod | |
def __torch_dispatch__(cls, func, types, args, kwargs=None): | |
"""TODO we are not supporting torch dispatch at the moment | |
instead we have created a Autograd.Function to handle the linear | |
""" | |
# All ops in the NF4_OPS_TABLE expect NF4 Tensors as inputs | |
# And don't support mixed tensor subclasses. This will trigger the handler for | |
# the next type in the dispatch list | |
def allowed_subclasses(type): | |
return ( | |
issubclass(cls, type) | |
or issubclass(torch._subclasses.fake_tensor.FakeTensor, type) | |
or issubclass( | |
torch._subclasses.functional_tensor.FunctionalTensor, type | |
) | |
) | |
if not all(allowed_subclasses(t) for t in types): | |
return NotImplemented("Up to the next one to handle") | |
if func in NF4_OPS_TABLE: | |
return NF4_OPS_TABLE[func](func, args, kwargs) | |
raise NotImplementedError( | |
f"NF4Tensor dispatch: attempting to run {func}, this is not supported" | |
) |
Please note that raising an exception at the end is very important. Otherwise you'll just return None
, which is valid Python, even though you might have meant to say "This isn't implemented".
Also, it's probably preferred (note: trust but verify this advice!) to use return NotImplemented
instead of raising an exception within __torch_dispatch__
, because it'll get caught and handled by the PyTorch dispatcher.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cpuhrsch incorporated your feedback in 1fdeb91. @msaroufim
Have a look if its ok.
torchao/prototype/dtypes/uintgen.py
Outdated
import torch | ||
|
||
""" | ||
Contains generic functions to pack and unpack uint8 tensors into uint2, uint3, uint4, uint5, uint6, and uint7 tensors. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you mean uintx into uint8?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes!, my bad
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
return packed_data | ||
|
||
|
||
def unpack_uint6(packed_data: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw how does this wcompare to @vayuda's work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a bit of rework. After discussio on cuda-mode between @andreaskoepf and @vayuda. Some of the odd bit functionality will be merged into @vayuda's algorithimic bitpacking.py. However will use this implementation to test correctness and speed with both approaches.
|
||
def pack_uint6(uint8_data: torch.Tensor) -> torch.Tensor: | ||
"""Pack the 6 lowest bits of 4 input bytes into 3 bytes | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very cool docstrings overall!
torchao/prototype/dtypes/uintgen.py
Outdated
assert torch.all(k == check) | ||
|
||
|
||
if __name__ == "__main__": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we move these to test/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
|
torchao/prototype/dtypes/bitnet.py
Outdated
# Quantize the input tensor to int2 | ||
quant = x.sign() + 1 | ||
|
||
if target_dtype == torch.uint2: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need torch.uint2 it doesn't do anything. You can remove the 0.3 skip test as well in the test file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
torchao/prototype/dtypes/bitnet.py
Outdated
return BitnetTensor(tensor) | ||
raise NotImplementedError(f"to {dtype} not supported") | ||
|
||
if __name__ == "__main__": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather have all of this functionality be in the test file
As in make sure you can instantiate a BItNet tensor, copy it it, transpose it, multiply and convert to from the main test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added test cases for BitnetTensor, UIntTensor
torchao/prototype/dtypes/uint2.py
Outdated
@implements([torch.ops.aten.detach.default]) | ||
def detach(func, args, kwargs): | ||
(tensor,) = args | ||
return tensor.elem.detach() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sidecar comment:
-
most of your view op implementations above wrap the output back into a
Uint2Tensor
("propagating" the subclass-ness through the model when views are encountered) -
detach
is just another view op -
so you probably want your impl for detach to wrap the output in your subclass?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that is correct. Fixed it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very cool :)
nice a segfault lol - although this is not from your code @jerryzh168 have you seen this before? https://github.com/pytorch/ao/actions/runs/9557079384/job/26344629090?pr=282#step:12:2377 |
* Added first bits of Uint2Tensor and BitnetTensor Co-authored-by: James Melvin Ebenezer <[email protected]> * add conversion to standard signed and unsigned dtypes * added triton kernel for pack and unpack * fix: test cases and device allocation for triton kernels * fix: moved uint2 to prototype folder * Add packing and unpacking functions for uint{2,3,4,5,6,7}. Co-authored-by: James Melvin Ebenezer <[email protected]> * housekeeping: renamed uint_small to uintgen and simple comments * Update uint2.py Added several operations for UInt2Tensor, still needs work. * added pytest ,compile tests and some cleanup * fix: implements pattern for uint2 and BitnetTensor * fix: torch.uint2 available after torch 2.3 * fix: test cases for BitnetTensor, UInt2Tensor and bitpacking gen * fix: removed torch.uint2 * fix: wrap detach in UIntTensor, torch.compile test * fix: CI errors on compile tests * fix: skip tests less than torch 2.4 * Added pytest fixture * remove tensor core flag --------- Co-authored-by: James Melvin Ebenezer <[email protected]> Co-authored-by: Z <[email protected]> Co-authored-by: Pawan Jayakumar <[email protected]> Co-authored-by: Mark Saroufim <[email protected]>
just saw this comment... I just saw the same error when I'm porting the uintx to pytorch (#635), but Charlie says that autoquant subclass is weird so I disabled it |
Created a
UInt2Tensor
class (similar to theUInt4Tensor
class). Added aBitnetTensor
class and a first unit test which quantizes the weights of ann.Linear()
layer and executes the matmul.Currently generates an error if the commented
@torch.compile
lines above theunpack_uint2()
andpack_uint2()
functions are uncommented.