-
Notifications
You must be signed in to change notification settings - Fork 185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor custom FPx cast #363
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/363
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit bd64efc with merge base 664f073 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
2x is a sizeable regression, how about keeping the LUT for the formats we already have it for and having a generic fallback for the other formats? People can then optimize format by format individually if they want. |
@vkuzo I have updated the dequant denormal implementation. No speed regression anymore (I updated the results in the 1st post). Didn't need to use the hard-coded LUT from your implementation. If torch compiler does constant folding and loop unrolling properly, I think my implementation should match your previous implementation exactly. If possible, you can benchmark on your GPUs to make sure 100% there is no regression. |
Here are results on an H100: https://gist.github.com/vkuzo/324256b8defd0231852a23cbb34f49a6, I see no meaningful change in performance, awesome stuff |
torchao/prototype/custom_fp_utils.py
Outdated
|
||
def _fpx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor: | ||
""" | ||
TODO(future): check if LUT for everything is faster than bit shifting, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this comment still relevant?
maybe add a docblock?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
using LUT for everything in dequant might be faster, like current NF4 implementation. I haven't benchmarked so I'm not sure.
I didn't add a docblock here since I think this is kinda an internal function. But a simple doc won't hurt. Will add some doc for this and quant function above. I already added a short description for these 2 functions at the top of the file.
F32_EXP_BIAS = _n_ones(EBITS_F32 - 1) | ||
|
||
|
||
def _f32_to_fpx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we have a docblock?
* refactor custom fp cast * add dequant * small formating * compile with fullgraph=True * add fullgraph=true * undo * add another version * fast path for mbits=1 * add back docstring
Summary: This diff - refactors install_et.sh into a bunch of utils - Uses those utils in build_android.sh to minimize duplication. - Makes sure taht we are building with custom sdpa op Test Plan: Model export python export.py --quant '{"linear:a8w4dq" : {"groupsize": 256}}' --checkpoint-path /home/kimishpatel/models/llama2/stories/stories110M.pt --params-path /home/kimishpatel/models/llama2/stories/params.json --output-pte-path /tmp/stories110m_a8w4dq.pte python utils/tokenizer.py --tokenizer-model=/tmp/tokenizer.model linux: ./scripts/install_et.sh rm -rf build/cmake-out/ cmake -S ./runner-et -B build/cmake-out -G Ninja cmake --build ./build/cmake-out ./build/cmake-out/runner_et /tmp/stories110m_a8w4dq.pte -z /tmp/tokenizer.bin -t 0 -n 120 android: ./runner-et/build_android.sh adb push ./build/cmake-out-android/runner_et /data/local/tmp/ adb push /tmp/stories110m_a8w4dq.pte /data/local/tmp/ adb push /tmp/tokenizer.bin /data/local/tmp/ adb shell "cd /data/local/tmp && ./runner_et ./stories110m_a8w4dq.pte -z ./tokenizer.bin -t 0 -n 120" Will add build commands to ci in the next PR Reviewers: Subscribers: Tasks: Tags:
Closes #354
TODO:
Check torch.compileBenchmark before and after8841094 (main)
2690b92 (this PR)
Dequant is 2x slower because I replaced LUT-based denormal handling with a more generic logic. @vkuzo Should I add back the LUT-based logic (check specifically for E2M3 E3M2 E2M1)? If we are interested in performance then perhaps we can generate a LUT for all bit patterns and cache it.
UPDATE
95f4582 (this PR v2)
Now FP4_E2M1 is slower lol. Feel like this should be bandwidth-limited. It might be register-limited also? Will do some profiling + make sure torch.compile run optimally. Interesting that native PyTorch float8 dequant is slower.
UPDATE 2
dcd5a05 (this PR v3)
Speed recovered 😊