Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance regression for large batch sizes #8

Closed
gau-nernst opened this issue May 11, 2024 · 5 comments
Closed

Performance regression for large batch sizes #8

gau-nernst opened this issue May 11, 2024 · 5 comments
Assignees

Comments

@gau-nernst
Copy link

Hello,

Thank you for the great work. I'm integrating the FP6 linear kernel to torchao (pytorch/ao#223). One thing I have observed is that the kernel is slower than PyTorch's default at large batch sizes. On my 4070Ti Super, the breaking point is at batch size = 256. You can see the detailed benchmark reports in my PR to torchao above. @msaroufim also had similar results with H100. I use the splitK values as specified in tests/python/run.sh

Have you observed similar results? I believe the kernel was tuned for A100. I don't have access to an A100 so I can't check.

I have also run tests/python/run.sh with the kernel compiled in this repo. Results are below

################################################################
Namespace(OC=10240, IC=8192, BS=1, splitK=5)
cuBLAS  time: 0.27 ms            cuBLAS  TFLOPs: 0.6
fp6-llm time: 0.11 ms            fp6-llm TFLOPs: 1.5
speedup: 2.39
relative error: 0.000126
################################################################
Namespace(OC=8192, IC=8192, BS=1, splitK=6)
cuBLAS  time: 0.22 ms            cuBLAS  TFLOPs: 0.6
fp6-llm time: 0.02 ms            fp6-llm TFLOPs: 5.6
speedup: 9.15
relative error: 0.000130
################################################################
Namespace(OC=57344, IC=8192, BS=1, splitK=7)
cuBLAS  time: 1.53 ms            cuBLAS  TFLOPs: 0.6
fp6-llm time: 0.57 ms            fp6-llm TFLOPs: 1.6
speedup: 2.69
relative error: 0.000201
################################################################
Namespace(OC=8192, IC=28672, BS=1, splitK=6)
cuBLAS  time: 0.77 ms            cuBLAS  TFLOPs: 0.6
fp6-llm time: 0.28 ms            fp6-llm TFLOPs: 1.7
speedup: 2.71
relative error: 0.000260

################################################################
Namespace(OC=10240, IC=8192, BS=256, splitK=4)
cuBLAS  time: 0.58 ms            cuBLAS  TFLOPs: 74.0
fp6-llm time: 0.61 ms            fp6-llm TFLOPs: 70.7
speedup: 0.96
relative error: 0.000010
################################################################
Namespace(OC=8192, IC=8192, BS=256, splitK=3)
cuBLAS  time: 0.39 ms            cuBLAS  TFLOPs: 88.6
fp6-llm time: 0.45 ms            fp6-llm TFLOPs: 75.8
speedup: 0.86
relative error: 0.000009
################################################################
Namespace(OC=57344, IC=8192, BS=256, splitK=2)
cuBLAS  time: 2.76 ms            cuBLAS  TFLOPs: 87.3
fp6-llm time: 3.19 ms            fp6-llm TFLOPs: 75.4
speedup: 0.86
relative error: 0.000008
################################################################
Namespace(OC=8192, IC=28672, BS=256, splitK=3)
cuBLAS  time: 1.37 ms            cuBLAS  TFLOPs: 87.7
fp6-llm time: 1.51 ms            fp6-llm TFLOPs: 79.5
speedup: 0.91
relative error: 0.000033
@Summer-Summer
Copy link
Member

Thanks for your efforts. I am delighted that our kernel can also work well on 4070Ti GPUs, as I never tested it myself.
The phenomenon you have observed is exactly what we would expect. Please refer to Figure 1 in our paper.
As the BS increases to a certain point (e.g., 256 for A100), the GEMM becomes compute-bound, where the computational throughput of Tensor Cores becomes the bottleneck of the kernel execution.
Our FP6-LLM kernel and cuBLAS are both using FP16 tensor cores. Thus, the theoretical peak speed of both kernels for large batch sizes should be the same. We have tried to make our FP6-LLM not significantly slower than cuBLAS for large batch sizes, but we can never be faster than cuBLAS if we also use the FP16 tensor core for the core computations.
Note that the breaking point could be different for different GPUs. The breaking point for H100 would be larger than 256 since the ratio between Tensor Core throughput and DRAM bandwidth is larger than ever. Our kernel is not optimized for H100 now, but I will do it if I get more spare time.

@gau-nernst
Copy link
Author

Thank you for your reply. It's good to know that it is indeed a known limitation.

In the case of LLM, the batch size in matmul is batch size x sequence length. Thus, how would this kernel be faster than fp16 for LLM inference? Specifically, I'm referring to End2End Inference results (Section 7.3), Figure 12, 13, and 14. Is the "batch size" in Figure 12 and 13 refers to batch size of matmul (BxD), or batch size of transformer activations (BxLxD)? Due to the sentence "We set the prefill/prompt length of each request to 0.5K, and generate 1.5K tokens", I have the impression that it is the latter (BxLxD).

@Summer-Summer
Copy link
Member

Summer-Summer commented May 13, 2024

The shape of activations is (B*L, Hidden) for the prompt/prefill processing phase, but the shape of activation becomes (B*1, Hidden) for the decoding (token generation) phase if KV-Cache is used. Our kernel mainly accelerates the token generation phase of LLM inference.
For LLM inference, we only need to execute a prompt processing phase once (all input tokens are processed in parallel), but we must execute a decoding step to generate each output token. Thus, the decoding phase can easily dominate the overall LLM inference execution time. That is the reason why we can get end-to-end speedups.

@gau-nernst
Copy link
Author

I see, that makes perfect sense! Thank you again. I will close this issue.

Another small question. Do you plan to add support for bias (A @ W.T + b) in the future? It's a bit inconvenient to launch a separate kernel just to add bias.

@Summer-Summer
Copy link
Member

Thanks for your valuable suggestion. I will add this feature.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants