From 1a3e3c7cffac8b985ef8bd8dff9891c63c51e830 Mon Sep 17 00:00:00 2001 From: Andrey Talman Date: Tue, 2 Jan 2024 21:54:15 +0000 Subject: [PATCH] [CUDA] baddmm should fall back to addmm for batch=1 (#114992) (#116518) I.e. it feels reasonable to always call `at::cuda::gemm` rather than `at::cuda::bgemm` when num_batches == 1 After the change, benchmarking torch built with CUDA-12 using [following perf script](https://gist.github.com/malfet/6a17156d7f5663b8b12054a1beff3fe1) on A100 are as follows: | Shape | bmm_time | mm_time | slow down (%) | | -------------- | --------- | --------- | ------------- | | 1x1x4096 | 14.18 | 14.31 | -0.89 | | 1x1x8192 | 14.37 | 14.37 | -0.05 | | 1x1x16384 | 14.03 | 14.12 | -0.68 | | 1x1x32768 | 14.19 | 14.24 | -0.35 | | 1x1x65536 | 14.85 | 14.52 | 2.30 | | 1x1x131072 | 14.03 | 14.07 | -0.33 | | 128x128x128 | 11.34 | 11.06 | 2.56 | | 256x256x256 | 14.85 | 14.40 | 3.15 | | 512x512x512 | 27.22 | 27.22 | -0.01 | | 1024x1024x1024 | 129.66 | 129.50 | 0.12 | | 2048x2048x2048 | 972.18 | 973.24 | -0.11 | | 129x127x129 | 11.21 | 11.25 | -0.39 | | 257x255x257 | 14.50 | 14.43 | 0.44 | | 513x511x513 | 29.01 | 29.01 | 0.01 | | 1025x1023x1025 | 137.65 | 137.64 | 0.01 | | 2049x2047x2049 | 982.58 | 982.65 | -0.01 | | 4097x3x4097 | 86.65 | 86.64 | 0.01 | | 8193x3x8193 | 384.02 | 383.96 | 0.02 | | 16385x3x16385 | 1106.73 | 1107.32 | -0.05 | | 32769x3x32769 | 4739.49 | 4739.48 | 0.00 | | 65537x3x65537 | 17377.78 | 17378.74 | -0.01 | | 4097x5x4097 | 87.09 | 87.12 | -0.03 | | 8193x5x8193 | 301.38 | 301.36 | 0.01 | | 16385x5x16385 | 1107.38 | 1108.04 | -0.06 | | 32769x5x32769 | 4743.73 | 4744.07 | -0.01 | | 65537x5x65537 | 17392.32 | 17395.42 | -0.02 | | 4097x7x4097 | 87.17 | 87.19 | -0.02 | | 8193x7x8193 | 301.94 | 302.00 | -0.02 | | 16385x7x16385 | 1107.17 | 1106.79 | 0.03 | | 32769x7x32769 | 4747.15 | 4747.13 | 0.00 | | 65537x7x65537 | 17403.85 | 17405.02 | -0.01 | Fixes perf problem reported in https://github.com/pytorch/pytorch/issues/114911 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114992 Approved by: https://github.com/Skylion007, https://github.com/eqy Co-authored-by: Nikita Shulga --- aten/src/ATen/native/cuda/Blas.cpp | 40 ++++++++++++------- .../_internal/common_methods_invocations.py | 6 +-- 2 files changed, 28 insertions(+), 18 deletions(-) diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 38cce45ab6e77a..e5163c339da990 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -369,12 +370,10 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma } const Tensor& baddbmm_out_cuda_impl(const Tensor& result, const Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha) { - IntArrayRef batch1_sizes = batch1.sizes(); - // handle pathological cases that blas may not like if (result.numel() == 0) { return result; - } else if (batch1_sizes[2] == 0) { + } else if (batch1.size(2) == 0) { if (beta.to>() == 0.0) { return result.zero_(); } else { @@ -421,17 +420,30 @@ const Tensor& baddbmm_out_cuda_impl(const Tensor& result, const Tensor& self, co const scalar_t* batch1_ptr = batch1_->const_data_ptr(); const scalar_t* batch2_ptr = batch2_->const_data_ptr(); scalar_t* result_ptr = result_->mutable_data_ptr(); - at::cuda::blas::bgemm( - transpose_batch1 ? batch1_->is_conj() ? 'c' : 't' : 'n', - transpose_batch2 ? batch2_->is_conj() ? 'c' : 't' : 'n', - m, n, k, - alpha_val, - batch1_ptr, lda, batch1_->strides()[0], - batch2_ptr, ldb, batch2_->strides()[0], - beta_val, - result_ptr, ldc, result_->strides()[0], - num_batches - ); + const auto transa = transpose_batch1 ? batch1_->is_conj() ? 'c' : 't' : 'n'; + const auto transb = transpose_batch2 ? batch2_->is_conj() ? 'c' : 't' : 'n'; + // If batch is 1 call gemm rather than bgemm + if (num_batches == 1) { + at::cuda::blas::gemm( + transa, transb, + m, n, k, + alpha_val, + batch1_ptr, lda, + batch2_ptr, ldb, + beta_val, + result_ptr, ldc); + } else { + at::cuda::blas::bgemm( + transa, transb, + m, n, k, + alpha_val, + batch1_ptr, lda, batch1_->strides()[0], + batch2_ptr, ldb, batch2_->strides()[0], + beta_val, + result_ptr, ldc, result_->strides()[0], + num_batches + ); + } }); if (!result.is_same(*result_)) { result.copy_(*result_); diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index e4c51d1a71d833..3b934eac19b4e4 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -26,7 +26,7 @@ skipCPUIfNoMklSparse, toleranceOverride, tol) from torch.testing._internal.common_cuda import ( - PLATFORM_SUPPORTS_FLASH_ATTENTION, SM53OrLater, SM60OrLater, SM80OrLater, SM90OrLater, with_tf32_off, TEST_CUDNN, + PLATFORM_SUPPORTS_FLASH_ATTENTION, SM53OrLater, SM80OrLater, SM90OrLater, with_tf32_off, TEST_CUDNN, _get_torch_cuda_version, _get_torch_rocm_version, ) from torch.testing._internal.common_utils import ( @@ -15937,9 +15937,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): op=lambda tensors, equation: torch.einsum(equation, tensors), dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), - backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, *[torch.bfloat16] - if (SM60OrLater or - TEST_WITH_ROCM) else []), + backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16), supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True,