Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
I.e. it feels reasonable to always call `at::cuda::gemm` rather than `at::cuda::bgemm` when num_batches == 1 After the change, benchmarking torch built with CUDA-12 using [following perf script](https://gist.github.com/malfet/6a17156d7f5663b8b12054a1beff3fe1) on A100 are as follows: | Shape | bmm_time | mm_time | slow down (%) | | -------------- | --------- | --------- | ------------- | | 1x1x4096 | 14.18 | 14.31 | -0.89 | | 1x1x8192 | 14.37 | 14.37 | -0.05 | | 1x1x16384 | 14.03 | 14.12 | -0.68 | | 1x1x32768 | 14.19 | 14.24 | -0.35 | | 1x1x65536 | 14.85 | 14.52 | 2.30 | | 1x1x131072 | 14.03 | 14.07 | -0.33 | | 128x128x128 | 11.34 | 11.06 | 2.56 | | 256x256x256 | 14.85 | 14.40 | 3.15 | | 512x512x512 | 27.22 | 27.22 | -0.01 | | 1024x1024x1024 | 129.66 | 129.50 | 0.12 | | 2048x2048x2048 | 972.18 | 973.24 | -0.11 | | 129x127x129 | 11.21 | 11.25 | -0.39 | | 257x255x257 | 14.50 | 14.43 | 0.44 | | 513x511x513 | 29.01 | 29.01 | 0.01 | | 1025x1023x1025 | 137.65 | 137.64 | 0.01 | | 2049x2047x2049 | 982.58 | 982.65 | -0.01 | | 4097x3x4097 | 86.65 | 86.64 | 0.01 | | 8193x3x8193 | 384.02 | 383.96 | 0.02 | | 16385x3x16385 | 1106.73 | 1107.32 | -0.05 | | 32769x3x32769 | 4739.49 | 4739.48 | 0.00 | | 65537x3x65537 | 17377.78 | 17378.74 | -0.01 | | 4097x5x4097 | 87.09 | 87.12 | -0.03 | | 8193x5x8193 | 301.38 | 301.36 | 0.01 | | 16385x5x16385 | 1107.38 | 1108.04 | -0.06 | | 32769x5x32769 | 4743.73 | 4744.07 | -0.01 | | 65537x5x65537 | 17392.32 | 17395.42 | -0.02 | | 4097x7x4097 | 87.17 | 87.19 | -0.02 | | 8193x7x8193 | 301.94 | 302.00 | -0.02 | | 16385x7x16385 | 1107.17 | 1106.79 | 0.03 | | 32769x7x32769 | 4747.15 | 4747.13 | 0.00 | | 65537x7x65537 | 17403.85 | 17405.02 | -0.01 | Fixes perf problem reported in #114911 Pull Request resolved: #114992 Approved by: https://github.com/Skylion007, https://github.com/eqy Co-authored-by: Nikita Shulga <[email protected]>
- Loading branch information