Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove workaround for unwrapping tensor subclasses #345

Open
jerryzh168 opened this issue Jun 12, 2024 · 3 comments
Open

Remove workaround for unwrapping tensor subclasses #345

jerryzh168 opened this issue Jun 12, 2024 · 3 comments
Assignees

Comments

@jerryzh168
Copy link
Contributor

Currently we need to do:

from torchao.quantization.utils import unwrap_tensor_subclass
m_unwrapped = unwrap_tensor_subclass(m)


# export
m = torch.export.export(m_unwrapped, example_inputs).module()

# aot_compile
torch._export.aot_compile(m_unwrapped, example_inputs)

to make tensor subclass work with export/aot_compile, this should be added to default export path directly

@bhack
Copy link

bhack commented Dec 5, 2024

In some example in the repo I see now:

        if not TORCH_VERSION_AT_LEAST_2_5:
            unwrap_tensor_subclass(mod)

But with pytorch and ao nightlies with torch.export.export_for_inference it seems that I still require unwrap_tensor_subclass.
So what is the use case of this conditional unwrap?

@jerryzh168
Copy link
Contributor Author

@bhack unconditional unwrap is for torch.export and AOTI as I mentioned above

if not TORCH_VERSION_AT_LEAST_2_5:
            unwrap_tensor_subclass(mod)

is a bit separate, this is because we fixed a compile problem starting from 2.5, so we don't need this unwrap for torch.compile after 2.5

@jerryzh168
Copy link
Contributor Author

@tugsbayasgalan is fixing this issue in pytorch/pytorch#141941 actually, so we won't need to call this workaround function after the fix is done

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants