-
Notifications
You must be signed in to change notification settings - Fork 191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add codebook (look up table based) quantization flow in torchao #1195
Comments
Hey @jerryzh168 I am new to torchao but this sounds like an issue I would want to investigate with my partner @Harthi7. We will take a look and let you know how it goes. Cheers! |
Hi, I am also new to torchao and I would like to do this issue? |
…rch#1195) * update flamingo model for tune * 1/n flamingo e2e ppl * flamingo e2e enable * bump up tune version * remove hacky cache size, add comment for magic number * dytpe set for input * manually cast dtype * extra config for deep fusion module
Hi! I'm interested in contributing to the implementation of the codebook quantization. Would it be helpful if I worked on [e.g., adding test cases]? Happy to coordinate with @DerekLiu35 to avoid duplicating effort. |
I'd also be happy to coordinate. |
I think the two immediate things are adding AQLM support and speedup. Adding AQLM in torchao will be a bit more convenient for users compared to using AQLM repo and then convert I think |
Great! Let me know what I can start with. |
I'll focus on speeding up token generation, can coordinate more if @pawarmanasi07 also wants to work on that. |
I can help with that! |
@DerekLiu35 Could you share your thoughts on which aspects of the dequantization kernels from AQLM we should focus on first? We could divide up different parts of the optimization work between us? |
I think we can focus on 1x16 group size cuda kernels and triton (as fallback). we could divide optimization work by one of us focusing on forward pass kernels and the other on backward pass kernels, though I'm not sure why you need backward pass kernels. we could also split by different kernels like 1x16 group size and 1x1 group size (no reference cuda kernels in AQLM). I'm not sure what the best way to divide work between us though. I'll probably start with 1x16 forward pass kernel |
Sounds good! I think focusing on the 1x16 group size kernels makes sense as a starting point. I can work on the 1x1 group size kernels while you tackle the 1x16 forward pass implementation. For the backward pass kernels - you raise a good point about whether they're necessary. Since this is post-training quantization, we likely don't need backward pass optimization unless we're planning to support fine-tuning scenarios? |
Hi @DerekLiu35 and @jerryzh168, to confirm my tasks - I'll be focusing on optimizing the dequantization for 1x1 group size.
While Derek focuses on the 1x16 forward pass kernel implementation. Implement new CUDA kernels for 1x1 Is this the correct understanding of the work division? I just want to ensure I'm heading in the right direction before starting. |
However would it make more sense to start with Triton implementation for 1x1 first (since we need it as a fallback anyway) |
Yeah I think that would make sense to start with triton fallback first |
Similar to affine quantization, we can implement codebook or look up table based quantization, which is another popular type of quantization, especially for lower bits like 4 bits or below (used in https://github.com/Vahe1994/AQLM, https://arxiv.org/abs/2402.04396 etc.). We can start with post training quantization and use k-means clustering to find the codebook / lookup table. You can check out #391 for the overall structure of torchao stack. Reference code for k-means can be found here.
After this we can also add more support for the advanced algorithms mentioned above.
API
Implementation details:
Needs to flesh out the details of args etc. but can be done in the PR. I'd suggest to gradually add things and gather feedback.
Code Location: add a
codebook
folder under https://github.com/pytorch/ao/tree/main/torchao/prototype/quantizationTasks
The text was updated successfully, but these errors were encountered: