-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
int8 gemm slower than fp16 on A100. #935
Comments
Hi @beegerous , thank you for reporting the issue. When using https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h one should specify In TensorRT-LLM we use profiler to find the best GEMM configuration for int8 given problem size: https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/plugins/common/gemmPluginProfiler.cpp. For example, this profiler is used here, could you, please, try it out. Let me know if you have more questions. |
@nekorobov Thanks for reply. sorry I miss some information. Before I calculate gemm, I use this function (https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp#L196) to get the best tile config, and make sure it only run one time for specific [m,n,k]. Is there difference between cutlass_heuristic and gemmPluginProfiler ? From io variable I think they are both for get best CutlassGemmConfig. By the way, the code here (https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h#L268) is a little bit confused. It is the only one shape that diff with case variable name tells. But I change it into 64 and re-run those cases, get no difference performance. |
Do you include this function into your time calculations? This might take significant amount of time. Which
Yes, cutlass_heuristic uses some estimation heuristics to find the best config without running it. While it is relatively fast estimation method, it is not always accurate. gemmPluginProfiler simply profiles each GEMM config to choose the fastest. gemmPluginProfiler should be executed offline, but it gives more accurate results.
Indeed, this mismatch is a bug. Thank you for reporting this. I believe that it makes difference, just not on the cases you've profiled. |
@nekorobov
|
Have you misstaken the unit? us not seconds? |
@nekorobov Hi, I also tested this int8 kernel and has some problems:
As you can see, this kernel is indeed faster for this four case. But when I use this kernel in real models, for the first two shape, the perf will dropped a lot: [M:16, N:6144, K:4096] : 24us But for the last two shapes, the perf won't drop. TIPS: I have done some experiments, I found if I call 10 times of the kernel in the model, not just 1 time, the calling's perf after the first time will be good. If this is a cache/warmup issue, why the last two cases won't have this problem. |
I need a python operator that support int8gemm with pertoken/perchannel quantization. So I wrap the code https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h into something like https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/kernels/linear.cu#L476
Then I test the speed of int8gemm compare with pure fp16 (torch.nn.functional.linear). The time cost only statistic gemm, no quantize. I repeat gemm 1e5 times (1e6 times in small case).
test code is something like:
At first I think the reason is input tensor too small, so when mnk equals 4096, int8gemm finally faster than torch. But then I try 8192 in just one dim, the int8gemm is slower again. The last three case should have similar computations, the torch result reflact that, but int8gemm cost is quiet unstable. And i expect that int8gemm should be 2x faster than fp16 according to https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf .
I check case [4096, 8192, 4096] with nsys, didn't found anything weird.
I'm confused with these results and try to understand the deep reason. or if it is caused by some compile error, how would I check it ?
Env:
Ubuntu 20.04.6 LTS
NVIDIA A100-SXM4-80GB
base commit: c896530
I build and run code in docker build from docker/Dockerfile.multi .
I build mylib base on scripts/build_wheel.py .
The text was updated successfully, but these errors were encountered: