-
Notifications
You must be signed in to change notification settings - Fork 185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
unwrap_tensor_subclass and nested tensor subclasses issue #515
Comments
cc @bdhirsh |
Hmm I pulled latest main of torchao and patched in that change, but when I run it I get a different error (before we hit compile):
|
But two things to call out here: (1) @vayuda you actually do not need to use The story there is that But for regular torch.compile usage, it's perfectly fine to have model parameters/buffers that are python tensor subclasses (2) that parametrization code doesn't appear to play very well with dynamo today. Fixing that is tracked here: pytorch/pytorch#129682 |
oh this is because we have a bit of change in ao/torchao/quantization/quant_api.py Line 462 in 4f53882
|
I updated the repro to reflect this change
So when i comment out Im not sure how an eager backend works, so dynamo has extracted the graph but doesn't do anything with it? I would like to use the full power of torch compile and fuse kernels together, but Im not interested in autograd. I dont need AOTI support for now. I guess I can retest after pytorch/pytorch#129682 is fixed. |
Interesting, I can repro too when I comment out @IvanKobzarev is going to take a look |
@vayuda for the following code
is this trying to simulate that the int_data field will be another tensor subclass like |
Yes exactly. Here is the original version:
|
@vayuda can you check if this fix works? pytorch/pytorch#132096 you'll need to install pytorch from source (with the PR applied) |
@vayuda As a details - I tested compilation of original repro without ao unwrap_tensor_subclasses, for me after pytorch/pytorch#132096 compilation works fine, aot_autograd handles all nested subclass logic. |
I built pytorch from your pr branch. I get a different error from last time now, but it is probably unrelated to this issue. Strangely, this repro errors, but my full code doesn't, so I'm still satisfied with the fix haha. Error:
|
…ed tensors ordering" get_plain_tensors() should result in DFS of leaves. The error was that plain tensors (leaves) on the same level were returned before subclasses plained tensors even if subclasses are before in "flatten" list. Original issue from AO: pytorch/ao#515 Test:TBD, need to make asymetric subclass with dense tensors and subclasses [ghstack-poisoned]
…ing" get_plain_tensors() should result in DFS of leaves. The error was that plain tensors (leaves) on the same level were returned before subclasses plained tensors even if subclasses are before in "flatten" list. Original issue from AO: pytorch/ao#515 Test:TBD, need to make asymetric subclass with dense tensors and subclasses [ghstack-poisoned]
get_plain_tensors() should result in DFS of leaves. The error was that plain tensors (leaves) on the same level were returned before subclasses plained tensors even if subclasses are before in "flatten" list. Original issue from AO: pytorch/ao#515 Test:TBD, need to make asymetric subclass with dense tensors and subclasses Pull Request resolved: #132096 Approved by: https://github.com/bdhirsh
Hi, this seems a blocker of being able to work with torch subtypes in torch.export A simple example that will result in:
import torch
import numpy as np
from torchao.dtypes.uint4 import UInt4Tensor
class ExampleModel(torch.nn.Module):
def __init__(self):
"""Init"""
super().__init__()
x_uint8 = torch.randint(0, 16, (4, 8)).to(torch.uint8)
self.x = UInt4Tensor.from_unpacked(x_uint8)
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""Invoke."""
transposed = self.x.view(8, 4).to(torch.uint8)
return torch.add(input, transposed)
def main():
input_tensor = torch.randn(8, 4)
model = ExampleModel()
with torch.no_grad():
exported_program = torch.export.export(
model.eval(), args=(), kwargs={"input":input_tensor},
)
print("===exported_program====")
print(exported_program)
if __name__ == "__main__":
main() Let us know if there's something we can try (the above issue is in 2.4 release). Also, let us know if we should log this into a separate issue, or this is the right place to track it down. Cheers. |
Hi @raziel ! for torch.export (and AOTI) you'd still need to use Looks like you are trying to use uint4 quantization, please take a look at our current doc for integer affine quantization here: https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization and here is a test for torch.export that uses torchao API and produces some quantization ops that can be lowered to other backends (uint4 is still represented with uint8 + qmin/qmax restriction): ao/test/integration/test_integration.py Lines 1416 to 1468 in 2d33197
Our UInt4 tensor subclass is an initial prototype example and it's not ready for use yet, we are actively working on sub-byte dtypes starting from #468 |
Hi @jerryzh168 ! Yeah I saw pytorch/pytorch#129682 . I originally posted there but since this seemed a more specific issue to our example I moved the post here. Is there a timeline for pytorch/pytorch#129682 @bdhirsh ? Based on the flag it seems you're targeting not needing the unwrap_tensor_subclass until 2.6?
Regarding the use of unwrap_tensor_subclass, I believe we tried it but had some issues. @xiaoqiqi177 can provide more details. Thanks! |
Oh I see, makes sense, I'm not sure about the timeline yet, last I heard is a few weeks, cc @tugsbayasgalan can you give a timeline for pytorch/pytorch#129682 |
@jerryzh168 Regarding the use of unwrap_tensor_subclass, here are several problems I encountered for my use case:
(Let me know if I need to open a new thread for the details) Overall, it would be nice to have a more general support for Tensor subclasses from torch.export, including uint4Tensor like the one in ao without unwrap_tensor_subclass such that the it could be exported to weights represented in torch.uint4. |
yes, we can definitely support more than linear, please feel free to open an PR. just need to extend this: Line 256 in 360a003
so there is some preliminary work from intel to pack the int4 weights by default (pytorch/pytorch#129940, we haven't integrate this properly in torchao yet), but we also have intx work going on that packs the lower bits: #468 that's probably going to land soon, I think this is what we can start with. for export I'm not sure what kind of format you need, if it's better to have single pack/unpack ops, instead of seeing all the bit shifting operations in the IR, we could use Line 147 in 360a003
not exactly sure where this comes from, we'd need to check the source code first.
yeah I feel a new thread might make sense, the issue in this thread has been resolved actually. |
this is resolved now, please refer to https://github.com/pytorch/ao/tree/main/torchao/quantization#workaround-with-unwrap_tensor_subclass-for-export-aoti-and-torchcompile-pytorch-24-and-before-only for docs. closing for now |
* make --device fast the default * Update iOS.md (pytorch#517) * Update iOS.md * Update iOS.md * Pip to pip3 (pytorch#504) * remove macos-12 test * pip to pip3 * break aoti CI jobs separately (pytorch#500) * init * fixes * more fixes * fixes * fix * fix * bug fix * add objcopy update * suppress int8 * undefined variable --------- Co-authored-by: Michael Gschwind <[email protected]> * Support llama3 in chat in run.cpp (pytorch#486) * refactor chat runner in preparation for llama3 * add sketch for llama3 prompt template and move to returning tokens * fix tiktoken * fixes to chat * add default llama_ver * Add tests for quantize json, add cuda device specification and precision to cuda.json (pytorch#519) * remove code for no KV Cache path (pytorch#527) * Update ADVANCED-USERS.md (pytorch#529) Update Advanced Users description to reflect changes in the repo since the description was initially created. * runner-aoti on cuda (pytorch#531) * runner-aoti on cuda * transfer results back to CPU * transfer results back to CPU * runner-aoti on cuda * Update runner_build.md (pytorch#530) Update description of runner and build process in runner_build.md * clean up runner code a little (pytorch#532) * clean up runner code a little * update * update * pull out generate loop in chat * updates * edit docs * typo * move int8 linear class and function into qops.py (pytorch#534) * add dtype tests for runner-aoti + runner-et (pytorch#539) * add dtype tests for runner-aoti + runner-et * typo * Quantized embedding (pytorch#536) * move int8 linear class and function into qops.py * move Quantized Embedding to qops.py * Move Linear int4 to qops (pytorch#537) * move int8 linear class and function into qops.py * move Quantized Embedding to qops.py * move int4 linear to qops * Revert "add dtype tests for runner-aoti + runner-et (pytorch#539)" (pytorch#548) This reverts commit a7a24577a65be67ac9ae4dc05452f35d9c49e5d1. * fix generate for llama3 (pytorch#538) * fix generate for llama3 * switch more things to C * remove C++ header * add delegation visualization instructions (pytorch#551) * Add dtype runner aoti (pytorch#552) * add dtype tests for runner-aoti + runner-et * typo * add dtype test runner-aoti * test sdpa with fp16 (pytorch#553) * test sdpa with fp16 * kv cache fp32 * typo * update (pytorch#560) * Only support newest versions of lm-eval (pytorch#556) Summary: remove support for lm-eval 0.3 to reduce the options we have Test Plan: CI Reviewers: Subscribers: Tasks: Tags: * split cpu eval CI by dtype (pytorch#554) * split cpu eval CI by dtype * fix * differentiate names with checks * keep one name the same as old * fix * Removing duplicate HF issue message from README (pytorch#559) Co-authored-by: Michael Gschwind <[email protected]> * doc updates (pytorch#567) * Add VM-safe MPS check --------- Co-authored-by: Anthony Shoumikhin <[email protected]> Co-authored-by: metascroy <[email protected]> Co-authored-by: Nikita Shulga <[email protected]> Co-authored-by: lucylq <[email protected]> Co-authored-by: Jerry Zhang <[email protected]> Co-authored-by: Jack-Khuu <[email protected]>
* code beautification * code beautification, move functions together * make --device fast the default (pytorch#515) * make --device fast the default * Update iOS.md (pytorch#517) * Update iOS.md * Update iOS.md * Pip to pip3 (pytorch#504) * remove macos-12 test * pip to pip3 * break aoti CI jobs separately (pytorch#500) * init * fixes * more fixes * fixes * fix * fix * bug fix * add objcopy update * suppress int8 * undefined variable --------- Co-authored-by: Michael Gschwind <[email protected]> * Support llama3 in chat in run.cpp (pytorch#486) * refactor chat runner in preparation for llama3 * add sketch for llama3 prompt template and move to returning tokens * fix tiktoken * fixes to chat * add default llama_ver * Add tests for quantize json, add cuda device specification and precision to cuda.json (pytorch#519) * remove code for no KV Cache path (pytorch#527) * Update ADVANCED-USERS.md (pytorch#529) Update Advanced Users description to reflect changes in the repo since the description was initially created. * runner-aoti on cuda (pytorch#531) * runner-aoti on cuda * transfer results back to CPU * transfer results back to CPU * runner-aoti on cuda * Update runner_build.md (pytorch#530) Update description of runner and build process in runner_build.md * clean up runner code a little (pytorch#532) * clean up runner code a little * update * update * pull out generate loop in chat * updates * edit docs * typo * move int8 linear class and function into qops.py (pytorch#534) * add dtype tests for runner-aoti + runner-et (pytorch#539) * add dtype tests for runner-aoti + runner-et * typo * Quantized embedding (pytorch#536) * move int8 linear class and function into qops.py * move Quantized Embedding to qops.py * Move Linear int4 to qops (pytorch#537) * move int8 linear class and function into qops.py * move Quantized Embedding to qops.py * move int4 linear to qops * Revert "add dtype tests for runner-aoti + runner-et (pytorch#539)" (pytorch#548) This reverts commit a7a24577a65be67ac9ae4dc05452f35d9c49e5d1. * fix generate for llama3 (pytorch#538) * fix generate for llama3 * switch more things to C * remove C++ header * add delegation visualization instructions (pytorch#551) * Add dtype runner aoti (pytorch#552) * add dtype tests for runner-aoti + runner-et * typo * add dtype test runner-aoti * test sdpa with fp16 (pytorch#553) * test sdpa with fp16 * kv cache fp32 * typo * update (pytorch#560) * Only support newest versions of lm-eval (pytorch#556) Summary: remove support for lm-eval 0.3 to reduce the options we have Test Plan: CI Reviewers: Subscribers: Tasks: Tags: * split cpu eval CI by dtype (pytorch#554) * split cpu eval CI by dtype * fix * differentiate names with checks * keep one name the same as old * fix * Removing duplicate HF issue message from README (pytorch#559) Co-authored-by: Michael Gschwind <[email protected]> * doc updates (pytorch#567) * Add VM-safe MPS check --------- Co-authored-by: Anthony Shoumikhin <[email protected]> Co-authored-by: metascroy <[email protected]> Co-authored-by: Nikita Shulga <[email protected]> Co-authored-by: lucylq <[email protected]> Co-authored-by: Jerry Zhang <[email protected]> Co-authored-by: Jack-Khuu <[email protected]> * add unpacking support (pytorch#525) * add unpacking support * fix typos and linter * perform parallel prefill when possible (pytorch#568) * perform parallel prefill when possible * typo * disable hack * remove print * remove debug messages which prevent export * fixes * stream results in generate.py (pytorch#571) * remove logging interfering with export --------- Co-authored-by: Anthony Shoumikhin <[email protected]> Co-authored-by: metascroy <[email protected]> Co-authored-by: Nikita Shulga <[email protected]> Co-authored-by: lucylq <[email protected]> Co-authored-by: Jerry Zhang <[email protected]> Co-authored-by: Jack-Khuu <[email protected]>
I'm noticing strange behavior when trying to create a tensor_subclass which holds another tensor_sub class.
Here is a minified repro: (add this to the bottom of
torchao/dtypes/affine_quantized_tensor.py
When running this code I get the following error when calling the model after compiling:
There might be an issue with how
unwrap_tensor_subclass
handles cases where there are nested tensor_subclasses, but Im not sure why this doesn't work, but AffineQuantizedTensor is able to hold an AQTLayout tensor and work just fine.You can check out https://github.com/vayuda/ao/blob/intx/torchao/dtypes/affine_quantized_tensor.py#L593 for what Im trying to do.
The text was updated successfully, but these errors were encountered: