-
Notifications
You must be signed in to change notification settings - Fork 185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add FP16Act-FP6Weight Linear #223
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/223
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit a8b4dd3 with merge base ad12663 (): NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Hey, I was looking at the fp6 code as well, I was going to bring up was the code has a few utility/kernel files which is be reusable for implementing the next quant type. |
I'd opt for generalizing things in a future PR but will @gau-nernst decide what makes sense for them. @Iron-Bound which future work were you hoping to build on top? |
@msaroufim Could hack on CFloat8_1_4_3 and CFloat8_1_5_2 if people think its valuable? |
I haven't fllowed our float8 work closely but have you gotten the chance to take a look at https://github.com/pytorch-labs/float8_experimental Granted I would like an API that looks like |
I will leave it for a future PR to refactor. I don't understand much of the parts that involved in the kernel, so I won't be touching them and leave them as is. Regarding float dtype. The actual FP6 used in FP6_LLM is E3M2, without nan/inf. Two pointers
Also, another interesting thing to work on is to replicate |
fp6_test.py
Outdated
@@ -0,0 +1,98 @@ | |||
# from https://github.com/usyd-fsalab/fp6_llm/blob/main/tests/python/kernel_test.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move the relevant files to either benchmark or test folder
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok so I think to merge this what we can do
- Move the relevant benchmark and test files to either the benchmark or test repo
- In CI either do the numerics check on an op level or at a macro eval level (ideally the first for now)
- We can worry about the subclass and torch.compile stuff in a future PR
- Make sure the acknowledgements to original repos are crystal clear everywhere
- Make the speedup clear in the PR description over the fp16/bf16 baselines
Ok I think we're ready to merge this, last thing is add limitations in README for small batch sizes here usyd-fsalab/fp6_llm#8 and explain that this should be used to speed up autoregressive decoding And for the next PR let's start to do evals with an end to end model, I'm hoping we can leverage this PR for that #189 |
* add files from fp6_llm * try to port weight packing first * rename * rename fp6 weight packing * add fp16act_fp6weight_linear * fix function def * delete duplicate file * move weight quant file * rename * add pytorch interface for fp6 weight dequant * add fake_fp6 to fp6 * move weight_quant to csrc/cuda due to cuda_fp16.h dependency * add fake_fp6_to_fp6 test * add test for fp16act_fp6weight_linear * add test for fp6_weight_dequant * Fp6WeightOnlyQuantizedLinearWeight (not working yet) * skip some tests, since the functions are not built w/o CUDA * add the original test * implement transpose and clone so that F.linear will work * remove print * remove dequantize * add notes and some rename * typo * small cleanup * improve tensor subclass and add test (which is failing for torch-compile) * add note * add note * add qtorch as dev requirement * update error message * add __repr__ and fix transposed issue * add fp6 perplexity test * rename variables * remove subclass * add correctness test * remove unwanted changes * add apache 2.0 notice * add benchmark script * add note about FP6 kernel * relax tolerance --------- Co-authored-by: Mark Saroufim <[email protected]>
* add files from fp6_llm * try to port weight packing first * rename * rename fp6 weight packing * add fp16act_fp6weight_linear * fix function def * delete duplicate file * move weight quant file * rename * add pytorch interface for fp6 weight dequant * add fake_fp6 to fp6 * move weight_quant to csrc/cuda due to cuda_fp16.h dependency * add fake_fp6_to_fp6 test * add test for fp16act_fp6weight_linear * add test for fp6_weight_dequant * Fp6WeightOnlyQuantizedLinearWeight (not working yet) * skip some tests, since the functions are not built w/o CUDA * add the original test * implement transpose and clone so that F.linear will work * remove print * remove dequantize * add notes and some rename * typo * small cleanup * improve tensor subclass and add test (which is failing for torch-compile) * add note * add note * add qtorch as dev requirement * update error message * add __repr__ and fix transposed issue * add fp6 perplexity test * rename variables * remove subclass * add correctness test * remove unwanted changes * add apache 2.0 notice * add benchmark script * add note about FP6 kernel * relax tolerance --------- Co-authored-by: Mark Saroufim <[email protected]>
Closes #208
References:
TODO:
benchmarks/benchmark_fp6.py
results - 4070 Ti SUPER, PyTorch 2.3, CUDA 12.1benchmarks/benchmark_fp6.py
results - 4090. Courtesy to @Iron-Bound