-
Notifications
You must be signed in to change notification settings - Fork 227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/blockwise fp8 quant #1668
base: main
Are you sure you want to change the base?
Feat/blockwise fp8 quant #1668
Conversation
- first implementation of the DeepSeek blockwise quantizer (not fully fonctionnal) - amax has been unpdated - 2 more quantisation recipes has been added - a couple of things here and there to make it consistent
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1668
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
could you share what gemm kernel you plan to use in this PR? I think a good first step here is to have a fast gemm. we have an issue tracking this here: #1594 |
this might be a good place to start: https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/1d044fd82b15f1cedb197a288e50cc96a2c27205/inference/kernel.py#L63 |
Overall it would be great to be able to support this recipe in torchao. I think having a gemm with compelling performance that supports 128x1 and 128x128 scaling is something we need first, with benchmarks comparing to other recipes such as rowwise scaled, etc. |
Relevant PR in SGLang that adds the triton kernels - sgl-project/sglang#2575 (thanks to @HandH1998). I think it makes sense to add this as a starting point to torchao. |
scale_a = torch.ones(M, 1, device=device) | ||
scale_b = torch.ones(1, N, device=device) | ||
else: | ||
assert scaling_granularity == ScalingGranularity.BLOCKWISE, "unsupported" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this file is benchmarking torch._scaled_mm
which does not support blockwise scaling, is this change intended?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unintended, but I will rework this PR. There were some details that I had missed when I initially worked on it.
@Degnel , thanks for adding the gemm, this is great to see. How do you feel about splitting the gemm + gemm benchmarks into a separate PR which we can quickly land to the prototype folder to get some performance data? We are happy to run some benchmarks for you on H100s. I would like to see how gemm performance stacks up to help us understand whether the overall workflow should go in the prototype folder or if we should add it to |
I also have a high level question on the type of scaling this PR implements. From the fact that the block size is specified as In https://arxiv.org/html/2412.19437v1 Section 3.3.2, the report specifies that activations are tiled 128x1, and weights blocked 128x128, so I just wanted to check if this PR is trying to implement the gemm from the paper as written or making a modification. |
Hi @vkuzo, I was just getting back to work on this issue. I also felt like it would be quicker to open a new PR gemm + gemme bench, and add integrations latter. For now, I have rent a A100, and it seems like W4A8 is both faster and more precise. That is why I don't feel like it is relevant to put the code into float8 (I've put in prototype on my local repositiory). |
At the time, I thought about making it simpler for the first version, but the current gemm that I have, does support the 128x1 (for activation) and 128x128 (for weights). |
I have memory leak issues for now. Once those are resolved, I will make a new PR. |
are you interested in training, inference or both? w4a8 is more for inference. |
Only inference for now, but I agree it would be interesting to add training bench. |
The new PR containing only the gemm and the benchmark is available at #1763 |
Feat: Implementation of the DeepSeek blockwise quantization for fp8 tensors
WARNING: The code has been tested on the following files:
pytest test/float8/test_base.py
pytest test/float8/test_compile.py
pytest test/float8/test_numerics_integration.py
However, tests have not been performed on the following files due to limitations (Triton is unavailable on Windows and I don't own an NVIDIA GPU):
./test/float8/test_fsdp.sh
./test/float8/test_dtensor.sh
python test/float8/test_fsdp2/test_fsdp2.py