Skip to content

Commit

Permalink
fix jagged tensor shared memory issue (pytorch#1286)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#1286

Reviewed By: jspark1105, mjanderson09

Differential Revision: D39335880

fbshipit-source-id: 6527a483b113d07b81091676923a2a350bfaaf56
  • Loading branch information
jianyuh authored and facebook-github-bot committed Sep 8, 2022
1 parent f2c7c11 commit 8637956
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 1 deletion.
70 changes: 69 additions & 1 deletion fbgemm_gpu/src/jagged_tensor_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,7 @@ __global__ void jagged_dense_dense_elementwise_jagged_output_opt_gather_kernel_(
bool jagged_dense_dense_elementwise_jagged_output_matches_opt(
const int& num_jagged_dim,
const Tensor& x_values,
const std::vector<Tensor>& x_offsets,
const Tensor& y_0_reshaped,
const Tensor& y_1_reshaped,
const Tensor& output_values) {
Expand Down Expand Up @@ -641,6 +642,25 @@ bool jagged_dense_dense_elementwise_jagged_output_matches_opt(
matches &= (y_0_reshaped.size(0) < INT_MAX);
matches &= (y_0_reshaped.size(1) < INT_MAX);

int max_shared_bytes;
cudaDeviceGetAttribute(
&max_shared_bytes,
cudaDevAttrMaxSharedMemoryPerBlockOptin,
y_0_reshaped.get_device());
int shared_kb = max_shared_bytes >> 10;
// Use 2/3 of the available GPU shared mem; leave rooms for L1$.
int used_shared_kb = round_down(shared_kb * 2 / 3, 16);
TORCH_CHECK(used_shared_kb > 0);
int used_shared_bytes = used_shared_kb << 10;
AT_DISPATCH_INDEX_TYPES(
x_offsets[0].scalar_type(), "check_shared_memory", [&] {
auto B = y_0_reshaped.size(0);
// the default shared memory on V100/A100 is 48 KB from
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#shared-memory-8-x
if ((B + 1) * sizeof(index_t) >= used_shared_bytes) {
matches = false;
}
});
return matches;
}

Expand Down Expand Up @@ -705,7 +725,12 @@ void jagged_dense_elementwise_jagged_output_opt_(
// Canonicalize y to 3D, collapsing jagged dimensions.
const Tensor y_reshaped = y.view({y.size(0), -1, y.size(-1)});
if (jagged_dense_dense_elementwise_jagged_output_matches_opt(
num_jagged_dim, x_values, y_reshaped, y_reshaped, output_values)) {
num_jagged_dim,
x_values,
x_offsets,
y_reshaped,
y_reshaped,
output_values)) {
AT_DISPATCH_INDEX_TYPES(
x_offsets[0].scalar_type(), "jagged_indices_fast_path", [=] {
auto nnz = output_values.size(0);
Expand All @@ -722,6 +747,27 @@ void jagged_dense_elementwise_jagged_output_opt_(

// Binary search
size_t dynamic_smem_size = (B + 1) * sizeof(index_t);
auto cur_max_shared_bytes =
at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock;
if (dynamic_smem_size > cur_max_shared_bytes) {
int max_shared_bytes;
cudaDeviceGetAttribute(
&max_shared_bytes,
cudaDevAttrMaxSharedMemoryPerBlockOptin,
y_reshaped.get_device());

int shared_kb = max_shared_bytes >> 10;
// Use 2/3 of the available GPU shared mem; leave rooms for L1$.
int used_shared_kb = round_down(shared_kb * 2 / 3, 16);
TORCH_CHECK(used_shared_kb > 0);
int used_shared_bytes = used_shared_kb << 10;
cudaFuncSetAttribute(
jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_<
index_t>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
used_shared_bytes); // V100: 64 KB; A100: 96 KB.
TORCH_CHECK(dynamic_smem_size <= used_shared_bytes);
}
dim3 threads_bs = dim3(1024, 1, 1);
dim3 blocks_bs = dim3(div_round_up(nnz, threads_bs.x), 1, 1);
jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_<
Expand Down Expand Up @@ -871,6 +917,7 @@ void jagged_dense_dense_elementwise_jagged_output_opt_(
if (jagged_dense_dense_elementwise_jagged_output_matches_opt(
num_jagged_dim,
x_values,
x_offsets,
y_0_reshaped,
y_1_reshaped,
output_values)) {
Expand All @@ -892,6 +939,27 @@ void jagged_dense_dense_elementwise_jagged_output_opt_(
// Binary search
size_t dynamic_smem_size = (B + 1) * sizeof(index_t);
auto cur_max_shared_bytes =
at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock;
if (dynamic_smem_size > cur_max_shared_bytes) {
int max_shared_bytes;
cudaDeviceGetAttribute(
&max_shared_bytes,
cudaDevAttrMaxSharedMemoryPerBlockOptin,
y_0_reshaped.get_device());
int shared_kb = max_shared_bytes >> 10;
// Use 2/3 of the available GPU shared mem; leave rooms for L1$.
int used_shared_kb = round_down(shared_kb * 2 / 3, 16);
TORCH_CHECK(used_shared_kb > 0);
int used_shared_bytes = used_shared_kb << 10;
cudaFuncSetAttribute(
jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_<
index_t>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
used_shared_bytes); // V100: 64 KB; A100: 96 KB.
TORCH_CHECK(dynamic_smem_size <= used_shared_bytes);
}
dim3 threads_bs = dim3(1024, 1, 1);
dim3 blocks_bs = dim3(div_round_up(nnz, threads_bs.x), 1, 1);
jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_<
Expand Down
30 changes: 30 additions & 0 deletions fbgemm_gpu/test/jagged_tensor_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,36 @@ def test_dense_to_jagged_opt(
precompute_total_L,
)

# (8000+1) * 8 (size of the element of LongTensor/int64_t offsets)
# = ~62.5KB > 48KB default shared memory on V100/A100.
# pyre-ignore [56]
@given(
num_jagged_dim=st.just(1),
outer_dense_size=st.just(8000),
inner_dense_size=st.just(16),
dtype=st.just(torch.half),
use_cpu=st.just(False),
precompute_total_L=st.booleans(),
)
@settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None)
def test_dense_to_jagged_opt_large_batch(
self,
num_jagged_dim: int,
outer_dense_size: int,
inner_dense_size: int,
dtype: torch.dtype,
use_cpu: bool,
precompute_total_L: bool,
) -> None:
self._test_dense_to_jagged(
num_jagged_dim,
outer_dense_size,
inner_dense_size,
dtype,
use_cpu,
precompute_total_L,
)

# pyre-ignore [56]
@given(
num_jagged_dim=st.integers(1, 5),
Expand Down

0 comments on commit 8637956

Please sign in to comment.