-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Performance regression for large batch sizes #8
Comments
Thanks for your efforts. I am delighted that our kernel can also work well on 4070Ti GPUs, as I never tested it myself. |
Thank you for your reply. It's good to know that it is indeed a known limitation. In the case of LLM, the batch size in matmul is batch size x sequence length. Thus, how would this kernel be faster than fp16 for LLM inference? Specifically, I'm referring to End2End Inference results (Section 7.3), Figure 12, 13, and 14. Is the "batch size" in Figure 12 and 13 refers to batch size of matmul (BxD), or batch size of transformer activations (BxLxD)? Due to the sentence "We set the prefill/prompt length of each request to 0.5K, and generate 1.5K tokens", I have the impression that it is the latter (BxLxD). |
The shape of activations is (B*L, Hidden) for the prompt/prefill processing phase, but the shape of activation becomes (B*1, Hidden) for the decoding (token generation) phase if KV-Cache is used. Our kernel mainly accelerates the token generation phase of LLM inference. |
I see, that makes perfect sense! Thank you again. I will close this issue. Another small question. Do you plan to add support for bias (A @ W.T + b) in the future? It's a bit inconvenient to launch a separate kernel just to add bias. |
Thanks for your valuable suggestion. I will add this feature. |
Hello,
Thank you for the great work. I'm integrating the FP6 linear kernel to torchao (pytorch/ao#223). One thing I have observed is that the kernel is slower than PyTorch's default at large batch sizes. On my 4070Ti Super, the breaking point is at batch size = 256. You can see the detailed benchmark reports in my PR to torchao above. @msaroufim also had similar results with H100. I use the splitK values as specified in
tests/python/run.sh
Have you observed similar results? I believe the kernel was tuned for A100. I don't have access to an A100 so I can't check.
I have also run
tests/python/run.sh
with the kernel compiled in this repo. Results are belowThe text was updated successfully, but these errors were encountered: