Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

int8 gemm slower than fp16 on A100. #935

Closed
beegerous opened this issue Jan 23, 2024 · 6 comments
Closed

int8 gemm slower than fp16 on A100. #935

beegerous opened this issue Jan 23, 2024 · 6 comments
Assignees
Labels
triaged Issue has been triaged by maintainers

Comments

@beegerous
Copy link

I need a python operator that support int8gemm with pertoken/perchannel quantization. So I wrap the code https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h into something like https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/kernels/linear.cu#L476

Then I test the speed of int8gemm compare with pure fp16 (torch.nn.functional.linear). The time cost only statistic gemm, no quantize. I repeat gemm 1e5 times (1e6 times in small case).

M N K torch cost(s) trt-llm cost(s)
1024 1024 1024 19.999567985534668 22.63555335998535
2048 2048 2048 8.349320888519287 7.4798314571380615
4096 4096 4096 65.3261570930481 45.53051781654358
8192 4096 4096 125.70239543914795 137.0772671699524
4096 8192 4096 125.74432516098022 117.87490010261536
4096 4096 8192 118.52623224258423 86.75222182273865

test code is something like:

x, alpha = gen_input_tensor([M, K])
y, beta  = gen_input_tensor([N, K])
n = 100000

with cost("int8gemm"):
    for _ in range(n):
        d = mylib.linear_a8_w8_bofp16(x, y, beta, alpha, bias)

x = x * alpha
y = y * beta

with cost("torch"):
    for _ in range(n):
        c = torch.nn.functional.linear(x, y)

At first I think the reason is input tensor too small, so when mnk equals 4096, int8gemm finally faster than torch. But then I try 8192 in just one dim, the int8gemm is slower again. The last three case should have similar computations, the torch result reflact that, but int8gemm cost is quiet unstable. And i expect that int8gemm should be 2x faster than fp16 according to https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf .

image
I check case [4096, 8192, 4096] with nsys, didn't found anything weird.

I'm confused with these results and try to understand the deep reason. or if it is caused by some compile error, how would I check it ?

Env:
Ubuntu 20.04.6 LTS
NVIDIA A100-SXM4-80GB
base commit: c896530

I build and run code in docker build from docker/Dockerfile.multi .
I build mylib base on scripts/build_wheel.py .

@nekorobov
Copy link
Collaborator

nekorobov commented Jan 25, 2024

Hi @beegerous , thank you for reporting the issue. When using https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h one should specify ThreadblockShape, WarpShape and Stages to configure GEMM. Choice of these parameters can influence the performance by a lot. Meanwhile, torch uses some runtime heuristics (e.g. in cuDNN library) to choose the optimal GEMM shape for the given problem size. Thus, I guess, the performance difference you see comes from the comparison of suboptimal int8 GEMM to the optimal fp16 GEMM.

In TensorRT-LLM we use profiler to find the best GEMM configuration for int8 given problem size: https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/plugins/common/gemmPluginProfiler.cpp. For example, this profiler is used here, could you, please, try it out. Let me know if you have more questions.

@nekorobov nekorobov added the triaged Issue has been triaged by maintainers label Jan 25, 2024
@beegerous
Copy link
Author

Hi @beegerous , thank you for reporting the issue. When using https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h one should specify ThreadblockShape, WarpShape and Stages to configure GEMM. Choice of these parameters can influence the performance by a lot. Meanwhile, torch uses some runtime heuristics (e.g. in cuDNN library) to choose the optimal GEMM shape for the given problem size. Thus, I guess, the performance difference you see comes from the comparison of suboptimal int8 GEMM to the optimal fp16 GEMM.

In TensorRT-LLM we use profiler to find the best GEMM configuration for int8 given problem size: https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/plugins/common/gemmPluginProfiler.cpp. For example, this profiler is used here, could you, please, try it out. Let me know if you have more questions.

@nekorobov Thanks for reply.

sorry I miss some information. Before I calculate gemm, I use this function (https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp#L196) to get the best tile config, and make sure it only run one time for specific [m,n,k].

Is there difference between cutlass_heuristic and gemmPluginProfiler ? From io variable I think they are both for get best CutlassGemmConfig.

By the way, the code here (https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h#L268) is a little bit confused. It is the only one shape that diff with case variable name tells. But I change it into 64 and re-run those cases, get no difference performance.

@nekorobov
Copy link
Collaborator

nekorobov commented Jan 25, 2024

I use this function (https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp#L196) to get the best tile config, and make sure it only run one time for specific [m,n,k].

Do you include this function into your time calculations? This might take significant amount of time. Which candidate_configs do you provide?

Is there difference between cutlass_heuristic and gemmPluginProfiler ? From io variable I think they are both for get best CutlassGemmConfig.

Yes, cutlass_heuristic uses some estimation heuristics to find the best config without running it. While it is relatively fast estimation method, it is not always accurate. gemmPluginProfiler simply profiles each GEMM config to choose the fastest. gemmPluginProfiler should be executed offline, but it gives more accurate results.

It is the only one shape that diff with case variable name tells. But I change it into 64 and re-run those cases, get no difference performance.

Indeed, this mismatch is a bug. Thank you for reporting this. I believe that it makes difference, just not on the cases you've profiled.

@beegerous
Copy link
Author

@nekorobov
thanks for your advise. I change cutlass_heuristic into profiler, and int8 gemm on a100 become much faster.
here is the latest result. (trt-llm cost does not contain profiler time.)

m n k torch cost(s) trt-llm cost(s)
1024 1024 1024 19.89363145828247 20.7686026096344
2048 2048 2048 8.329726219177246 7.037260055541992
4096 4096 4096 63.426324129104614 44.36107349395752
8192 4096 4096 125.48689246177673 96.80497765541077
4096 8192 4096 126.33650040626526 85.49396276473999
4096 4096 8192 118.56665873527527 86.97948837280273

@foreverlms
Copy link

@nekorobov thanks for your advise. I change cutlass_heuristic into profiler, and int8 gemm on a100 become much faster. here is the latest result. (trt-llm cost does not contain profiler time.)

m n k torch cost(s) trt-llm cost(s)
1024 1024 1024 19.89363145828247 20.7686026096344
2048 2048 2048 8.329726219177246 7.037260055541992
4096 4096 4096 63.426324129104614 44.36107349395752
8192 4096 4096 125.48689246177673 96.80497765541077
4096 8192 4096 126.33650040626526 85.49396276473999
4096 4096 8192 118.56665873527527 86.97948837280273

Have you misstaken the unit? us not seconds?

@foreverlms
Copy link

@nekorobov Hi, I also tested this int8 kernel and has some problems:
From my benchmark, I compared this int8 gemm kenel with w8a16 mixed gemm kernel and got:

[M,N,K] W8A16/ms W8A8/ms Boost/ %
[M:16, N:6144, K:4096] 0.02534 0.01514 40.244
[M:16, N:4096, K:4096] 0.02431 0.01454 40.1752
[M:16, N:28672, K:4096] 0.11408 0.09675 15.1917
[M:16, N:4096, K:14336] 0.06912 0.05302 23.2827

As you can see, this kernel is indeed faster for this four case. But when I use this kernel in real models, for the first two shape, the perf will dropped a lot:

[M:16, N:6144, K:4096] : 24us
[M:16, N:4096, K:4096]: 20 us

But for the last two shapes, the perf won't drop.
Could you please give me some clues to get the perf back?

TIPS: I have done some experiments, I found if I call 10 times of the kernel in the model, not just 1 time, the calling's perf after the first time will be good. If this is a cache/warmup issue, why the last two cases won't have this problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

3 participants