diff --git a/fbgemm_gpu/src/jagged_tensor_ops.cu b/fbgemm_gpu/src/jagged_tensor_ops.cu index 46dc54a38e..184a52f0bb 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops.cu @@ -613,6 +613,7 @@ __global__ void jagged_dense_dense_elementwise_jagged_output_opt_gather_kernel_( bool jagged_dense_dense_elementwise_jagged_output_matches_opt( const int& num_jagged_dim, const Tensor& x_values, + const std::vector& x_offsets, const Tensor& y_0_reshaped, const Tensor& y_1_reshaped, const Tensor& output_values) { @@ -641,6 +642,35 @@ bool jagged_dense_dense_elementwise_jagged_output_matches_opt( matches &= (y_0_reshaped.size(0) < INT_MAX); matches &= (y_0_reshaped.size(1) < INT_MAX); + int max_shared_bytes; +#ifndef __HIP_PLATFORM_HCC__ + cudaDeviceGetAttribute( + &max_shared_bytes, + cudaDevAttrMaxSharedMemoryPerBlockOptin, + y_0_reshaped.get_device()); +#else + // MI100 has 64 KB local memory (shared memory) per workgroup + max_shared_bytes = 64 << 10; +#endif + int shared_kb = max_shared_bytes >> 10; +#ifndef __HIP_PLATFORM_HCC__ + // Use 2/3 of the available GPU shared mem; leave rooms for L1$. + int used_shared_kb = round_down(shared_kb * 2 / 3, 16); + TORCH_CHECK(used_shared_kb > 0); +#else + // MI100 has independent shared mem and L1 + int used_shared_kb = shared_kb; +#endif + int used_shared_bytes = used_shared_kb << 10; + AT_DISPATCH_INDEX_TYPES( + x_offsets[0].scalar_type(), "check_shared_memory", [&] { + auto B = y_0_reshaped.size(0); + // the default shared memory on V100/A100 is 48 KB from + // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#shared-memory-8-x + if ((B + 1) * sizeof(index_t) >= used_shared_bytes) { + matches = false; + } + }); return matches; } @@ -705,7 +735,12 @@ void jagged_dense_elementwise_jagged_output_opt_( // Canonicalize y to 3D, collapsing jagged dimensions. const Tensor y_reshaped = y.view({y.size(0), -1, y.size(-1)}); if (jagged_dense_dense_elementwise_jagged_output_matches_opt( - num_jagged_dim, x_values, y_reshaped, y_reshaped, output_values)) { + num_jagged_dim, + x_values, + x_offsets, + y_reshaped, + y_reshaped, + output_values)) { AT_DISPATCH_INDEX_TYPES( x_offsets[0].scalar_type(), "jagged_indices_fast_path", [=] { auto nnz = output_values.size(0); @@ -722,6 +757,36 @@ void jagged_dense_elementwise_jagged_output_opt_( // Binary search size_t dynamic_smem_size = (B + 1) * sizeof(index_t); + auto cur_max_shared_bytes = + at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock; + if (dynamic_smem_size > cur_max_shared_bytes) { + int max_shared_bytes; +#ifndef __HIP_PLATFORM_HCC__ + cudaDeviceGetAttribute( + &max_shared_bytes, + cudaDevAttrMaxSharedMemoryPerBlockOptin, + y_reshaped.get_device()); +#else + // MI100 has 64 KB local memory (shared memory) per workgroup + max_shared_bytes = 64 << 10; +#endif + int shared_kb = max_shared_bytes >> 10; +#ifndef __HIP_PLATFORM_HCC__ + // Use 2/3 of the available GPU shared mem; leave rooms for L1$. + int used_shared_kb = round_down(shared_kb * 2 / 3, 16); + TORCH_CHECK(used_shared_kb > 0); +#else + // MI100 has independent shared mem and L1 + int used_shared_kb = shared_kb; +#endif + int used_shared_bytes = used_shared_kb << 10; + cudaFuncSetAttribute( + jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_< + index_t>, + cudaFuncAttributeMaxDynamicSharedMemorySize, + used_shared_bytes); // V100: 64 KB; A100: 96 KB. + TORCH_CHECK(dynamic_smem_size <= used_shared_bytes); + } dim3 threads_bs = dim3(1024, 1, 1); dim3 blocks_bs = dim3(div_round_up(nnz, threads_bs.x), 1, 1); jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_< @@ -871,6 +936,7 @@ void jagged_dense_dense_elementwise_jagged_output_opt_( if (jagged_dense_dense_elementwise_jagged_output_matches_opt( num_jagged_dim, x_values, + x_offsets, y_0_reshaped, y_1_reshaped, output_values)) { @@ -892,6 +958,36 @@ void jagged_dense_dense_elementwise_jagged_output_opt_( // Binary search size_t dynamic_smem_size = (B + 1) * sizeof(index_t); + auto cur_max_shared_bytes = + at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock; + if (dynamic_smem_size > cur_max_shared_bytes) { + int max_shared_bytes; +#ifndef __HIP_PLATFORM_HCC__ + cudaDeviceGetAttribute( + &max_shared_bytes, + cudaDevAttrMaxSharedMemoryPerBlockOptin, + y_0_reshaped.get_device()); +#else + // MI100 has 64 KB local memory (shared memory) per workgroup + max_shared_bytes = 64 << 10; +#endif + int shared_kb = max_shared_bytes >> 10; +#ifndef __HIP_PLATFORM_HCC__ + // Use 2/3 of the available GPU shared mem; leave rooms for L1$. + int used_shared_kb = round_down(shared_kb * 2 / 3, 16); + TORCH_CHECK(used_shared_kb > 0); +#else + // MI100 has independent shared mem and L1 + int used_shared_kb = shared_kb; +#endif + int used_shared_bytes = used_shared_kb << 10; + cudaFuncSetAttribute( + jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_< + index_t>, + cudaFuncAttributeMaxDynamicSharedMemorySize, + used_shared_bytes); // V100: 64 KB; A100: 96 KB. + TORCH_CHECK(dynamic_smem_size <= used_shared_bytes); + } dim3 threads_bs = dim3(1024, 1, 1); dim3 blocks_bs = dim3(div_round_up(nnz, threads_bs.x), 1, 1); jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_< diff --git a/fbgemm_gpu/test/jagged_tensor_ops_test.py b/fbgemm_gpu/test/jagged_tensor_ops_test.py index 13bba6e8fb..b16e68b35e 100644 --- a/fbgemm_gpu/test/jagged_tensor_ops_test.py +++ b/fbgemm_gpu/test/jagged_tensor_ops_test.py @@ -678,6 +678,36 @@ def test_dense_to_jagged_opt( precompute_total_L, ) + # (8000+1) * 8 (size of the element of LongTensor/int64_t offsets) + # = ~62.5KB > 48KB default shared memory on V100/A100. + # pyre-ignore [56] + @given( + num_jagged_dim=st.just(1), + outer_dense_size=st.just(8000), + inner_dense_size=st.just(16), + dtype=st.just(torch.half), + use_cpu=st.just(False), + precompute_total_L=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) + def test_dense_to_jagged_opt_large_batch( + self, + num_jagged_dim: int, + outer_dense_size: int, + inner_dense_size: int, + dtype: torch.dtype, + use_cpu: bool, + precompute_total_L: bool, + ) -> None: + self._test_dense_to_jagged( + num_jagged_dim, + outer_dense_size, + inner_dense_size, + dtype, + use_cpu, + precompute_total_L, + ) + # pyre-ignore [56] @given( num_jagged_dim=st.integers(1, 5),