From 59267a9ab796001d1af46e9e43211d49fcaf4c57 Mon Sep 17 00:00:00 2001 From: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com> Date: Tue, 25 Jan 2022 14:18:00 -0600 Subject: [PATCH 01/76] Hipification of fbgemm for AMD GPUs/CPUs (#4) * Hipify code * Add correctness check * Revert "Add correctness check" This reverts commit a7f169dcc862e5cc8102a39eb3b7882dfa888f1b. * Fix setup.py * Add run_all.sh * Update Zipf index generation Update the Zipf index generation to generate unique indices in each bag and shuffle indices to avoid spatial locality Code reference: https://github.com/pytorch/FBGEMM/blob/7588d9d804826b428fc0e4fd418e9cc3f7a72e52/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py#L98-L117 * Fix ROCm version check in fbgemm_gpu's setup.py * Fix hipification errors Modify code to fix hipification errors. Some ops/kernels including merge_pooled_embeddings, quantize_ops and embedding_forward_quantized_split ops are diabled currently. These ops will be enabled in the future. * Disable AVX512 for AMD CPUs AMD CPUs do not support AVX512. Thus, it has to be disabled in ROCm. * Update run_all.sh * Fix __launch_bounds__ with kWarpSize. * fix missing '#endif' in codegen/embedding_backward_code_generator.py * fix the dependencies import in setup.py * debug enum cudaMemeryAdvise * bypass the both cudaMemoryAdvise cudaMemAdvise are mapped to hipMemAdvise, in cumem_utils.cu * Build and import successfully but with NAN values. * NAN values are eliminated by bypassing res.vals[0] = hfma2( * Remove debug lines in include/fbgemm_gpu/fbgemm_cuda_utils.cuh Note: The tests of fbgemm-gpu do not pass. They will be addressed in future commits. Co-authored-by: Sarunya Pumma Co-authored-by: Li Li Co-authored-by: liligwu --- fbgemm_gpu/bench/CMakeLists.txt | 9 +- fbgemm_gpu/build.sh | 4 + .../embedding_backward_code_generator.py | 17 ++ .../codegen/embedding_backward_dense_host.cpp | 9 + ...embedding_backward_split_host_template.cpp | 10 +- ..._backward_split_indice_weights_template.cu | 8 + .../embedding_backward_split_template.cu | 36 ++- .../embedding_backward_template_helpers.cuh | 20 +- fbgemm_gpu/codegen/embedding_bounds_check.cu | 4 + ...edding_forward_quantized_split_template.cu | 32 +- .../embedding_forward_split_template.cu | 13 + fbgemm_gpu/defs.bzl | 19 ++ .../fbgemm_gpu/hipcub_namespace_postfix.cuh | 21 ++ .../fbgemm_gpu/hipcub_namespace_prefix.cuh | 16 + .../split_table_batched_embeddings_ops.py | 1 + fbgemm_gpu/include/fbgemm_gpu/enum_utils.h | 6 +- .../include/fbgemm_gpu/fbgemm_cuda_utils.cuh | 113 +++++-- .../fbgemm_gpu/hipcub_namespace_postfix.cuh | 21 ++ .../fbgemm_gpu/hipcub_namespace_postfix.hpp | 21 ++ .../fbgemm_gpu/hipcub_namespace_prefix.cuh | 16 + .../fbgemm_gpu/hipcub_namespace_prefix.hpp | 16 + .../include/fbgemm_gpu/quantize_ops.cuh | 10 +- fbgemm_gpu/include/fbgemm_gpu/sparse_ops.cuh | 4 + fbgemm_gpu/run_all.sh | 39 +++ fbgemm_gpu/setup.py | 289 +++++++++++++++++- fbgemm_gpu/src/cumem_utils.cu | 41 +++ .../src/merge_pooled_embeddings_gpu.cpp | 3 + fbgemm_gpu/src/sparse_ops.cu | 4 + fbgemm_gpu/src/split_embeddings_cache_cuda.cu | 56 ++++ fbgemm_gpu/src/split_embeddings_utils.cuh | 14 + src/EmbeddingSpMDM.cc | 2 + 31 files changed, 834 insertions(+), 40 deletions(-) create mode 100755 fbgemm_gpu/build.sh create mode 100644 fbgemm_gpu/fbgemm_gpu/hipcub_namespace_postfix.cuh create mode 100644 fbgemm_gpu/fbgemm_gpu/hipcub_namespace_prefix.cuh create mode 100644 fbgemm_gpu/include/fbgemm_gpu/hipcub_namespace_postfix.cuh create mode 100644 fbgemm_gpu/include/fbgemm_gpu/hipcub_namespace_postfix.hpp create mode 100644 fbgemm_gpu/include/fbgemm_gpu/hipcub_namespace_prefix.cuh create mode 100644 fbgemm_gpu/include/fbgemm_gpu/hipcub_namespace_prefix.hpp create mode 100755 fbgemm_gpu/run_all.sh diff --git a/fbgemm_gpu/bench/CMakeLists.txt b/fbgemm_gpu/bench/CMakeLists.txt index 0eadbabae..b1a1374d7 100644 --- a/fbgemm_gpu/bench/CMakeLists.txt +++ b/fbgemm_gpu/bench/CMakeLists.txt @@ -15,7 +15,10 @@ macro(add_benchmark BENCHNAME) CUDA_SEPARABLE_COMPILATION OFF CXX_STANDARD 11 CXX_EXTENSIONS NO) - target_link_libraries(${BENCHNAME} fbgemm_gpu -lcurand) + target_link_libraries(${BENCHNAME} fbgemm_gpu -L/opt/rocm/lib -lhiprand -lrocrand) + include_directories(${BENCHNAME} BEFORE + /opt/rocm/include/hiprand + /opt/rocm/include/rocrand) add_dependencies(${BENCHNAME} fbgemm_gpu) if (USE_SANITIZER) @@ -35,12 +38,12 @@ if(FBGEMMGPU_BUILD_BENCHMARKS) set(BENCHMARKS "") - file(GLOB BENCH_LIST "*_benchmark.cu") + file(GLOB BENCH_LIST "hip/*_benchmark.cpp") foreach(BENCH_FILE ${BENCH_LIST}) get_filename_component(BENCH_NAME "${BENCH_FILE}" NAME_WE) get_filename_component(BENCH_FILE_ONLY "${BENCH_FILE}" NAME) add_benchmark("${BENCH_NAME}" - "${BENCH_FILE_ONLY}") + "hip/${BENCH_FILE_ONLY}") list(APPEND BENCHMARKS "${BENCH_NAME}") endforeach() diff --git a/fbgemm_gpu/build.sh b/fbgemm_gpu/build.sh new file mode 100755 index 000000000..f181dd6a9 --- /dev/null +++ b/fbgemm_gpu/build.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +export MAX_JOBS=32 +python3.6 setup.py build develop 2>&1 | tee build.log diff --git a/fbgemm_gpu/codegen/embedding_backward_code_generator.py b/fbgemm_gpu/codegen/embedding_backward_code_generator.py index 9cfae837b..2d87581fd 100644 --- a/fbgemm_gpu/codegen/embedding_backward_code_generator.py +++ b/fbgemm_gpu/codegen/embedding_backward_code_generator.py @@ -397,7 +397,11 @@ def rowwise_adagrad() -> None: momentum1[idx] = new_sum_square_grads; multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps); } +#ifdef __HIP_PLATFORM_HCC__ + multiplier = __shfl(multiplier, 0); +#else multiplier = __shfl_sync(0xFFFFFFFF, multiplier, 0); +#endif """ split_weight_update_cpu = """ at::acc_type g_local_sum_square = 0.0; @@ -474,8 +478,13 @@ def rowwise_weighted_adagrad() -> None: multiplier = learning_rate * lambda / (cbrtf(new_sum_square_grads) + eps); correction = 1.0 - multiplier * weight_decay; } +#ifdef __HIP_PLATFORM_HCC__ + multiplier = __shfl(multiplier, 0); + correction = __shfl(correction, 0); +#else multiplier = __shfl_sync(0xFFFFFFFF, multiplier, 0); correction = __shfl_sync(0xFFFFFFFF, correction, 0); +#endif """ split_weight_update_cpu = """ // weight_decay not supported for cpu version @@ -636,7 +645,11 @@ def partial_rowwise_lamb() -> None: m2 = beta2 * momentum2[idx] + (1.0 - beta2) * g_avg_square; momentum2[idx] = m2; } +#ifdef __HIP_PLATFORM_HCC__ + m2 = __shfl(m2, 0); +#else m2 = __shfl_sync(0xFFFFFFFF, m2, 0); +#endif at::acc_type m2_hat = 1.0 / (sqrtf((m2 / (1.0 - powf(beta2, iter)))) + eps); at::acc_type weight_sum_sq = 0.0; @@ -772,7 +785,11 @@ def partial_rowwise_adam() -> None: momentum2[idx] = v_t; v_hat_t = v_t / (1.0 - powf(beta2, iter)); } +#ifdef __HIP_PLATFORM_HCC__ + v_hat_t = __shfl(v_hat_t, 0); +#else v_hat_t = __shfl_sync(0xFFFFFFFF, v_hat_t, 0); +#endif """ split_weight_update = """ diff --git a/fbgemm_gpu/codegen/embedding_backward_dense_host.cpp b/fbgemm_gpu/codegen/embedding_backward_dense_host.cpp index d8d16346c..8c76f867e 100644 --- a/fbgemm_gpu/codegen/embedding_backward_dense_host.cpp +++ b/fbgemm_gpu/codegen/embedding_backward_dense_host.cpp @@ -110,7 +110,11 @@ class SplitLookupFunction_Dense_Op ctx->saved_data["total_hash_size_bits"] = total_hash_size_bits; ctx->saved_data["pooling_mode"] = pooling_mode; +#ifdef __HIP_PLATFORM_HCC__ + constexpr int32_t BT_block_size = 64; +#else constexpr int32_t BT_block_size = 32; +#endif if (!indice_weights.has_value()) { return {dense_embedding_codegen_forward_unweighted_cuda( dev_weights, @@ -158,8 +162,13 @@ class SplitLookupFunction_Dense_Op TORCH_CHECK(grad_outputs.size() == 1); +#ifdef __HIP_PLATFORM_HCC__ + constexpr int32_t BT_block_size = 64; + constexpr int32_t max_segment_length_per_warp = 64; +#else constexpr int32_t BT_block_size = 32; constexpr int32_t max_segment_length_per_warp = 32; +#endif using torch::autograd::Variable; auto grad_output = grad_outputs[0]; diff --git a/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp index c31d69592..c3f519039 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp @@ -186,9 +186,12 @@ class Split{{ "NoBag" if nobag else "" }}LookupFunction_{{ optimizer }}_Op : {% for (var, _) in args.saved_data %} ctx->saved_data["{{ var }}"] = {{ var }}; {% endfor %} - {% if not nobag %} +#ifdef __HIP_PLATFORM_HCC__ + constexpr int32_t BT_block_size = 64; +#else constexpr int32_t BT_block_size = 32; +#endif if (!indice_weights) { return {split_embedding_codegen_forward_unweighted_cuda( dev_weights, uvm_weights, lxu_cache_weights, weights_placements, weights_offsets, @@ -256,8 +259,13 @@ class Split{{ "NoBag" if nobag else "" }}LookupFunction_{{ optimizer }}_Op : TORCH_CHECK(grad_outputs.size() == 1); +#ifdef __HIP_PLATFORM_HCC__ + constexpr int32_t BT_block_size = 64; + constexpr int32_t max_segment_length_per_warp = 64; +#else constexpr int32_t BT_block_size = 32; constexpr int32_t max_segment_length_per_warp = 32; +#endif using torch::autograd::Variable; auto grad_output = gradient_clipping ? clamp(grad_outputs[0], -max_gradient, max_gradient) : grad_outputs[0]; diff --git a/fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu index 628eb0f78..486297a46 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu @@ -103,9 +103,17 @@ __launch_bounds__(kForwardMaxThreads) void {{ "dense" if dense else "split" }}_e int32_t cache_idx = (placement == PlacementType::MANAGED_CACHING && l < L) ? lxu_cache_locations[indices_start + l] : 0; {% endif %} for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) { +#ifdef __HIP_PLATFORM_HCC__ + int64_t idx_j = __shfl(idx, j); +#else int64_t idx_j = __shfl_sync(0xFFFFFFFF, idx, j); +#endif {% if not dense %} +#ifdef __HIP_PLATFORM_HCC__ + int32_t cache_idx_j = __shfl(cache_idx, j); +#else int32_t cache_idx_j = __shfl_sync(0xFFFFFFFF, cache_idx, j); +#endif {% endif %} at::acc_type grad_indice_weight = 0.0; diff --git a/fbgemm_gpu/codegen/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_template.cu index b51f9bf20..f69a8dec2 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_template.cu @@ -190,13 +190,26 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ {% endif %} for (int32_t j = 0; j < kWarpSize && sl + j < sl_end; ++j) { {% if not nobag %} +#ifdef __HIP_PLATFORM_HCC__ + int32_t b_j = __shfl(b, j); + int32_t D_start_j = __shfl(D_start, j); +#else int32_t b_j = __shfl_sync(0xFFFFFFFF, b, j); int32_t D_start_j = __shfl_sync(0xFFFFFFFF, D_start, j); +#endif {% else %} +#ifdef __HIP_PLATFORM_HCC__ + int32_t l_j = __shfl(l, j); +#else int32_t l_j = __shfl_sync(0xFFFFFFFF, l, j); +#endif {% endif %} {% if weighted %} +#ifdef __HIP_PLATFORM_HCC__ + at::acc_type idx_weight_j = __shfl(idx_weight, j); +#else at::acc_type idx_weight_j = __shfl_sync(0xFFFFFFFF, idx_weight, j); +#endif {% endif %} #pragma unroll kMaxVecsPerThread @@ -549,13 +562,26 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ for (int32_t j = 0; j < kWarpSize && sl + j < sl_end; ++j) { {% if not nobag %} +#ifdef __HIP_PLATFORM_HCC__ + int32_t b_j = __shfl(b, j); + int32_t D_start_j = __shfl(D_start, j); +#else int32_t b_j = __shfl_sync(0xFFFFFFFF, b, j); int32_t D_start_j = __shfl_sync(0xFFFFFFFF, D_start, j); +#endif {% else %} +#ifdef __HIP_PLATFORM_HCC__ + int32_t l_j = __shfl(l, j); +#else int32_t l_j = __shfl_sync(0xFFFFFFFF, l, j); +#endif {% endif %} {% if weighted %} +#ifdef __HIP_PLATFORM_HCC__ + at::acc_type idx_weight_j = __shfl(idx_weight, j); +#else at::acc_type idx_weight_j = __shfl_sync(0xFFFFFFFF, idx_weight, j); +#endif {% endif %} #pragma unroll kMaxVecsPerThread @@ -736,6 +762,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ TORCH_CHECK(D <= {{ max_embedding_dim }}); {% endif %} +#ifndef __HIP_PLATFORM_HCC__ // V100: 96 KB; A100: 160 KB. int max_shared_bytes = 0; cudaDeviceGetAttribute(&max_shared_bytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev_weights.get_device()); @@ -746,6 +773,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ int used_shared_kb = round_down(shared_kb * 2 / 3, 16); TORCH_CHECK(used_shared_kb > 0); int used_shared_bytes = used_shared_kb << 10; +#endif {% if not nobag %} auto infos = at::empty_like(indices, indices.options().dtype(at::kInt)); @@ -756,7 +784,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ auto linear_indices = at::empty_like(indices); auto linear_indices_sorted = at::empty_like(indices); {% if not nobag %} - linearize_index_kernel<<< + linearize_index_kernel<<< div_round_up(B * T, kMaxThreads), kMaxThreads, 0, @@ -990,10 +1018,12 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ {% else %} if (D <= {{ 128 * kMaxVecsPerThread }}) { {% endif %} +#ifndef __HIP_PLATFORM_HCC__ // Stay under used_shared_kb of shared memory (V100: 64 KB; A100: 96 KB), BT_block_size must be a power of two. while (BT_block_size * sizeof(at::acc_type<{{ "scalar_t" if dense else "cache_t" }}, true>) * 4 * kWarpSize * {{ kMaxVecsPerThread }} >= used_shared_bytes) { BT_block_size /= 2; } +#endif TORCH_CHECK(BT_block_size >= 1); if (std::is_same<{{ "scalar_t" if dense else "emb_t" }}, double>::value) { // Otherwise we see CUDA kernel launch failures despite the above checks. @@ -1022,6 +1052,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ // over 48 KB per block are architecture-specific, as such they // must use dynamic shared memory (rather than statically sized // arrays) and require an explicit opt-in using cudaFuncSetAttribute()". +#ifndef __HIP_PLATFORM_HCC__ cudaFuncSetAttribute( split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_kernel_cta_per_row_1< {% if not dense %} @@ -1034,6 +1065,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ {{ kMaxVecsPerThread }}>, cudaFuncAttributeMaxDynamicSharedMemorySize, used_shared_bytes); // V100: 64 KB; A100: 96 KB. +#endif C10_CUDA_KERNEL_LAUNCH_CHECK(); split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_kernel_cta_per_row_1< {% if not dense %} @@ -1096,6 +1128,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ {% endif %} {{ args.split_kernel_arg_constructors | join(", ") }}); C10_CUDA_KERNEL_LAUNCH_CHECK(); +#ifndef __HIP_PLATFORM_HCC__ cudaFuncSetAttribute( split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_kernel_warp_per_row_1< {% if not dense %} @@ -1108,6 +1141,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ {{ kMaxVecsPerThread }}>, cudaFuncAttributeMaxDynamicSharedMemorySize, used_shared_bytes); // V100: 64 KB; A100: 96 KB. +#endif C10_CUDA_KERNEL_LAUNCH_CHECK(); split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_kernel_warp_per_row_1< {% if not dense %} diff --git a/fbgemm_gpu/codegen/embedding_backward_template_helpers.cuh b/fbgemm_gpu/codegen/embedding_backward_template_helpers.cuh index c6cbaed22..3e79e8fce 100644 --- a/fbgemm_gpu/codegen/embedding_backward_template_helpers.cuh +++ b/fbgemm_gpu/codegen/embedding_backward_template_helpers.cuh @@ -69,8 +69,10 @@ inline at::Tensor asynchronous_complete_cumsum(at::Tensor t_in) { return t_out; } -template -__global__ void linearize_index_kernel( +template +__global__ void __launch_bounds__(kMaxThreads) +linearize_index_kernel( const at::PackedTensorAccessor32 hash_size_cumsum, const at::PackedTensorAccessor32 indices, @@ -91,10 +93,17 @@ __global__ void linearize_index_kernel( int32_t lane_id = threadIdx.x % fbgemm_gpu::kWarpSize; for (int32_t j = 0; j < fbgemm_gpu::kWarpSize; ++j) { +#ifdef __HIP_PLATFORM_HCC__ + index_t indices_start_warp = __shfl(indices_start, j); + int32_t b_t_warp = __shfl(b_t, j); + int32_t L_warp = __shfl(L, j); + index_t hash_offset_warp = __shfl(hash_offset, j); +#else index_t indices_start_warp = __shfl_sync(0xFFFFFFFF, indices_start, j); int32_t b_t_warp = __shfl_sync(0xFFFFFFFF, b_t, j); int32_t L_warp = __shfl_sync(0xFFFFFFFF, L, j); index_t hash_offset_warp = __shfl_sync(0xFFFFFFFF, hash_offset, j); +#endif for (int32_t i = lane_id; i < L_warp; i += fbgemm_gpu::kWarpSize) { index_t idx = __ldg(&indices[indices_start_warp + i]); infos[indices_start_warp + i] = b_t_warp; @@ -125,10 +134,17 @@ __global__ void nobag_linearize_index_kernel( int32_t lane_id = threadIdx.x % fbgemm_gpu::kWarpSize; for (int32_t j = 0; j < fbgemm_gpu::kWarpSize; ++j) { +#ifdef __HIP_PLATFORM_HCC__ + index_t indices_start_warp = __shfl(indices_start, j); + int32_t t_warp = __shfl(t, j); + int32_t L_warp = __shfl(L, j); + index_t hash_offset_warp = __shfl(hash_offset, j); +#else index_t indices_start_warp = __shfl_sync(0xFFFFFFFF, indices_start, j); int32_t t_warp = __shfl_sync(0xFFFFFFFF, t, j); int32_t L_warp = __shfl_sync(0xFFFFFFFF, L, j); index_t hash_offset_warp = __shfl_sync(0xFFFFFFFF, hash_offset, j); +#endif for (int32_t i = lane_id; i < L_warp; i += fbgemm_gpu::kWarpSize) { index_t idx = __ldg(&indices[indices_start_warp + i]); int64_t l_t = (indices_start_warp + i) * T + t_warp; diff --git a/fbgemm_gpu/codegen/embedding_bounds_check.cu b/fbgemm_gpu/codegen/embedding_bounds_check.cu index 9a7aaa636..9b1c706e5 100644 --- a/fbgemm_gpu/codegen/embedding_bounds_check.cu +++ b/fbgemm_gpu/codegen/embedding_bounds_check.cu @@ -69,7 +69,11 @@ __global__ void bounds_check_indices_kernel( } auto L = indices_end - indices_start; +#ifdef __HIP_PLATFORM_HCC__ + for (index_t i = (index_t) threadIdx.x; i < L; i += (index_t) fbgemm_gpu::kWarpSize) { +#else for (auto i = threadIdx.x; i < L; i += fbgemm_gpu::kWarpSize) { +#endif auto idx = indices[indices_start + i]; if (idx == -1) { // -1 indicates pruned rows. diff --git a/fbgemm_gpu/codegen/embedding_forward_quantized_split_template.cu b/fbgemm_gpu/codegen/embedding_forward_quantized_split_template.cu index 787ebbea5..fa2b56df9 100644 --- a/fbgemm_gpu/codegen/embedding_forward_quantized_split_template.cu +++ b/fbgemm_gpu/codegen/embedding_forward_quantized_split_template.cu @@ -172,7 +172,7 @@ void cp_async_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard) { // TODO: increase code sharing (templates for accumulator_ty, accumulation, outputs per thread, etc?) template -__launch_bounds__(WarpsPerBlock * 32) +__launch_bounds__(WarpsPerBlock * kWarpSize) __global__ void fp32_split_embedding_codegen_forward_{{ wdesc }}_kernel_small_L( const at::PackedTensorAccessor64 dev_weights, const at::PackedTensorAccessor64 uvm_weights, @@ -291,7 +291,11 @@ __global__ void fp32_split_embedding_codegen_forward_{{ wdesc }}_kernel_small_L( } // equivalent to fence + wait. cp_async_wait<0>(); +#ifdef __HIP_PLATFORM_HCC__ + __syncthreads(); +#else __syncwarp(); +#endif for (uint32_t input_row_idx = 0; input_row_idx < input_rows_in_flight; ++input_row_idx) { #pragma unroll OutputRowsPerThread for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { @@ -368,7 +372,7 @@ __global__ void fp32_split_embedding_codegen_forward_{{ wdesc }}_kernel_small_L( // TODO: increase code sharing (templates for accumulator_ty, accumulation, outputs per thread, etc?) template -__launch_bounds__(WarpsPerBlock * 32) +__launch_bounds__(WarpsPerBlock * kWarpSize) __global__ void fp16_split_embedding_codegen_forward_{{ wdesc }}_kernel_small_L( const at::PackedTensorAccessor64 dev_weights, const at::PackedTensorAccessor64 uvm_weights, @@ -488,7 +492,11 @@ __global__ void fp16_split_embedding_codegen_forward_{{ wdesc }}_kernel_small_L( } // equivalent to fence + wait. cp_async_wait<0>(); +#ifdef __HIP_PLATFORM_HCC__ + __syncthreads(); +#else __syncwarp(); +#endif for (uint32_t input_row_idx = 0; input_row_idx < input_rows_in_flight; ++input_row_idx) { #pragma unroll OutputRowsPerThread for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { @@ -570,7 +578,7 @@ __global__ void fp16_split_embedding_codegen_forward_{{ wdesc }}_kernel_small_L( } template -__launch_bounds__(WarpsPerBlock * 32) +__launch_bounds__(WarpsPerBlock * kWarpSize) __global__ void int_8bit_split_embedding_codegen_forward_{{ wdesc }}_kernel_small_L( const at::PackedTensorAccessor64 dev_weights, const at::PackedTensorAccessor64 uvm_weights, @@ -690,7 +698,11 @@ __global__ void int_8bit_split_embedding_codegen_forward_{{ wdesc }}_kernel_smal } // equivalent to fence + wait. cp_async_wait<0>(); +#ifdef __HIP_PLATFORM_HCC__ + __syncthreads(); +#else __syncwarp(); +#endif for (uint32_t input_row_idx = 0; input_row_idx < input_rows_in_flight; ++input_row_idx) { #pragma unroll OutputRowsPerThread for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { @@ -773,7 +785,7 @@ __global__ void int_8bit_split_embedding_codegen_forward_{{ wdesc }}_kernel_smal } template -__launch_bounds__(WarpsPerBlock * 32) +__launch_bounds__(WarpsPerBlock * kWarpSize) __global__ void int_4bit_split_embedding_codegen_forward_{{ wdesc }}_kernel_small_L( const at::PackedTensorAccessor64 dev_weights, const at::PackedTensorAccessor64 uvm_weights, @@ -893,7 +905,11 @@ __global__ void int_4bit_split_embedding_codegen_forward_{{ wdesc }}_kernel_smal } // equivalent to fence + wait. cp_async_wait<0>(); +#ifdef __HIP_PLATFORM_HCC__ + __syncthreads(); +#else __syncwarp(); +#endif for (uint32_t input_row_idx = 0; input_row_idx < input_rows_in_flight; ++input_row_idx) { #pragma unroll OutputRowsPerThread for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { @@ -1037,9 +1053,17 @@ __global__ void int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_{ found = true; dense_indices[indices_start + l_start + subwarp_id] = slot_dense_idx; } +#ifdef __HIP_PLATFORM_HCC__ + if (__any(found)) { +#else if (__any_sync(subwarp_mask, found)) { +#endif break; +#ifdef __HIP_PLATFORM_HCC__ + } else if (__any(empty)) { +#else } else if (__any_sync(subwarp_mask, empty)) { +#endif dense_indices[indices_start + l_start + subwarp_id] = -1; break; } diff --git a/fbgemm_gpu/codegen/embedding_forward_split_template.cu b/fbgemm_gpu/codegen/embedding_forward_split_template.cu index b30e5ebd7..d4bcca1eb 100644 --- a/fbgemm_gpu/codegen/embedding_forward_split_template.cu +++ b/fbgemm_gpu/codegen/embedding_forward_split_template.cu @@ -127,16 +127,29 @@ __global__ void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if noba at::acc_type idx_weight = l < L ? indice_weights[indices_start + l] : 0; {% endif %} for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) { +#ifdef __HIP_PLATFORM_HCC__ + int64_t idx_j = __shfl(idx, j); +#else int64_t idx_j = __shfl_sync(0xFFFFFFFF, idx, j); +#endif + {% if nobag %} int64_t output_j = indices_start + l_start + j; {% endif %} {% if not dense %} +#ifdef __HIP_PLATFORM_HCC__ + int32_t cache_idx_j = __shfl(cache_idx, j); +#else int32_t cache_idx_j = __shfl_sync(0xFFFFFFFF, cache_idx, j); +#endif {% endif %} {% if weighted %} +#ifdef __HIP_PLATFORM_HCC__ + at::acc_type idx_weight_j = __shfl(idx_weight, j); +#else at::acc_type idx_weight_j = __shfl_sync(0xFFFFFFFF, idx_weight, j); +#endif {% endif %} {% if not dense %} diff --git a/fbgemm_gpu/defs.bzl b/fbgemm_gpu/defs.bzl index b773f11fc..d3b92d672 100644 --- a/fbgemm_gpu/defs.bzl +++ b/fbgemm_gpu/defs.bzl @@ -15,3 +15,22 @@ def get_fbgemm_gpu_public_headers(): "include/fbgemm_gpu/sparse_ops.cuh", "include/fbgemm_gpu/layout_transform_ops.cuh", ] + +def get_fbgemm_gpu_wrapper_srcs_hip(): + return [ + "src/hip/batched_unary_embedding_wrappers.cpp", + "src/hip/quantize_wrappers.cpp", + "src/hip/sparse_wrappers.cpp", + ] + +def get_fbgemm_gpu_public_headers_hip(): + return [ + "include/hip/fbgemm_gpu/batched_unary_embedding_ops.cuh", + "include/hip/fbgemm_gpu/batched_unary_embedding_wrappers.cuh", + "include/hip/fbgemm_gpu/bench_utils.cuh", + "include/hip/fbgemm_gpu/cuda_utils.cuh", + "include/hip/fbgemm_gpu/quantize_ops.cuh", + "include/hip/fbgemm_gpu/quantize_wrappers.cuh", + "include/hip/fbgemm_gpu/sparse_ops.cuh", + "include/hip/fbgemm_gpu/sparse_wrappers.cuh", + ] diff --git a/fbgemm_gpu/fbgemm_gpu/hipcub_namespace_postfix.cuh b/fbgemm_gpu/fbgemm_gpu/hipcub_namespace_postfix.cuh new file mode 100644 index 000000000..8922edbba --- /dev/null +++ b/fbgemm_gpu/fbgemm_gpu/hipcub_namespace_postfix.cuh @@ -0,0 +1,21 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#undef FBGEMM_GPU_CUB_NS_PREFIX + +#ifdef FBGEMM_CUB_USE_NAMESPACE + +#undef CUB_NS_PREFIX +#undef CUB_NS_POSTFIX + +#define FBGEMM_GPU_CUB_NS_PREFIX fbgemm_gpu:: + +#else + +#define FBGEMM_GPU_CUB_NS_PREFIX + +#endif diff --git a/fbgemm_gpu/fbgemm_gpu/hipcub_namespace_prefix.cuh b/fbgemm_gpu/fbgemm_gpu/hipcub_namespace_prefix.cuh new file mode 100644 index 000000000..c977653fa --- /dev/null +++ b/fbgemm_gpu/fbgemm_gpu/hipcub_namespace_prefix.cuh @@ -0,0 +1,16 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifdef FBGEMM_CUB_USE_NAMESPACE + +#undef CUB_NS_PREFIX +#undef CUB_NS_POSTFIX + +#define CUB_NS_PREFIX namespace fbgemm_gpu { +#define CUB_NS_POSTFIX } // namespace fbgemm_gpu + +#endif diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py index 221a5ea89..8aae5b4cb 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py @@ -1718,6 +1718,7 @@ def max_ty_D(ty: SparseType) -> int: def align_to_cacheline(a: int) -> int: # align each table to 128b cache line boundary. + # TODO: Change cache line boundary for ROCM return round_up(a, 128) weights_tys_int = [weights_tys[t].as_int() for t in self.feature_table_map] diff --git a/fbgemm_gpu/include/fbgemm_gpu/enum_utils.h b/fbgemm_gpu/include/fbgemm_gpu/enum_utils.h index 2d04a13a5..3a6e55cde 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/enum_utils.h +++ b/fbgemm_gpu/include/fbgemm_gpu/enum_utils.h @@ -15,9 +15,11 @@ namespace fbgemm_gpu { #define FBGEMM_GPU_ENUM_CREATE_TAG(module_name) \ struct fbgemm_gpu_enum_tag_##module_name {}; \ - extern template enum_registration* \ + template <> enum_registration* \ enum_registration< \ - struct fbgemm_gpu_enum_tag_##module_name>::registration_list; + struct fbgemm_gpu_enum_tag_##module_name>::registration_list; \ + extern template class enum_registration< \ + struct fbgemm_gpu_enum_tag_##module_name>; #define FBGEMM_GPU_ENUM_TAG(module_name) \ struct fbgemm_gpu_enum_tag_##module_name diff --git a/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh b/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh index 85f02bc17..d033645bf 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh @@ -19,7 +19,11 @@ namespace fbgemm_gpu { #define DEVICE_INLINE __device__ inline __attribute__((always_inline)) // Warp size +#ifdef __HIP_PLATFORM_HCC__ +static constexpr int32_t kWarpSize = 64; +#else static constexpr int32_t kWarpSize = 32; +#endif // Max thread num in one thread block static constexpr int32_t kMaxThreads = 1024; static constexpr float kQParamEps = 1e-8f; @@ -36,7 +40,12 @@ struct Half4 { half2 b; __device__ inline void store(at::Half* p) { -#if CUDA_VERSION >= 9000 +#ifdef __HIP_PLATFORM_HCC__ + p[0] = __low2half(a); + p[1] = __high2half(a); + p[2] = __low2half(b); + p[3] = __high2half(b); +#elif CUDA_VERSION >= 9000 #ifndef __HALF2_TO_UI // cuda_fp16.hpp doesn't export this @@ -79,6 +88,12 @@ struct Vec4T { } DEVICE_INLINE Vec4T(const at::Half* p) { +#ifdef __HIP_PLATFORM_HCC__ + acc.x = __half2float(p[0]); + acc.y = __half2float(p[1]); + acc.z = __half2float(p[2]); + acc.w = __half2float(p[3]); +#else Half4 out; #if CUDA_VERSION >= 9000 asm("ld.global.v2.u32 {%0, %1}, [%2];" @@ -97,6 +112,7 @@ struct Vec4T { acc.y = a.y; acc.z = b.x; acc.w = b.y; +#endif } DEVICE_INLINE void store(float* p) { @@ -173,6 +189,12 @@ struct Vec4T { } DEVICE_INLINE Vec4T(const at::Half* p) { +#ifdef __HIP_PLATFORM_HCC__ + acc.x = __half2float(p[0]); + acc.y = __half2float(p[1]); + acc.z = __half2float(p[2]); + acc.w = __half2float(p[3]); +#else Half4 out; #if CUDA_VERSION >= 9000 asm("ld.global.v2.u32 {%0, %1}, [%2];" @@ -191,6 +213,7 @@ struct Vec4T { acc.y = a.y; acc.z = b.x; acc.w = b.y; +#endif } DEVICE_INLINE Vec4T(const float* p) { @@ -235,6 +258,12 @@ struct Vec4T { } DEVICE_INLINE static void copy(const at::Half* src, at::Half* dst) { +#ifdef __HIP_PLATFORM_HCC__ + dst[0] = src[0]; + dst[1] = src[1]; + dst[2] = src[2]; + dst[3] = src[3]; +#else Half4 out; #if CUDA_VERSION >= 9000 asm("ld.global.v2.u32 {%0, %1}, [%2];" @@ -251,6 +280,7 @@ struct Vec4T { : "l"(dst), "r"(__HALF2_TO_UI(out.a)), "r"(__HALF2_TO_UI(out.b))); #else asm("st.v2.u32 [%0], {%1, %2};" : : "l"(dst), "r"(out.a.x), "r"(out.b.x)); +#endif #endif } @@ -305,6 +335,12 @@ struct Vec4T { } DEVICE_INLINE Vec4T(const at::Half* p) { +#ifdef __HIP_PLATFORM_HCC__ + acc.x = __half2float(p[0]); + acc.y = __half2float(p[1]); + acc.z = __half2float(p[2]); + acc.w = __half2float(p[3]); +#else Half4 out; #if CUDA_VERSION >= 9000 asm("ld.global.v2.u32 {%0, %1}, [%2];" @@ -323,6 +359,7 @@ struct Vec4T { acc.y = a.y; acc.z = b.x; acc.w = b.y; +#endif } DEVICE_INLINE Vec4T(const float* p) { @@ -406,7 +443,9 @@ DEVICE_INLINE Vec4T vec4_acc( template DEVICE_INLINE T shfl_xor(const T val, int laneMask, int width = kWarpSize) { -#if CUDA_VERSION >= 9000 +#ifdef __HIP_PLATFORM_HCC__ + return __shfl_xor(val, laneMask, width); +#elif CUDA_VERSION >= 9000 return __shfl_xor_sync(0xffffffff, val, laneMask, width); #else return __shfl_xor(val, laneMask, width); @@ -418,7 +457,11 @@ template DEVICE_INLINE T warpReduceAllSum(T val) { #pragma unroll for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) { +#ifdef __HIP_PLATFORM_HCC__ + val += __shfl_xor(val, mask); +#else val += shfl_xor(val, mask); +#endif } return val; } @@ -517,10 +560,17 @@ DEVICE_INLINE void stochastic_rounding_vector( float2 /* not used */) { uint4 random_bits = stochastic_rounding_rand4(&state); Half4 v; +#ifdef __HIP_PLATFORM_HCC__ + v.a = __halves2half2(stochastic_rounding_scalar(value.acc.x, random_bits.x), + stochastic_rounding_scalar(value.acc.y, random_bits.y)); + v.b = __halves2half2(stochastic_rounding_scalar(value.acc.z, random_bits.z), + stochastic_rounding_scalar(value.acc.w, random_bits.w)); +#else v.a.x = stochastic_rounding_scalar(value.acc.x, random_bits.x); v.a.y = stochastic_rounding_scalar(value.acc.y, random_bits.y); v.b.x = stochastic_rounding_scalar(value.acc.z, random_bits.z); v.b.y = stochastic_rounding_scalar(value.acc.w, random_bits.w); +#endif v.store(output); } @@ -532,10 +582,17 @@ DEVICE_INLINE void stochastic_rounding_vector( float2 /* not used */) { uint4 random_bits = stochastic_rounding_rand4(&state); Half4 v; +#ifdef __HIP_PLATFORM_HCC__ + v.a = __halves2half2(stochastic_rounding_scalar(value.acc.x, random_bits.x), + stochastic_rounding_scalar(value.acc.y, random_bits.y)); + v.b = __halves2half2(stochastic_rounding_scalar(value.acc.z, random_bits.z), + stochastic_rounding_scalar(value.acc.w, random_bits.w)); +#else v.a.x = stochastic_rounding_scalar(value.acc.x, random_bits.x); v.a.y = stochastic_rounding_scalar(value.acc.y, random_bits.y); v.b.x = stochastic_rounding_scalar(value.acc.z, random_bits.z); v.b.y = stochastic_rounding_scalar(value.acc.w, random_bits.w); +#endif v.store(output); } @@ -879,8 +936,13 @@ __device__ float2 warp_find_qparams(scalar_t local_min, scalar_t local_max) { qparams.x = (local_max - local_min) / 255.0f; qparams.y = local_min; } +#ifdef __HIP_PLATFORM_HCC__ + qparams.x = __shfl(qparams.x, 0); + qparams.y = __shfl(qparams.y, 0); +#else qparams.x = __shfl_sync(0xFFFFFFFF, qparams.x, 0); qparams.y = __shfl_sync(0xFFFFFFFF, qparams.y, 0); +#endif return qparams; } @@ -940,6 +1002,9 @@ DEVICE_INLINE float8 make_zero_float8() { __forceinline__ __device__ __half2 hfma2(const __half2 a, const __half2 b, const __half2 c) { +#ifdef __HIP_PLATFORM_HCC__ + return __hfma2(a, b, c); +#else #if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610 return __hfma2(a, b, c); #else @@ -951,14 +1016,19 @@ hfma2(const __half2 a, const __half2 b, const __half2 c) { fc.y = fa.y * fb.y + fc.y; return __float22half2_rn(fc); #endif +#endif } __forceinline__ __device__ half hmul(half a, half b) { +#ifdef __HIP_PLATFORM_HCC__ + return __hmul(a, b); +#else #if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610 return __hmul(a, b); #else return __float2half(__half2float(a) * __half2float(b)); #endif +#endif } // Reinterpret a pair of uint16_t (packed into a uint32_t) as half2, and @@ -1011,28 +1081,31 @@ dequantize_permuted_int4(uint32_t packedVals, __half2 shift_scale) { // on each 4-bit value is expensive on the ALU, and 4-bit to half is expensive // on the XU. b) doing a 256-entry shared memory LUT on 8-bit pairs is // expensive on SMEM throughput. Credit to @jhj. - res.vals[0] = hmul_short2(v & 0x000F000F, 32768); - res.vals[1] = hmul_short2(v & 0x00F000F0, 32768); + res.vals[0] = hmul_short2(v & 0x000F000F, __int2half_rn(32768)); + res.vals[1] = hmul_short2(v & 0x00F000F0, __int2half_rn(32768)); v >>= 8; - res.vals[2] = hmul_short2(v & 0x000F000F, 32768); - res.vals[3] = hmul_short2(v & 0x00F000F0, 32768); + res.vals[2] = hmul_short2(v & 0x000F000F, __int2half_rn(32768)); + res.vals[3] = hmul_short2(v & 0x00F000F0, __int2half_rn(32768)); + // TODO: Enable this for HIP +#ifndef __HIP_PLATFORM_HCC__ res.vals[0] = hfma2( res.vals[0], - __half2(hmul(shift_scale.x, 512), hmul(shift_scale.x, 512)), - __half2(shift_scale.y, shift_scale.y)); + __half2(hmul(__ushort2half_rn(shift_scale.x), __int2half_rn(512)), hmul(__ushort2half_rn(shift_scale.x), __int2half_rn(512))), + __half2(__ushort2half_rn(shift_scale.y), __ushort2half_rn(shift_scale.y))); res.vals[1] = hfma2( res.vals[1], - __half2(hmul(shift_scale.x, 32), hmul(shift_scale.x, 32)), - __half2(shift_scale.y, shift_scale.y)); + __half2(hmul(__ushort2half_rn(shift_scale.x), __int2half_rn(32)), hmul(__ushort2half_rn(shift_scale.x), __int2half_rn(32))), + __half2(__ushort2half_rn(shift_scale.y), __ushort2half_rn(shift_scale.y))); res.vals[2] = hfma2( res.vals[2], - __half2(hmul(shift_scale.x, 512), hmul(shift_scale.x, 512)), - __half2(shift_scale.y, shift_scale.y)); + __half2(hmul(__ushort2half_rn(shift_scale.x), __int2half_rn(512)), hmul(__ushort2half_rn(shift_scale.x), __int2half_rn(512))), + __half2(__ushort2half_rn(shift_scale.y), __ushort2half_rn(shift_scale.y))); res.vals[3] = hfma2( res.vals[3], - __half2(hmul(shift_scale.x, 32), hmul(shift_scale.x, 32)), - __half2(shift_scale.y, shift_scale.y)); + __half2(hmul(__ushort2half_rn(shift_scale.x), __int2half_rn(32)), hmul(__ushort2half_rn(shift_scale.x), __int2half_rn(32))), + __half2(__ushort2half_rn(shift_scale.y), __ushort2half_rn(shift_scale.y))); +#endif return res; } @@ -1041,17 +1114,17 @@ dequantize_permuted_int8(uint32_t packedVals, __half2 shift_scale) { half4 res; uint32_t v = packedVals; // See comment above, this is a minor variation. - res.vals[0] = hmul_short2(v & 0x00FF00FF, 32768); + res.vals[0] = hmul_short2(v & 0x00FF00FF, __int2half_rn(32768)); v >>= 8; - res.vals[1] = hmul_short2(v & 0x00FF00FF, 32768); + res.vals[1] = hmul_short2(v & 0x00FF00FF, __int2half_rn(32768)); res.vals[0] = hfma2( res.vals[0], - __half2(hmul(shift_scale.x, 512), hmul(shift_scale.x, 512)), - __half2(shift_scale.y, shift_scale.y)); + __half2(hmul(__ushort2half_rn(shift_scale.x), __int2half_rn(512)), hmul(__ushort2half_rn(shift_scale.x), __int2half_rn(512))), + __half2(__ushort2half_rn(shift_scale.y), __ushort2half_rn(shift_scale.y))); res.vals[1] = hfma2( res.vals[1], - __half2(hmul(shift_scale.x, 512), hmul(shift_scale.x, 512)), - __half2(shift_scale.y, shift_scale.y)); + __half2(hmul(__ushort2half_rn(shift_scale.x), __int2half_rn(512)), hmul(__ushort2half_rn(shift_scale.x), __int2half_rn(512))), + __half2(__ushort2half_rn(shift_scale.y), __ushort2half_rn(shift_scale.y))); return res; } diff --git a/fbgemm_gpu/include/fbgemm_gpu/hipcub_namespace_postfix.cuh b/fbgemm_gpu/include/fbgemm_gpu/hipcub_namespace_postfix.cuh new file mode 100644 index 000000000..8922edbba --- /dev/null +++ b/fbgemm_gpu/include/fbgemm_gpu/hipcub_namespace_postfix.cuh @@ -0,0 +1,21 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#undef FBGEMM_GPU_CUB_NS_PREFIX + +#ifdef FBGEMM_CUB_USE_NAMESPACE + +#undef CUB_NS_PREFIX +#undef CUB_NS_POSTFIX + +#define FBGEMM_GPU_CUB_NS_PREFIX fbgemm_gpu:: + +#else + +#define FBGEMM_GPU_CUB_NS_PREFIX + +#endif diff --git a/fbgemm_gpu/include/fbgemm_gpu/hipcub_namespace_postfix.hpp b/fbgemm_gpu/include/fbgemm_gpu/hipcub_namespace_postfix.hpp new file mode 100644 index 000000000..8922edbba --- /dev/null +++ b/fbgemm_gpu/include/fbgemm_gpu/hipcub_namespace_postfix.hpp @@ -0,0 +1,21 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#undef FBGEMM_GPU_CUB_NS_PREFIX + +#ifdef FBGEMM_CUB_USE_NAMESPACE + +#undef CUB_NS_PREFIX +#undef CUB_NS_POSTFIX + +#define FBGEMM_GPU_CUB_NS_PREFIX fbgemm_gpu:: + +#else + +#define FBGEMM_GPU_CUB_NS_PREFIX + +#endif diff --git a/fbgemm_gpu/include/fbgemm_gpu/hipcub_namespace_prefix.cuh b/fbgemm_gpu/include/fbgemm_gpu/hipcub_namespace_prefix.cuh new file mode 100644 index 000000000..c977653fa --- /dev/null +++ b/fbgemm_gpu/include/fbgemm_gpu/hipcub_namespace_prefix.cuh @@ -0,0 +1,16 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifdef FBGEMM_CUB_USE_NAMESPACE + +#undef CUB_NS_PREFIX +#undef CUB_NS_POSTFIX + +#define CUB_NS_PREFIX namespace fbgemm_gpu { +#define CUB_NS_POSTFIX } // namespace fbgemm_gpu + +#endif diff --git a/fbgemm_gpu/include/fbgemm_gpu/hipcub_namespace_prefix.hpp b/fbgemm_gpu/include/fbgemm_gpu/hipcub_namespace_prefix.hpp new file mode 100644 index 000000000..c977653fa --- /dev/null +++ b/fbgemm_gpu/include/fbgemm_gpu/hipcub_namespace_prefix.hpp @@ -0,0 +1,16 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifdef FBGEMM_CUB_USE_NAMESPACE + +#undef CUB_NS_PREFIX +#undef CUB_NS_POSTFIX + +#define CUB_NS_PREFIX namespace fbgemm_gpu { +#define CUB_NS_POSTFIX } // namespace fbgemm_gpu + +#endif diff --git a/fbgemm_gpu/include/fbgemm_gpu/quantize_ops.cuh b/fbgemm_gpu/include/fbgemm_gpu/quantize_ops.cuh index 66d4d0448..10d246ffb 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/quantize_ops.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/quantize_ops.cuh @@ -7,8 +7,10 @@ #pragma once #include -#include +#ifndef __HIP_PLATFORM_HCC__ #include +#endif +#include #include #define QUANTIZE_OPS_MAX(a, b) ((a) > (b) ? (a) : (b)) @@ -59,8 +61,14 @@ __global__ inline void _get_8bit_qparam_cuda_kernel( const int output_columns = ncols_aligned + 2 * sizeof(float); // starting values for future reductions + // TODO: Fix this for HIP +#ifdef __HIP_PLATFORM_HCC__ + float minimum_element = 0; + float maximum_element = 0; +#else float minimum_element = CUDART_INF_F; float maximum_element = -CUDART_INF_F; +#endif // always a power of 2 up to size 32. Multiple rows can share the same warp // when smaller than 32. diff --git a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.cuh b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.cuh index 610a3fe06..3c0608919 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.cuh @@ -6,6 +6,10 @@ */ #pragma once +#ifdef __HIP_PLATFORM_HCC__ +#define HIPCUB_ARCH 1 +#endif + #include // clang-format off diff --git a/fbgemm_gpu/run_all.sh b/fbgemm_gpu/run_all.sh new file mode 100755 index 000000000..d7a457c08 --- /dev/null +++ b/fbgemm_gpu/run_all.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +common_opts="--bag-size 55 \ + --batch-size 65536 \ + --num-embeddings 19300000 \ + --num-tables 1 \ + --iters 5" + +# Run on GPU and get PyTorch-level performance +for D in 64 128 192 256 512; do + for fp in "fp32" "fp16"; do + for alpha in 1 1.15; do + echo "D = ${D}, FP = ${fp}, alpha = ${alpha}" + python3.6 bench/split_table_batched_embeddings_benchmark.py device \ + $common_opts \ + --embedding-dim $D \ + --alpha ${alpha} \ + --weights-precision $fp + done + done +done 2>&1 | tee log_fbgemm_gpu_m1.log + +# Run on GPU and get rocprof-level performance +for D in 64 128 192 256 512; do + for fp in "fp32" "fp16"; do + for alpha in 1 1.15; do + rm -rf rocprof + rm -rf rocprof_tmp + echo "D = ${D}, FP = ${fp}, alpha = ${alpha}" + outf="rocprof_fbgemm_gpu_D_${D}_${fp}_alpha_${alpha}.csv" + rocprof --timestamp on -o $outf -d rocprof -t rocprof_tmp \ + python3.6 bench/split_table_batched_embeddings_benchmark.py device \ + $common_opts \ + --embedding-dim $D \ + --alpha ${alpha} \ + --weights-precision $fp + done + done +done diff --git a/fbgemm_gpu/setup.py b/fbgemm_gpu/setup.py index f33dc0f91..7d033fa0b 100644 --- a/fbgemm_gpu/setup.py +++ b/fbgemm_gpu/setup.py @@ -3,19 +3,38 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import glob import os +import shutil +import sysconfig import sys +import re +import tempfile -from skbuild import setup +from codegen.embedding_backward_code_generator import emb_codegen +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension +import torch +from torch.utils.hipify import hipify_python cpu_only_build = False - +cur_dir = os.path.dirname(os.path.realpath(__file__)) cub_include_path = os.getenv("CUB_DIR", None) if cub_include_path is None: print( "CUDA CUB directory environment variable not set. Using default CUB location." ) +build_codegen_path = "build/codegen" +py_path = "python" + +is_rocm_pytorch = False +maj_ver, min_ver, _ = torch.__version__.split('.') +if int(maj_ver) > 1 or (int(maj_ver) == 1 and int(min_ver) >= 5): + from torch.utils.cpp_extension import ROCM_HOME + is_rocm_pytorch = True \ + if ((torch.version.hip is not None) and (ROCM_HOME is not None)) \ + else False # Get the long description from the relevant file cur_dir = os.path.dirname(os.path.realpath(__file__)) @@ -23,9 +42,187 @@ with open(os.path.join(cur_dir, "README.md"), encoding="utf-8") as f: long_description = f.read() -import torch +extra_compile_args = sysconfig.get_config_var("CFLAGS").split() +extra_compile_args += ["-mavx2", "-mf16c", "-mfma"] +if not is_rocm_pytorch: + extra_compile_args += ["-mavx512f", "-mavx512bw", "-mavx512dq", "-mavx512vl"] + +OPTIMIZERS = [ + "adagrad", + "adam", + "approx_rowwise_adagrad", + "approx_sgd", + "lamb", + "lars_sgd", + "partial_rowwise_adam", + "partial_rowwise_lamb", + "rowwise_adagrad", + "sgd", + "rowwise_weighted_adagrad" +] + +cpp_asmjit_files = glob.glob("../third_party/asmjit/src/asmjit/*/*.cpp") + +cpp_fbgemm_files = [ + "../src/EmbeddingSpMDMAvx2.cc", + "../src/EmbeddingSpMDM.cc", + "../src/EmbeddingSpMDMNBit.cc", + "../src/QuantUtils.cc", + "../src/QuantUtilsAvx2.cc", + "../src/RefImplementations.cc", + "../src/RowWiseSparseAdagradFused.cc", + "../src/SparseAdagrad.cc", + "../src/Utils.cc", +] + +if not is_rocm_pytorch: + cpp_fbgemm_files += ["../src/EmbeddingSpMDMAvx512.cc"] + +cpp_cpu_output_files = ( + [ + "gen_embedding_forward_quantized_unweighted_codegen_cpu.cpp", + "gen_embedding_forward_quantized_weighted_codegen_cpu.cpp", + "gen_embedding_backward_dense_split_cpu.cpp", + ] + + [ + "gen_embedding_backward_split_{}_cpu.cpp".format(optimizer) + for optimizer in OPTIMIZERS + ] + + [ + "gen_embedding_backward_{}_split_cpu.cpp".format(optimizer) + for optimizer in OPTIMIZERS + ] +) + +cpp_cuda_output_files = ( + [ + "gen_embedding_forward_dense_weighted_codegen_cuda.cu", + "gen_embedding_forward_dense_unweighted_codegen_cuda.cu", + "gen_embedding_forward_quantized_split_unweighted_codegen_cuda.cu", + "gen_embedding_forward_quantized_split_weighted_codegen_cuda.cu", + "gen_embedding_forward_split_weighted_codegen_cuda.cu", + "gen_embedding_forward_split_unweighted_codegen_cuda.cu", + "gen_embedding_backward_split_indice_weights_codegen_cuda.cu", + "gen_embedding_backward_dense_indice_weights_codegen_cuda.cu", + "gen_embedding_backward_dense_split_unweighted_cuda.cu", + "gen_embedding_backward_dense_split_weighted_cuda.cu", + ] + + [ + "gen_embedding_backward_{}_split_{}_cuda.cu".format(optimizer, weighted) + for optimizer in OPTIMIZERS + for weighted in [ + "weighted", + "unweighted", + ] + ] + + [ + "gen_embedding_backward_split_{}.cpp".format(optimizer) + for optimizer in OPTIMIZERS + ] +) + +py_output_files = ["lookup_{}.py".format(optimizer) for optimizer in OPTIMIZERS] + + +def generate_jinja_files(): + abs_build_path = os.path.join(cur_dir, build_codegen_path) + if not os.path.exists(abs_build_path): + os.makedirs(abs_build_path) + emb_codegen(install_dir=abs_build_path, is_fbcode=False) + + dst_python_path = os.path.join(cur_dir, py_path) + if not os.path.exists(dst_python_path): + os.makedirs(dst_python_path) + for filename in py_output_files: + shutil.copy2(os.path.join(abs_build_path, filename), dst_python_path) + shutil.copy2(os.path.join(cur_dir, "codegen", "lookup_args.py"), dst_python_path) + + +class FBGEMM_GPU_BuildExtension(BuildExtension.with_options(no_python_abi_suffix=True)): + def build_extension(self, ext): + if not is_rocm_pytorch: + generate_jinja_files() + else: + with hipify_python.GeneratedFileCleaner(keep_intermediates=True) as clean_ctx: + hipify_python.hipify( + project_directory=cur_dir, + output_directory=cur_dir, + includes="codegen/*", + show_detailed=True, + is_pytorch_extension=True, + clean_ctx=clean_ctx) + + def replace_pattern(hip_file, pattern_map): + patterns = {} + for regexp in pattern_map: + patterns[regexp] = re.compile(regexp.format(exclude="")) + with tempfile.NamedTemporaryFile(mode="w", delete=False) as tmp_file: + with open(hip_file) as src_file: + for line in src_file: + for regexp in pattern_map: + pattern = pattern_map[regexp] + exclude = pattern[0] + replacement = pattern[1] + in_regexp = regexp.format(exclude="") + if len(pattern_map[regexp]) == 4: + all_ori = pattern[2] + all_new = pattern[3] + else: + all_ori = None + all_new = None + if re.search(in_regexp, line) and \ + (exclude is None or not re.search(regexp.format(exclude=exclude), line)): + ori = line + if all_ori is not None and all_ori in line: + line = line.replace(all_ori, all_new) + else: + line = patterns[regexp].sub(replacement, line) + + tmp_file.write(line) + + shutil.copystat(hip_file, tmp_file.name) + shutil.move(tmp_file.name, hip_file) + + def post_hipify(hip_file): + replace_pattern(hip_file, {"(#include.*\"codegen.*){exclude}[.]cuh": ["_hip", "\\1_hip.cuh"], + "{exclude}cub(::DeviceRunLengthEncode)": ["hip", "hipcub\\1"], + "(#include.*[<\"].*){exclude}cub(.*)[.]cuh": ["hip", "\\1hipcub\\2.hpp"], + "(#include.*[<\"]fbgemm_gpu.*)({exclude}[.]cuh)": ["_hip", "\\1_hip\\2", "cuda", "hip"], + "cudaCpuDeviceId": [None, "hipCpuDeviceId"], + "split_embeddings_utils[.]cuh": [None, "split_embeddings_utils_hip.cuh"]}) + + abs_build_path = os.path.join(cur_dir, build_codegen_path) + for f in cpp_cuda_output_files: + if f.endswith(".cu"): + hip_f = os.path.join(abs_build_path, f.replace("cuda", "hip").replace(".cu", ".hip")) + post_hipify(hip_f) + + for s in ["codegen", "src"]: + for f in os.listdir(s): + if f.endswith(".hip") or f.endswith("hip.cuh"): + hip_f = os.path.join(s, f) + post_hipify(hip_f) + + os.system("hipify-perl src/split_embeddings_utils.cuh > src/split_embeddings_utils_hip.cuh") + post_hipify("src/split_embeddings_utils_hip.cuh") + + super().build_extension(ext) + +if is_rocm_pytorch: + generate_jinja_files() + rocm_include_dirs = ["/opt/rocm/include/hiprand", "/opt/rocm/include/rocrand"] + libraries = [] +else: + rocm_include_dirs = [] + libraries = ["nvidia-ml"] + +include_dirs = [ cur_dir, + os.path.join(cur_dir, "include"), + os.path.join(cur_dir, "include/fbgemm_gpu"), + ] + rocm_include_dirs -torch_root = os.path.dirname(torch.__file__) +if cub_include_path is not None: + include_dirs += [cub_include_path] # Handle command line args before passing to main setup() method. if "--cpu_only" in sys.argv: @@ -34,8 +231,88 @@ setup( name="fbgemm_gpu", + install_requires=[ + "torch", + "Jinja2", + "click", + "hypothesis", + ], version="0.0.1", long_description=long_description, - packages=["fbgemm_gpu"], - cmake_args=[f"-DCMAKE_PREFIX_PATH={torch_root}"], + ext_modules=[ + CUDAExtension( + name="fbgemm_gpu_py", + sources=[ + os.path.join(cur_dir, build_codegen_path, "{}".format(f)) + for f in cpp_cuda_output_files + cpp_cpu_output_files + ] + + cpp_asmjit_files + + cpp_fbgemm_files + + [ + os.path.join(cur_dir, "codegen/embedding_forward_split_cpu.cpp"), + os.path.join(cur_dir, "codegen/embedding_forward_quantized_host_cpu.cpp"), + os.path.join(cur_dir, "codegen/embedding_forward_quantized_host.cpp"), + os.path.join(cur_dir, "codegen/embedding_backward_dense_host_cpu.cpp"), + os.path.join(cur_dir, "codegen/embedding_backward_dense_host.cpp"), + os.path.join(cur_dir, "codegen/embedding_bounds_check_host.cpp"), + os.path.join(cur_dir, "codegen/embedding_bounds_check_host_cpu.cpp"), + os.path.join(cur_dir, "codegen/embedding_bounds_check.cu"), + os.path.join(cur_dir, "src/split_embeddings_cache_cuda.cu"), + os.path.join(cur_dir, "src/split_table_batched_embeddings.cpp"), + os.path.join(cur_dir, "src/cumem_utils.cu"), + os.path.join(cur_dir, "src/cumem_utils_host.cpp"), + os.path.join(cur_dir, "src/quantize_ops_cpu.cpp"), + os.path.join(cur_dir, "src/quantize_ops_gpu.cpp"), + os.path.join(cur_dir, "src/cpu_utils.cpp"), + os.path.join(cur_dir, "src/sparse_ops_cpu.cpp"), + os.path.join(cur_dir, "src/sparse_ops_gpu.cpp"), + os.path.join(cur_dir, "src/sparse_ops.cu"), + os.path.join(cur_dir, "src/merge_pooled_embeddings_gpu.cpp"), + os.path.join(cur_dir, "src/permute_pooled_embedding_ops.cu"), + os.path.join(cur_dir, "src/permute_pooled_embedding_ops_gpu.cpp"), + os.path.join(cur_dir, "src/layout_transform_ops_cpu.cpp"), + os.path.join(cur_dir, "src/layout_transform_ops_gpu.cpp"), + os.path.join(cur_dir, "src/layout_transform_ops.cu"), + ], + include_dirs=[ + cur_dir, + os.path.join(cur_dir, "include"), + os.path.join(cur_dir, "../include"), + os.path.join(cur_dir, "../src"), + os.path.join(cur_dir, "../third_party/asmjit/src"), + os.path.join(cur_dir, "../third_party/asmjit/src/core"), + os.path.join(cur_dir, "../third_party/asmjit/src/x86"), + os.path.join(cur_dir, "../third_party/cpuinfo/include"), + ] + include_dirs, + extra_compile_args={"cxx": extra_compile_args + ["-DFBGEMM_GPU_WITH_CUDA"], + "nvcc": ["-U__CUDA_NO_HALF_CONVERSIONS__"]}, + libraries=libraries, + ) if not cpu_only_build else + CppExtension( + name="fbgemm_gpu_py", + sources=[ + os.path.join(cur_dir, build_codegen_path, "{}".format(f)) + for f in cpp_cpu_output_files + ] + + cpp_asmjit_files + + cpp_fbgemm_files + + [ + os.path.join(cur_dir, "codegen/embedding_forward_split_cpu.cpp"), + os.path.join(cur_dir, "codegen/embedding_forward_quantized_host_cpu.cpp"), + os.path.join(cur_dir, "codegen/embedding_backward_dense_host_cpu.cpp"), + ], + include_dirs=[ + cur_dir, + os.path.join(cur_dir, "include"), + os.path.join(cur_dir, "../include"), + os.path.join(cur_dir, "../src"), + os.path.join(cur_dir, "../third_party/asmjit/src"), + os.path.join(cur_dir, "../third_party/asmjit/src/core"), + os.path.join(cur_dir, "../third_party/asmjit/src/x86"), + os.path.join(cur_dir, "../third_party/cpuinfo/include"), + ], + extra_compile_args={"cxx": extra_compile_args}, + ) + ], + cmdclass={"build_ext": FBGEMM_GPU_BuildExtension}, ) diff --git a/fbgemm_gpu/src/cumem_utils.cu b/fbgemm_gpu/src/cumem_utils.cu index 966876dee..cf7045356 100644 --- a/fbgemm_gpu/src/cumem_utils.cu +++ b/fbgemm_gpu/src/cumem_utils.cu @@ -237,6 +237,35 @@ int64_t uvm_get_guard_index(Tensor& t) { } // namespace +#ifdef __HIP_PLATFORM_HCC__ +void uvm_cuda_mem_advise(Tensor t, int64_t hipMemoryAdvise) { + // Call hipMemAdvise on vm tensor + // See hipMemAdvise enum (automatically exported to python fbgemm_gpu.uvm + // namespace) for valid values and interface stub. + at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard; + int64_t cuda_device_index = uvm_get_guard_index(t); + int hint_device; + if (t.is_cpu()) { + hint_device = hipCpuDeviceId; + } else { + TORCH_CHECK(t.is_cuda()); + hint_device = static_cast(cuda_device_index); + } + + void* ptr = t.data_ptr(); + size_t size_bytes = at::detail::computeStorageNbytes( + t.sizes(), t.strides(), t.dtype().itemsize()); + + device_guard.set_index(cuda_device_index); + + AT_CUDA_CHECK(hipMemAdvise( + ptr, + size_bytes, + static_cast(hipMemoryAdvise), + hint_device)); + return; +} +#else void uvm_cuda_mem_advise(Tensor t, int64_t cudaMemoryAdvise) { // Call cudaMemAdvise on vm tensor // See cudaMemoryAdvise enum (automatically exported to python fbgemm_gpu.uvm @@ -264,6 +293,7 @@ void uvm_cuda_mem_advise(Tensor t, int64_t cudaMemoryAdvise) { hint_device)); return; } +#endif void uvm_cuda_mem_prefetch_async(Tensor t, c10::optional device_t) { // Call cudaMemPrefetchAsync on Tensor @@ -311,6 +341,16 @@ void uvm_mem_advice_dont_fork(Tensor t) { FBGEMM_GPU_ENUM_GLOGAL(uvm) +#ifdef __HIP_PLATFORM_HCC__ +FBGEMM_GPU_ENUM_REGISTER_START(uvm, hipMemoryAdvise){ + FBGEMM_GPU_ENUM_ITEM(cudaMemAdviseSetReadMostly), + FBGEMM_GPU_ENUM_ITEM(cudaMemAdviseUnsetReadMostly), + FBGEMM_GPU_ENUM_ITEM(cudaMemAdviseSetPreferredLocation), + FBGEMM_GPU_ENUM_ITEM(cudaMemAdviseUnsetPreferredLocation), + FBGEMM_GPU_ENUM_ITEM(cudaMemAdviseSetAccessedBy), + FBGEMM_GPU_ENUM_ITEM(cudaMemAdviseUnsetAccessedBy), +} FBGEMM_GPU_ENUM_REGISTER_END +#else FBGEMM_GPU_ENUM_REGISTER_START(uvm, cudaMemoryAdvise){ FBGEMM_GPU_ENUM_ITEM(cudaMemAdviseSetReadMostly), FBGEMM_GPU_ENUM_ITEM(cudaMemAdviseUnsetReadMostly), @@ -319,5 +359,6 @@ FBGEMM_GPU_ENUM_REGISTER_START(uvm, cudaMemoryAdvise){ FBGEMM_GPU_ENUM_ITEM(cudaMemAdviseSetAccessedBy), FBGEMM_GPU_ENUM_ITEM(cudaMemAdviseUnsetAccessedBy), } FBGEMM_GPU_ENUM_REGISTER_END +#endif } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/merge_pooled_embeddings_gpu.cpp b/fbgemm_gpu/src/merge_pooled_embeddings_gpu.cpp index c354a06c1..e621cf682 100644 --- a/fbgemm_gpu/src/merge_pooled_embeddings_gpu.cpp +++ b/fbgemm_gpu/src/merge_pooled_embeddings_gpu.cpp @@ -15,6 +15,8 @@ #include #include +// TODO: Enable merge_pooled_embeddings for HIP +#ifndef __HIP_PLATFORM_HCC__ #include #include @@ -353,3 +355,4 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { c10::DispatchKey::CUDA, TORCH_FN(fbgemm_gpu::merge_pooled_embeddings))); } +#endif diff --git a/fbgemm_gpu/src/sparse_ops.cu b/fbgemm_gpu/src/sparse_ops.cu index 006832be1..876a1dc26 100644 --- a/fbgemm_gpu/src/sparse_ops.cu +++ b/fbgemm_gpu/src/sparse_ops.cu @@ -1135,7 +1135,11 @@ __global__ void reorder_batched_ad_indices_kernel( const int32_t output_segment_start = reordered_cat_ad_offsets[output_segment_offset_start]; +#ifdef __HIP_PLATFORM_HCC__ + for (int32_t i = threadIdx.x; i < input_segment_end - input_segment_start; +#else for (auto i = threadIdx.x; i < input_segment_end - input_segment_start; +#endif i += blockDim.x) { reordered_cat_ad_indices[output_segment_start + i] = cat_ad_indices[input_segment_start + i]; diff --git a/fbgemm_gpu/src/split_embeddings_cache_cuda.cu b/fbgemm_gpu/src/split_embeddings_cache_cuda.cu index dd3cf9a98..4feceb4a6 100644 --- a/fbgemm_gpu/src/split_embeddings_cache_cuda.cu +++ b/fbgemm_gpu/src/split_embeddings_cache_cuda.cu @@ -254,9 +254,15 @@ __global__ __launch_bounds__(kMaxThreads) void linearize_cache_indices_kernel( // hash_offset < 0 for non-caching tables for (int32_t j = 0; j < kWarpSize; ++j) { +#ifdef __HIP_PLATFORM_HCC__ + int64_t indices_start_warp = __shfl(indices_start, j); + int32_t L_warp = __shfl(L, j); + int64_t hash_offset_warp = __shfl(hash_offset, j); +#else int64_t indices_start_warp = __shfl_sync(0xFFFFFFFF, indices_start, j); int32_t L_warp = __shfl_sync(0xFFFFFFFF, L, j); int64_t hash_offset_warp = __shfl_sync(0xFFFFFFFF, hash_offset, j); +#endif if (hash_offset_warp >= 0) { for (int32_t i = lane_id; i < L_warp; i += kWarpSize) { auto idx = __ldg(&indices[indices_start_warp + i]); @@ -441,7 +447,11 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_find_uncached_kernel( lru_state[cache_set][slot] = time_stamp; } +#ifdef __HIP_PLATFORM_HCC__ + if (!__any(found)) { +#else if (!__any_sync(0xFFFFFFFF, found)) { +#endif if (threadIdx.x == 0) { cache_sets[n] = cache_set; } @@ -571,9 +581,14 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_kernel( int64_t sorted_lru_cost = costs[0]; for (int32_t l = 0; l < min(SL, kWarpSize); ++l) { +#ifdef __HIP_PLATFORM_HCC__ + int32_t insert_slot = __shfl(sorted_slot, l); + int64_t insert_current_lru_cost = __shfl(sorted_lru_cost, l); +#else int32_t insert_slot = __shfl_sync(0xFFFFFFFF, sorted_slot, l); int64_t insert_current_lru_cost = __shfl_sync(0xFFFFFFFF, sorted_lru_cost, l); +#endif if (insert_current_lru_cost == time_stamp) { return; } @@ -589,7 +604,11 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_kernel( // lxu_cache_state int64_t current_idx = threadIdx.x == 0 ? lxu_cache_state[cache_set][insert_slot] : 0; +#ifdef __HIP_PLATFORM_HCC__ + current_idx = __shfl(current_idx, 0); +#else current_idx = __shfl_sync(0xFFFFFFFF, current_idx, 0); +#endif // not empty if (current_idx != static_cast(kCacheStateInvalid)) { @@ -862,9 +881,15 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_byte_kernel( int64_t sorted_lru_cost = costs[0]; for (int32_t l = 0; l < min(SL, kWarpSize); ++l) { +#ifdef __HIP_PLATFORM_HCC__ + int32_t insert_slot = __shfl(sorted_slot, l); + int64_t insert_current_lru_cost = + __shfl(sorted_lru_cost, l); +#else int32_t insert_slot = __shfl_sync(0xFFFFFFFF, sorted_slot, l); int64_t insert_current_lru_cost = __shfl_sync(0xFFFFFFFF, sorted_lru_cost, l); +#endif if (insert_current_lru_cost == time_stamp) { return; } @@ -885,7 +910,11 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_byte_kernel( // lxu_cache_state int64_t current_idx = threadIdx.x == 0 ? lxu_cache_state[cache_set][insert_slot] : 0; +#ifdef __HIP_PLATFORM_HCC__ + current_idx = __shfl(current_idx, 0); +#else current_idx = __shfl_sync(0xFFFFFFFF, current_idx, 0); +#endif // not empty if (current_idx != static_cast(kCacheStateInvalid)) { @@ -1112,7 +1141,11 @@ __global__ __launch_bounds__(kMaxThreads) void lfu_cache_find_uncached_kernel( << kLFUCounterBits); // invalid index, used as sentinel } +#ifdef __HIP_PLATFORM_HCC__ + if (!__any(found)) { +#else if (!__any_sync(0xFFFFFFFF, found)) { +#endif if (threadIdx.x == 0) { // sort so the highest LFUs come first in the segment. // assume lfu_state[idx] <= 2^40 - 1 and cache_set < 2^24 -1 @@ -1249,9 +1282,14 @@ __global__ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_kernel( int64_t sorted_lfu_cost = costs[0]; for (int32_t l = 0; l < min(SL, kWarpSize); ++l) { +#ifdef __HIP_PLATFORM_HCC__ + int32_t insert_slot = __shfl(sorted_slot, l); + int64_t insert_current_lfu_cost = __shfl(sorted_lfu_cost, l); +#else int32_t insert_slot = __shfl_sync(0xFFFFFFFF, sorted_slot, l); int64_t insert_current_lfu_cost = __shfl_sync(0xFFFFFFFF, sorted_lfu_cost, l); +#endif int64_t insert_idx = cache_set_sorted_indices[n + l]; int64_t insert_lfu_cost = lfu_state[insert_idx]; @@ -1275,7 +1313,11 @@ __global__ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_kernel( // lxu_cache_state int64_t current_idx = threadIdx.x == 0 ? lxu_cache_state[cache_set][insert_slot] : 0; +#ifdef __HIP_PLATFORM_HCC__ + current_idx = __shfl(current_idx, 0); +#else current_idx = __shfl_sync(0xFFFFFFFF, current_idx, 0); +#endif int32_t t_current = cache_index_table_map[current_idx]; int64_t idx_current = current_idx - cache_hash_size_cumsum[t_current]; int64_t weights_offset_current = weights_offsets[t_current]; @@ -1564,9 +1606,15 @@ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_byte_kernel( int64_t sorted_lfu_cost = costs[0]; for (int32_t l = 0; l < min(SL, kWarpSize); ++l) { +#ifdef __HIP_PLATFORM_HCC__ + int32_t insert_slot = __shfl(sorted_slot, l); + int64_t insert_current_lfu_cost = + __shfl(sorted_lfu_cost, l); +#else int32_t insert_slot = __shfl_sync(0xFFFFFFFF, sorted_slot, l); int64_t insert_current_lfu_cost = __shfl_sync(0xFFFFFFFF, sorted_lfu_cost, l); +#endif int64_t insert_idx = cache_set_sorted_indices[n + l]; int64_t insert_lfu_cost = lfu_state[insert_idx]; @@ -1595,7 +1643,11 @@ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_byte_kernel( // lxu_cache_state int64_t current_idx = threadIdx.x == 0 ? lxu_cache_state[cache_set][insert_slot] : 0; +#ifdef __HIP_PLATFORM_HCC__ + current_idx = __shfl(current_idx, 0); +#else current_idx = __shfl_sync(0xFFFFFFFF, current_idx, 0); +#endif int32_t t_current = cache_index_table_map[current_idx]; SparseType weight_ty_current = static_cast(weights_tys[t_current]); @@ -1752,7 +1804,11 @@ __global__ __launch_bounds__(kMaxThreads) void lxu_cache_lookup_kernel( if (found) { lxu_cache_locations[n] = cache_set * kWarpSize + slot; } +#ifdef __HIP_PLATFORM_HCC__ + if (!__any(found)) { +#else if (!__any_sync(0xFFFFFFFF, found)) { +#endif if (threadIdx.x == 0) { lxu_cache_locations[n] = kCacheLocationMissing; } diff --git a/fbgemm_gpu/src/split_embeddings_utils.cuh b/fbgemm_gpu/src/split_embeddings_utils.cuh index 21f33cb03..f10fa268c 100644 --- a/fbgemm_gpu/src/split_embeddings_utils.cuh +++ b/fbgemm_gpu/src/split_embeddings_utils.cuh @@ -41,8 +41,13 @@ inline __device__ void warpBitonicMergeLE16(K& k, V& v) { // Reverse the first comparison stage. // For example, merging a list of size 8 has the exchanges: // 0 <-> 15, 1 <-> 14, ... +#ifdef __HIP_PLATFORM_HCC__ + K otherK = __shfl_xor(k, 2 * L - 1); + V otherV = __shfl_xor(v, 2 * L - 1); +#else K otherK = shfl_xor(k, 2 * L - 1); V otherV = shfl_xor(v, 2 * L - 1); +#endif // Whether we are the lesser thread in the exchange bool small = !(laneId & L); @@ -64,8 +69,13 @@ inline __device__ void warpBitonicMergeLE16(K& k, V& v) { #pragma unroll for (int32_t stride = IsBitonic ? L : L / 2; stride > 0; stride /= 2) { +#ifdef __HIP_PLATFORM_HCC__ + K otherK = __shfl_xor(k, stride); + V otherV = __shfl_xor(v, stride); +#else K otherK = shfl_xor(k, stride); V otherV = shfl_xor(v, stride); +#endif // Whether we are the lesser thread in the exchange bool small = !(laneId & stride); @@ -86,7 +96,11 @@ inline __device__ void warpBitonicMergeLE16(K& k, V& v) { template struct BitonicSort { static inline __device__ void sort(K k[1], V v[1]) { +#ifdef __HIP_PLATFORM_HCC__ + static_assert(fbgemm_gpu::kWarpSize == 64, "unexpected warp size"); +#else static_assert(fbgemm_gpu::kWarpSize == 32, "unexpected warp size"); +#endif warpBitonicMergeLE16(k[0], v[0]); warpBitonicMergeLE16(k[0], v[0]); warpBitonicMergeLE16(k[0], v[0]); diff --git a/src/EmbeddingSpMDM.cc b/src/EmbeddingSpMDM.cc index 56451bab3..8e96c3cb4 100644 --- a/src/EmbeddingSpMDM.cc +++ b/src/EmbeddingSpMDM.cc @@ -1207,6 +1207,7 @@ void compressed_indices_remap( const inst_set_t isa = fbgemmInstructionSet(); if (isZmm(isa)) { +#ifndef __HIP_PLATFORM_HCC__ if (weights == nullptr) { internal::compressed_indices_remap_avx512( offsets_len, @@ -1228,6 +1229,7 @@ void compressed_indices_remap( out_offsets, out_weights); } +#endif } else { compressed_indices_remap_ref( offsets_len, From a2239361b095ee7e415293cc9da0eaa7761fe5ec Mon Sep 17 00:00:00 2001 From: liligwu Date: Wed, 26 Jan 2022 16:06:36 +0000 Subject: [PATCH 02/76] Use SHEFL_SYNC_MACRO to replace __shefl() and __shefl_sync() --- .../embedding_backward_code_generator.py | 27 ++----- ..._backward_split_indice_weights_template.cu | 12 +-- .../embedding_backward_split_template.cu | 42 ++-------- .../embedding_backward_template_helpers.cuh | 30 ++----- fbgemm_gpu/codegen/embedding_common.h | 6 ++ fbgemm_gpu/src/split_embeddings_cache_cuda.cu | 78 ++++--------------- 6 files changed, 44 insertions(+), 151 deletions(-) diff --git a/fbgemm_gpu/codegen/embedding_backward_code_generator.py b/fbgemm_gpu/codegen/embedding_backward_code_generator.py index 2d87581fd..ee2388f9e 100644 --- a/fbgemm_gpu/codegen/embedding_backward_code_generator.py +++ b/fbgemm_gpu/codegen/embedding_backward_code_generator.py @@ -397,11 +397,7 @@ def rowwise_adagrad() -> None: momentum1[idx] = new_sum_square_grads; multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps); } -#ifdef __HIP_PLATFORM_HCC__ - multiplier = __shfl(multiplier, 0); -#else - multiplier = __shfl_sync(0xFFFFFFFF, multiplier, 0); -#endif + multiplier = SHFL_SYNC_MACRO(multiplier, 0); """ split_weight_update_cpu = """ at::acc_type g_local_sum_square = 0.0; @@ -478,13 +474,8 @@ def rowwise_weighted_adagrad() -> None: multiplier = learning_rate * lambda / (cbrtf(new_sum_square_grads) + eps); correction = 1.0 - multiplier * weight_decay; } -#ifdef __HIP_PLATFORM_HCC__ - multiplier = __shfl(multiplier, 0); - correction = __shfl(correction, 0); -#else - multiplier = __shfl_sync(0xFFFFFFFF, multiplier, 0); - correction = __shfl_sync(0xFFFFFFFF, correction, 0); -#endif + multiplier = SHFL_SYNC_MACRO(multiplier, 0); + correction = SHFL_SYNC_MACRO(correction, 0); """ split_weight_update_cpu = """ // weight_decay not supported for cpu version @@ -645,11 +636,7 @@ def partial_rowwise_lamb() -> None: m2 = beta2 * momentum2[idx] + (1.0 - beta2) * g_avg_square; momentum2[idx] = m2; } -#ifdef __HIP_PLATFORM_HCC__ - m2 = __shfl(m2, 0); -#else - m2 = __shfl_sync(0xFFFFFFFF, m2, 0); -#endif + m2 = SHFL_SYNC_MACRO(m2, 0); at::acc_type m2_hat = 1.0 / (sqrtf((m2 / (1.0 - powf(beta2, iter)))) + eps); at::acc_type weight_sum_sq = 0.0; @@ -785,11 +772,7 @@ def partial_rowwise_adam() -> None: momentum2[idx] = v_t; v_hat_t = v_t / (1.0 - powf(beta2, iter)); } -#ifdef __HIP_PLATFORM_HCC__ - v_hat_t = __shfl(v_hat_t, 0); -#else - v_hat_t = __shfl_sync(0xFFFFFFFF, v_hat_t, 0); -#endif + v_hat_t = SHFL_SYNC_MACRO(v_hat_t, 0); """ split_weight_update = """ diff --git a/fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu index 486297a46..186c60294 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu @@ -103,17 +103,9 @@ __launch_bounds__(kForwardMaxThreads) void {{ "dense" if dense else "split" }}_e int32_t cache_idx = (placement == PlacementType::MANAGED_CACHING && l < L) ? lxu_cache_locations[indices_start + l] : 0; {% endif %} for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) { -#ifdef __HIP_PLATFORM_HCC__ - int64_t idx_j = __shfl(idx, j); -#else - int64_t idx_j = __shfl_sync(0xFFFFFFFF, idx, j); -#endif + int64_t idx_j = SHFL_SYNC_MACRO(idx, 0); {% if not dense %} -#ifdef __HIP_PLATFORM_HCC__ - int32_t cache_idx_j = __shfl(cache_idx, j); -#else - int32_t cache_idx_j = __shfl_sync(0xFFFFFFFF, cache_idx, j); -#endif + int32_t cache_idx_j = SHFL_SYNC_MACRO(cache_idx, 0); {% endif %} at::acc_type grad_indice_weight = 0.0; diff --git a/fbgemm_gpu/codegen/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_template.cu index f69a8dec2..5f5dbde2d 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_template.cu @@ -190,26 +190,13 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ {% endif %} for (int32_t j = 0; j < kWarpSize && sl + j < sl_end; ++j) { {% if not nobag %} -#ifdef __HIP_PLATFORM_HCC__ - int32_t b_j = __shfl(b, j); - int32_t D_start_j = __shfl(D_start, j); -#else - int32_t b_j = __shfl_sync(0xFFFFFFFF, b, j); - int32_t D_start_j = __shfl_sync(0xFFFFFFFF, D_start, j); -#endif + int32_t b_j = SHFL_SYNC_MACRO(b, j); + int32_t D_start_j = SHFL_SYNC_MACRO(D_start, j); {% else %} -#ifdef __HIP_PLATFORM_HCC__ - int32_t l_j = __shfl(l, j); -#else - int32_t l_j = __shfl_sync(0xFFFFFFFF, l, j); -#endif + int32_t l_j = SHFL_SYNC_MACRO(l, j); {% endif %} {% if weighted %} -#ifdef __HIP_PLATFORM_HCC__ - at::acc_type idx_weight_j = __shfl(idx_weight, j); -#else - at::acc_type idx_weight_j = __shfl_sync(0xFFFFFFFF, idx_weight, j); -#endif + at::acc_type idx_weight_j = SHFL_SYNC_MACRO(idx_weight, j); {% endif %} #pragma unroll kMaxVecsPerThread @@ -562,26 +549,13 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ for (int32_t j = 0; j < kWarpSize && sl + j < sl_end; ++j) { {% if not nobag %} -#ifdef __HIP_PLATFORM_HCC__ - int32_t b_j = __shfl(b, j); - int32_t D_start_j = __shfl(D_start, j); -#else - int32_t b_j = __shfl_sync(0xFFFFFFFF, b, j); - int32_t D_start_j = __shfl_sync(0xFFFFFFFF, D_start, j); -#endif + int32_t b_j = SHFL_SYNC_MACRO(b, j); + int32_t D_start_j = SHFL_SYNC_MACRO(D_start, j); {% else %} -#ifdef __HIP_PLATFORM_HCC__ - int32_t l_j = __shfl(l, j); -#else - int32_t l_j = __shfl_sync(0xFFFFFFFF, l, j); -#endif + int32_t l_j = SHFL_SYNC_MACRO(l, j); {% endif %} {% if weighted %} -#ifdef __HIP_PLATFORM_HCC__ - at::acc_type idx_weight_j = __shfl(idx_weight, j); -#else - at::acc_type idx_weight_j = __shfl_sync(0xFFFFFFFF, idx_weight, j); -#endif + at::acc_type idx_weight_j = SHFL_SYNC_MACRO(idx_weight, j); {% endif %} #pragma unroll kMaxVecsPerThread diff --git a/fbgemm_gpu/codegen/embedding_backward_template_helpers.cuh b/fbgemm_gpu/codegen/embedding_backward_template_helpers.cuh index 3e79e8fce..beceed8b4 100644 --- a/fbgemm_gpu/codegen/embedding_backward_template_helpers.cuh +++ b/fbgemm_gpu/codegen/embedding_backward_template_helpers.cuh @@ -93,17 +93,10 @@ linearize_index_kernel( int32_t lane_id = threadIdx.x % fbgemm_gpu::kWarpSize; for (int32_t j = 0; j < fbgemm_gpu::kWarpSize; ++j) { -#ifdef __HIP_PLATFORM_HCC__ - index_t indices_start_warp = __shfl(indices_start, j); - int32_t b_t_warp = __shfl(b_t, j); - int32_t L_warp = __shfl(L, j); - index_t hash_offset_warp = __shfl(hash_offset, j); -#else - index_t indices_start_warp = __shfl_sync(0xFFFFFFFF, indices_start, j); - int32_t b_t_warp = __shfl_sync(0xFFFFFFFF, b_t, j); - int32_t L_warp = __shfl_sync(0xFFFFFFFF, L, j); - index_t hash_offset_warp = __shfl_sync(0xFFFFFFFF, hash_offset, j); -#endif + index_t indices_start_warp = SHFL_SYNC_MACRO(indices_start, j); + int32_t b_t_warp = SHFL_SYNC_MACRO(b_t, j); + int32_t L_warp = SHFL_SYNC_MACRO(L, j); + index_t hash_offset_warp = SHFL_SYNC_MACRO(hash_offset, j); for (int32_t i = lane_id; i < L_warp; i += fbgemm_gpu::kWarpSize) { index_t idx = __ldg(&indices[indices_start_warp + i]); infos[indices_start_warp + i] = b_t_warp; @@ -134,17 +127,10 @@ __global__ void nobag_linearize_index_kernel( int32_t lane_id = threadIdx.x % fbgemm_gpu::kWarpSize; for (int32_t j = 0; j < fbgemm_gpu::kWarpSize; ++j) { -#ifdef __HIP_PLATFORM_HCC__ - index_t indices_start_warp = __shfl(indices_start, j); - int32_t t_warp = __shfl(t, j); - int32_t L_warp = __shfl(L, j); - index_t hash_offset_warp = __shfl(hash_offset, j); -#else - index_t indices_start_warp = __shfl_sync(0xFFFFFFFF, indices_start, j); - int32_t t_warp = __shfl_sync(0xFFFFFFFF, t, j); - int32_t L_warp = __shfl_sync(0xFFFFFFFF, L, j); - index_t hash_offset_warp = __shfl_sync(0xFFFFFFFF, hash_offset, j); -#endif + index_t indices_start_warp = SHFL_SYNC_MACRO(indices_start, j); + int32_t t_warp = SHFL_SYNC_MACRO(t, j); + int32_t L_warp = SHFL_SYNC_MACRO(L, j); + index_t hash_offset_warp = SHFL_SYNC_MACRO(hash_offset, j); for (int32_t i = lane_id; i < L_warp; i += fbgemm_gpu::kWarpSize) { index_t idx = __ldg(&indices[indices_start_warp + i]); int64_t l_t = (indices_start_warp + i) * T + t_warp; diff --git a/fbgemm_gpu/codegen/embedding_common.h b/fbgemm_gpu/codegen/embedding_common.h index f4311a19a..ec2030591 100644 --- a/fbgemm_gpu/codegen/embedding_common.h +++ b/fbgemm_gpu/codegen/embedding_common.h @@ -36,3 +36,9 @@ enum class BoundsCheckMode : uint8_t { }; } // namespace + +#ifdef __HIP_PLATFORM_HCC__ + #define SHFL_SYNC_MACRO(var, srcLane) __shfl(var, srcLane) +#else + #define SHFL_SYNC_MACRO(var, srcLane) __shfl_sync(0xFFFFFFFF, var, srcLane) +#endif \ No newline at end of file diff --git a/fbgemm_gpu/src/split_embeddings_cache_cuda.cu b/fbgemm_gpu/src/split_embeddings_cache_cuda.cu index 4feceb4a6..13fb62652 100644 --- a/fbgemm_gpu/src/split_embeddings_cache_cuda.cu +++ b/fbgemm_gpu/src/split_embeddings_cache_cuda.cu @@ -254,15 +254,9 @@ __global__ __launch_bounds__(kMaxThreads) void linearize_cache_indices_kernel( // hash_offset < 0 for non-caching tables for (int32_t j = 0; j < kWarpSize; ++j) { -#ifdef __HIP_PLATFORM_HCC__ - int64_t indices_start_warp = __shfl(indices_start, j); - int32_t L_warp = __shfl(L, j); - int64_t hash_offset_warp = __shfl(hash_offset, j); -#else - int64_t indices_start_warp = __shfl_sync(0xFFFFFFFF, indices_start, j); - int32_t L_warp = __shfl_sync(0xFFFFFFFF, L, j); - int64_t hash_offset_warp = __shfl_sync(0xFFFFFFFF, hash_offset, j); -#endif + int64_t indices_start_warp = SHFL_SYNC_MACRO(indices_start, j); + int32_t L_warp = SHFL_SYNC_MACRO(L, j); + int64_t hash_offset_warp = SHFL_SYNC_MACRO(hash_offset, j); if (hash_offset_warp >= 0) { for (int32_t i = lane_id; i < L_warp; i += kWarpSize) { auto idx = __ldg(&indices[indices_start_warp + i]); @@ -581,14 +575,8 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_kernel( int64_t sorted_lru_cost = costs[0]; for (int32_t l = 0; l < min(SL, kWarpSize); ++l) { -#ifdef __HIP_PLATFORM_HCC__ - int32_t insert_slot = __shfl(sorted_slot, l); - int64_t insert_current_lru_cost = __shfl(sorted_lru_cost, l); -#else - int32_t insert_slot = __shfl_sync(0xFFFFFFFF, sorted_slot, l); - int64_t insert_current_lru_cost = - __shfl_sync(0xFFFFFFFF, sorted_lru_cost, l); -#endif + int32_t insert_slot = SHFL_SYNC_MACRO(sorted_slot, l); + int64_t insert_current_lru_cost = SHFL_SYNC_MACRO(sorted_lru_cost, l); if (insert_current_lru_cost == time_stamp) { return; } @@ -604,11 +592,7 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_kernel( // lxu_cache_state int64_t current_idx = threadIdx.x == 0 ? lxu_cache_state[cache_set][insert_slot] : 0; -#ifdef __HIP_PLATFORM_HCC__ - current_idx = __shfl(current_idx, 0); -#else - current_idx = __shfl_sync(0xFFFFFFFF, current_idx, 0); -#endif + current_idx = SHFL_SYNC_MACRO(current_idx, 0); // not empty if (current_idx != static_cast(kCacheStateInvalid)) { @@ -881,15 +865,8 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_byte_kernel( int64_t sorted_lru_cost = costs[0]; for (int32_t l = 0; l < min(SL, kWarpSize); ++l) { -#ifdef __HIP_PLATFORM_HCC__ - int32_t insert_slot = __shfl(sorted_slot, l); - int64_t insert_current_lru_cost = - __shfl(sorted_lru_cost, l); -#else - int32_t insert_slot = __shfl_sync(0xFFFFFFFF, sorted_slot, l); - int64_t insert_current_lru_cost = - __shfl_sync(0xFFFFFFFF, sorted_lru_cost, l); -#endif + int32_t insert_slot = SHFL_SYNC_MACRO(sorted_slot, l); + int64_t insert_current_lru_cost = SHFL_SYNC_MACRO(sorted_lru_cost, l); if (insert_current_lru_cost == time_stamp) { return; } @@ -910,11 +887,7 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_byte_kernel( // lxu_cache_state int64_t current_idx = threadIdx.x == 0 ? lxu_cache_state[cache_set][insert_slot] : 0; -#ifdef __HIP_PLATFORM_HCC__ - current_idx = __shfl(current_idx, 0); -#else - current_idx = __shfl_sync(0xFFFFFFFF, current_idx, 0); -#endif + current_idx = SHFL_SYNC_MACRO(current_idx, 0); // not empty if (current_idx != static_cast(kCacheStateInvalid)) { @@ -1282,14 +1255,8 @@ __global__ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_kernel( int64_t sorted_lfu_cost = costs[0]; for (int32_t l = 0; l < min(SL, kWarpSize); ++l) { -#ifdef __HIP_PLATFORM_HCC__ - int32_t insert_slot = __shfl(sorted_slot, l); - int64_t insert_current_lfu_cost = __shfl(sorted_lfu_cost, l); -#else - int32_t insert_slot = __shfl_sync(0xFFFFFFFF, sorted_slot, l); - int64_t insert_current_lfu_cost = - __shfl_sync(0xFFFFFFFF, sorted_lfu_cost, l); -#endif + int32_t insert_slot = SHFL_SYNC_MACRO(sorted_slot, l); + int64_t insert_current_lfu_cost = SHFL_SYNC_MACRO(sorted_lfu_cost, l); int64_t insert_idx = cache_set_sorted_indices[n + l]; int64_t insert_lfu_cost = lfu_state[insert_idx]; @@ -1313,11 +1280,7 @@ __global__ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_kernel( // lxu_cache_state int64_t current_idx = threadIdx.x == 0 ? lxu_cache_state[cache_set][insert_slot] : 0; -#ifdef __HIP_PLATFORM_HCC__ - current_idx = __shfl(current_idx, 0); -#else - current_idx = __shfl_sync(0xFFFFFFFF, current_idx, 0); -#endif + current_idx = SHFL_SYNC_MACRO(current_idx, 0); int32_t t_current = cache_index_table_map[current_idx]; int64_t idx_current = current_idx - cache_hash_size_cumsum[t_current]; int64_t weights_offset_current = weights_offsets[t_current]; @@ -1606,15 +1569,8 @@ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_byte_kernel( int64_t sorted_lfu_cost = costs[0]; for (int32_t l = 0; l < min(SL, kWarpSize); ++l) { -#ifdef __HIP_PLATFORM_HCC__ - int32_t insert_slot = __shfl(sorted_slot, l); - int64_t insert_current_lfu_cost = - __shfl(sorted_lfu_cost, l); -#else - int32_t insert_slot = __shfl_sync(0xFFFFFFFF, sorted_slot, l); - int64_t insert_current_lfu_cost = - __shfl_sync(0xFFFFFFFF, sorted_lfu_cost, l); -#endif + int32_t insert_slot = SHFL_SYNC_MACRO(sorted_slot, l); + int64_t insert_current_lfu_cost = SHFL_SYNC_MACRO(sorted_lfu_cost, l); int64_t insert_idx = cache_set_sorted_indices[n + l]; int64_t insert_lfu_cost = lfu_state[insert_idx]; @@ -1643,11 +1599,7 @@ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_byte_kernel( // lxu_cache_state int64_t current_idx = threadIdx.x == 0 ? lxu_cache_state[cache_set][insert_slot] : 0; -#ifdef __HIP_PLATFORM_HCC__ - current_idx = __shfl(current_idx, 0); -#else - current_idx = __shfl_sync(0xFFFFFFFF, current_idx, 0); -#endif + current_idx = SHFL_SYNC_MACRO(current_idx, 0); int32_t t_current = cache_index_table_map[current_idx]; SparseType weight_ty_current = static_cast(weights_tys[t_current]); From a506c52ee9bd03f144ae18c334e76588275df617 Mon Sep 17 00:00:00 2001 From: Li Li Date: Mon, 31 Jan 2022 10:37:02 -0600 Subject: [PATCH 03/76] Change the hipify dependency to hipify_torch (#7) * Change hipify dependency from torch.utils.torch_hipify to hipify_torch. * add the third_party/hipify_torch to git repo --- .gitmodules | 3 +++ fbgemm_gpu/include/fbgemm_gpu/quantize_ops.cuh | 4 +++- fbgemm_gpu/setup.py | 3 ++- third_party/hipify_torch | 1 + 4 files changed, 9 insertions(+), 2 deletions(-) create mode 160000 third_party/hipify_torch diff --git a/.gitmodules b/.gitmodules index 9b3f016d9..acbeeee2c 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,3 +7,6 @@ [submodule "third_party/googletest"] path = third_party/googletest url = https://github.com/google/googletest +[submodule "third_party/hipify_torch"] + path = third_party/hipify_torch + url = https://github.com/ROCmSoftwarePlatform/hipify_torch.git diff --git a/fbgemm_gpu/include/fbgemm_gpu/quantize_ops.cuh b/fbgemm_gpu/include/fbgemm_gpu/quantize_ops.cuh index 10d246ffb..f9902206b 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/quantize_ops.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/quantize_ops.cuh @@ -42,7 +42,9 @@ __device__ inline T max(const T* from, const T* to) { template __device__ inline __attribute__((always_inline)) T quantize_ops_shfl_xor(const T val, int laneMask, int width) { -#if CUDA_VERSION >= 9000 +#ifdef __HIP_PLATFORM_HCC__ + return __shfl_xor(val, laneMask, width); +#elif CUDA_VERSION >= 9000 return __shfl_xor_sync(0xffffffff, val, laneMask, width); #else return __shfl_xor(val, laneMask, width); diff --git a/fbgemm_gpu/setup.py b/fbgemm_gpu/setup.py index 7d033fa0b..d3f10f725 100644 --- a/fbgemm_gpu/setup.py +++ b/fbgemm_gpu/setup.py @@ -15,7 +15,8 @@ from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension import torch -from torch.utils.hipify import hipify_python +sys.path.append("..") +from third_party.hipify_torch.hipify import hipify_python cpu_only_build = False cur_dir = os.path.dirname(os.path.realpath(__file__)) diff --git a/third_party/hipify_torch b/third_party/hipify_torch new file mode 160000 index 000000000..88bd87904 --- /dev/null +++ b/third_party/hipify_torch @@ -0,0 +1 @@ +Subproject commit 88bd87904aaf5d68b908af9fe2ef6b32dbbcf45e From f596bdeec55be7a09f46e08ff1d719eb34520dae Mon Sep 17 00:00:00 2001 From: Li Li Date: Mon, 14 Feb 2022 17:25:20 -0600 Subject: [PATCH 04/76] IFU, merge from upstream commit c6df576 to main. (#8) * unify function signature of jagged_xD_to_dense (#813) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/813 As title Reviewed By: jiaqizhai, jianyuh Differential Revision: D33066551 fbshipit-source-id: 8e2fd3c21f3bde67c6b20045681c2549e3583bd3 * Daily `arc lint --take CLANGFORMAT` Reviewed By: zertosh Differential Revision: D33183467 fbshipit-source-id: d7c37f3522a38e85891524c544eab4fdb01270de * Assert Tensors allocated on GPU. (#819) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/819 Check inputs for correctness wrt to GPU allocation and device. Reviewed By: jspark1105, jianyuh Differential Revision: D33167469 fbshipit-source-id: 04f638d13bde93373d64cff1428ef743300400a6 * Support batched benchmark execution and fix benchmark stats reporting (#818) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/818 As title, support multiple execution of benchmark scripts and report aggregated metric. Further, require `--bag-size` argument to conform to input data file for proper metric accounting. Reviewed By: jianyuh Differential Revision: D33182257 fbshipit-source-id: a6eeeb25646c00665b6d29df9389eddab7618d4e * Direct Convolution JIT assembly for KH=2, KW = 6 Summary: this diff has specialized codegen for convolution case where KH=2 and KW=6 ## Performance results on local devserver with AVX2 instruction: 1, 16, 16, {2, 126}, 1, {2, 6}, {1, 2}, {0, 0, 0, 0}, {1, 1}, {0, 0}, false Fbgemm baseline: 3.8 GOPS This diff: 9.2 GOPS 1, 64, 64, {2, 257}, 1, {2, 6}, {1, 2}, {0, 0, 0, 0}, {1, 1}, {0, 0}, false Fbgemm baseline: 43.8 GOPS This diff: 61.2 GOPS ## How to invoke indirect convolution function: **At offline:** 1. Weights need to be transposed to (oc/8) - (kh) - (kw) - (ic/4) - 8 - 4 2. Create the convolution function based on problem size: ``` CodeGenBase codeObj; CodeGenBase::jit_micro_kernel_fp fn; fn = codeObj.getOrCreateDirectConv( true, conv_p.OUT_DIM[1], conv_p.IN_DIM[1] * conv_p.IC, conv_p.stride[1] * conv_p.IC); ``` 3. Compute the *col_offsets* of weight tensor 4. Make sure you have allocated the space for: output tensor (Cint32_fb, Cint8_fb), and some temporary space for input rowsum ( InSum: IN_DIM[0] x IN_DIM[1], rowSum: OUT_DIM[0] x OUT_DIM[1]) **Online:** Make sure we have: conv_p ( the problem info), Aint8 (input tensor), bBuf_tr ( the transposed weight tensor), Cint32_fb ( the 32-bit results after accumulation), Cint8_fb ( the final quantized 8-bit output). // compute direct conv row sum directConvRowSum(conv_p, Aint8.data(), inSum, rowSum, row_offsets.data()); // kernel for direct convolution for (int oc = 0; oc < conv_p.OC; oc+= 8) { fn(Aint8.data(), bBuf_tr.data() + oc * kernel_dim * conv_p.IC , bBuf_tr.data(), Cint32_fb.data() + oc, conv_p.IC * conv_p.K[1], conv_p.OC); } requantizationParams_t<> reqObj = { Aint8_zero_point, // Aq_zero_point Bint8_zero_point.data(), C_zero_point, C_multiplier.data(), rowSum, // row_offsets //row_offsets.data(), col_offsets.data(), // col_offsets nullptr, // bias static_cast(conv_p.OC), // ncols 1, // groups nullptr}; requantizeOutputProcessingAvx2(Cint8_fb.data(), Cint32_ref.data(), {0, conv_p.OUT_DIM[1] * conv_p.OUT_DIM[0], 0, conv_p.OC}, conv_p.OC, conv_p.OC, reqObj); For more details please refer to test_asmjit2.cc Reviewed By: dskhudia Differential Revision: D31775222 fbshipit-source-id: 294450613b0978277e75d171d6a560124c14ecda * suppress errors in `deeplearning/fbgemm/fbgemm_gpu` Differential Revision: D33201593 fbshipit-source-id: 251f338e03dfde1dcc4a83c4ff9df1fe27840bdb * fix copy right header of batch_benchmark_run.py (#820) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/820 As title Reviewed By: jianyuh Differential Revision: D33213812 fbshipit-source-id: d901e87ff1047ff969c99a330aa05c8d26e1954e * Assert Tensors allocated on GPU for generated code. (#821) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/821 Check inputs for correctness wrt to GPU allocation and device. Reviewed By: jspark1105 Differential Revision: D33189944 fbshipit-source-id: 36fb5eac677466e783ef5a754c28b6d838ea09b7 * Move all fbgemm_gpu provided Python ops to fbgemm namespace from fb namespace. (#823) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/823 Reviewed By: jianyuh Differential Revision: D33147038 fbshipit-source-id: fdcb667dfb920b4f04b7d0b08082afabe7213cc1 * Implement generic HBC by feature. (#822) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/822 Implement a generic version of HBC by feature, which takes in bin_boundaries. Reviewed By: jianyuh Differential Revision: D33232676 fbshipit-source-id: 99c77f6d081fdc89699948a6c9482b8806f598a3 * Benchmark for newly added generic HBC by feature. (#826) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/826 More benchmarking for new op, and also add "double" for benchmarking type. Reviewed By: jianyuh Differential Revision: D33241845 fbshipit-source-id: 38f08f5453fd8d112ff55c046a6ac091c23bc3de * Allways set dontfork on managed Tensor + new uvm clone (#824) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/824 Workaround for S256045. UVM Tensors are unmapped from the process page table on fork (spawn). The UVM fault handler then slows down the UVM CPU<->CPU copy substantially reestablishing those mappings. The workaround sets MADV_DONTFORK on the addresses (rounded down to page size) of UVM allocations - this prevents the removal from UVM pages from the original process page table. Additionally this introduces a single threaded UVM->CPU tensor copy to 1) Avoid 8 trainers on a host to concurrently all threads with copy_ 2) Avoid high concurency in the fault handler of the uvm kernel driver. Reviewed By: jianyuh Differential Revision: D33192043 fbshipit-source-id: 094f3dcd302d455efbf4e912d58ed28756cb653f * Use kWarpSize for warp size (#827) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/827 Reviewed By: rweyrauch Differential Revision: D33271792 fbshipit-source-id: dc66b6950b37e5d92c10406a3891568a7500e26e * Move fb.embedding_bag_rowwise_prune to fbgemm_gpu OSS. (#825) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/825 Move the fb.embedding_bag_rowwise_prune op from caffe2/fb/sparsenn to fbgemm_gpu. Reviewed By: jianyuh Differential Revision: D33240318 fbshipit-source-id: 4db93a1ecd9666881779eeada1e3e493aa7525e4 * Allow optional Tensor args to be empty or on GPU. (#828) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/828 Reviewed By: jianyuh Differential Revision: D33267641 fbshipit-source-id: b193ee5b7e9ea946a20672760c320f29b217b998 * Add output_dtype to training TBE op for CPU (#829) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/829 This Diff adds `output_dtype` to `split_embedding_codegen_lookup_{{ optimizer }}_function_cpu()`. Note that the CUDA version (`split_embedding_codegen_lookup_{{ optimizer }}_function()`) already has this argument (D32399931 (https://github.com/pytorch/FBGEMM/commit/7e1183c1f13cf6753f546eec48488bfb56d80481)). Reviewed By: jianyuh Differential Revision: D32969921 fbshipit-source-id: 695e54434dc4f65f9f4c60782c60a550e38d97a7 * fix copyright header of tensor_assert_test.cpp (#831) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/831 As title Reviewed By: rweyrauch Differential Revision: D33310866 fbshipit-source-id: 1cbdee1d7c00f0e900faac570bac330866887b1c * Add permute_pooled_embedding_modules_test into RE (#830) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/830 As title Reviewed By: rweyrauch Differential Revision: D33303898 fbshipit-source-id: c94a14bc398ecb58b68ca15d7e79204233ac67d1 * Use all to one op to do DtoD between remote and merge (#817) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/817 Previously we were simply calling `Tensor.to` to launch DtoD copy. Since PyTorch is doing two-way barrier for DtoD copy, all the DtoD copies are serialized even though they are launched from different devices. See the blue DtoD copies in the graph below. {F686842812} At first I went for merge_pooled_embedding directly but I forgot that MRS models also have sequence embeddings. Covering pooled embeddings are not enough in this case. This diff introduced a function that takes in a tuple of ivalues and move the underlining tensors to a given target device then outputs a vector of ivalues with underlining tensors in the same device. For each source device, we synchronize its current stream and launch all the copies for tensors in that device. Then we synchronize the current stream on target device to wait on all the copies. Now the copies from different devices can run in parallel. {F686843333} Reviewed By: yinghai, jianyuh, houseroad Differential Revision: D33065710 fbshipit-source-id: f479fa2ea20702e14419c8b87024a87d5bbb1a68 * Add MSFP option for ads hpc model numeric emulations (#832) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/832 Add fake conversions between MSFP and fp32 in both forward and backward pass of the hpc ads model training. TODO: Add compute kernels that split the FC operator into gemms for column_blocks of activations and row_blocks of weights Reviewed By: jspark1105 Differential Revision: D30942234 fbshipit-source-id: 601d671fd00622304a50651dedffd0de3ae01ae0 * Remove benchmark CMakeLists.txt (#835) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/835 As title. This file is no longer needed after we decide to support setup.py only OSS build approach. Reviewed By: jspark1105, rweyrauch Differential Revision: D33318121 fbshipit-source-id: 4f71b23f6e9e7e78d50fab20af53cdf9f63844ad * Increase code reuse between FP32, FP16, INT8, INT4 embedding types for infer TBE (#833) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/833 We merge the implementation for {FP32, FP16, INT8, INT4} weights in inference TBE into one unified template and increase the code reuse between these implementations. This will pave the way for the future enhancements (no need to change all 4 implementations for one new feature). Reviewed By: rweyrauch Differential Revision: D33343450 fbshipit-source-id: 24e59c4a2df5ef3da353535eb879a2365293bc1f * minimize functions defined in headers (#836) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/836 We had so much stuffs that didn't need to be at header files. Split long source files. Put experimental quantization functions to experimental namespace Reviewed By: rweyrauch Differential Revision: D33358916 fbshipit-source-id: cffcec344cbe565045ee2c564ce1cef529de4cf8 * add missing C10_CUDA_KERNEL_LAUNCH_CHECK (#837) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/837 As title Reviewed By: rweyrauch Differential Revision: D33359025 fbshipit-source-id: 162dd2897a5d56e7ac8ff3ba9ae5c8689961204b * Add seq embedding kernel for infer TBE (#834) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/834 - Add sequence embedding support in infer TBE kernel - TODO: "mask" solution for the duplicated embedding row access. cc jspark1105 Reviewed By: jspark1105 Differential Revision: D33341863 fbshipit-source-id: 47babe921dbaf086e2df92f4693b4718c01bcec1 * add missing new files to CMakeLists.txt (#838) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/838 This was missed in D33358916 (https://github.com/pytorch/FBGEMM/commit/38a6c3553e33a7c0d5b2e7758dbed6e3ae9e47b0) Reviewed By: colin2328 Differential Revision: D33370387 fbshipit-source-id: 72007f51afd6757690a1898098e8b6207c3c487b * Support int32_t indices/offsets for caching handling logics (#811) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/811 In training, we assume the indices / offsets are int64_t for embedding (TBE), but in inference, we assume the indices / offsets are int32_t. This Diff enables both int32_t and int64_t supports for the caching logics so that we can reuse the same functions for both training and inference, while reducing the extra overhead to convert the indices/offsets from int to long or vice versa. Reviewed By: jspark1105 Differential Revision: D33045589 fbshipit-source-id: 4e508a1095536a629bdab8e5577db74310032b23 * Add seq embedding benchmark Summary: 5x ~ 10x speedup in the benchmark level. Reviewed By: jspark1105 Differential Revision: D33355933 fbshipit-source-id: 2c609ae9ec5fd4fda48dbafa13b5eb75900fdf5f * fix warning count check in test_bounds_check (#839) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/839 In GPU multiple threads in a thread block can increase warning count for the same bound errors in offset array Reviewed By: jianyuh Differential Revision: D33379301 fbshipit-source-id: b00520cc613bb7e15c9f8cd4bdf0c61bd4dbd83b * fix typo in CMakeLists.txt (#840) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/840 Fixing a silly typo Reviewed By: jianyuh Differential Revision: D33380967 fbshipit-source-id: 8220cc87a2564107cb124d3f9c31b8d92cb7d1a4 * Slight perf optimization for infer TBE (#843) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/843 ~5% perf improvement for INT4 / INT8 inference TBE on A100 GPUs. Reviewed By: jspark1105 Differential Revision: D33388153 fbshipit-source-id: 63566e3dccd9ce4775abb3374251f9046512e131 * extract embedding input transpose out of embedding_backward_split_template.cu (#841) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/841 Refactoring to prepare D33381126 Other minor changes * Remove unused sorted_linear_indices_run_lengths parameter from bwd kernels Reviewed By: jianyuh Differential Revision: D33380032 fbshipit-source-id: b880cc3745a6f6dd63319109e753a470d6c28c49 * increase parallelism in batched unary embeddings backward (#842) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/842 Sort indices and have each thread handle indices with the same values (called a run in the code) Reviewed By: jianyuh Differential Revision: D33381126 fbshipit-source-id: aec1c0be619b9072f5a1f9273b66c03e5106ca02 * use DISPATCH_TO_CUDA macro (#845) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/845 We should use the macro consistently or just drop Reviewed By: jianyuh Differential Revision: D33392682 fbshipit-source-id: bd99286f55fe2d6e5bab231ec65dae02f16f35c2 * Follow-up comments (#844) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/844 Reviewed By: jspark1105 Differential Revision: D33393019 fbshipit-source-id: 1df7d8457a950a829f7ff2fe6f47595afdc9cc26 * HIP extension support for FBGEMM_GPU (#846) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/846 Reviewed By: jspark1105 Differential Revision: D33231489 fbshipit-source-id: 6bd46ddee45c767ad25c2d52b6c05030bba94082 * correct the max_shared_bytes logit evaluation logic in embedding_backward_split_template.cu * IFU from from upstream commit c6df576 to main. fbgemm-gpu is built and imported. Tests do NOT pass. Co-authored-by: Xing Liu Co-authored-by: CodemodService FBSourceClangFormatLinterBot <> Co-authored-by: Rick Weyrauch Co-authored-by: Martin Schatz Co-authored-by: Jiyuan Zhang Co-authored-by: Jongsoo Park Co-authored-by: Jason Park Co-authored-by: Stephan Uphoff Co-authored-by: Jianyu Huang Co-authored-by: Shintaro Iwasaki Co-authored-by: Shiyan Deng Co-authored-by: Summer Deng --- defs.bzl | 1 + fbgemm_gpu/CMakeLists.txt | 13 +- ...histogram_binning_calibration_benchmark.py | 75 +- .../bench/merge_embeddings_benchmark.py | 27 +- fbgemm_gpu/bench/scripts/README.md | 54 ++ .../bench/scripts/batch_benchmark_run.py | 90 ++ ...plit_table_batched_embeddings_benchmark.py | 209 ++++- .../embedding_backward_code_generator.py | 30 +- .../codegen/embedding_backward_dense_host.cpp | 17 +- .../embedding_backward_dense_host_cpu.cpp | 10 +- ...ing_backward_split_cpu_approx_template.cpp | 2 +- .../embedding_backward_split_cpu_template.cpp | 2 +- ...dding_backward_split_host_cpu_template.cpp | 13 +- ...embedding_backward_split_host_template.cpp | 11 +- ..._backward_split_indice_weights_template.cu | 22 +- .../embedding_backward_split_template.cu | 213 ++--- fbgemm_gpu/codegen/embedding_bounds_check.cu | 14 +- .../codegen/embedding_bounds_check_host.cpp | 17 +- .../embedding_bounds_check_host_cpu.cpp | 9 +- ...bedding_forward_quantized_cpu_template.cpp | 2 +- .../embedding_forward_quantized_host.cpp | 81 +- .../embedding_forward_quantized_host_cpu.cpp | 32 + ...edding_forward_quantized_split_template.cu | 816 ++++------------- .../codegen/embedding_forward_split_cpu.cpp | 4 +- .../embedding_forward_split_template.cu | 38 +- .../embedding_forward_template_helpers.cuh | 3 +- .../split_embedding_inference_converter.py | 9 +- .../split_table_batched_embeddings_ops.py | 89 +- fbgemm_gpu/fbgemm_gpu/uvm.py | 6 +- .../batched_unary_embedding_ops.cuh | 77 -- .../embedding_backward_template_helpers.cuh | 78 +- .../fbgemm_gpu}/embedding_common.h | 2 +- .../include/fbgemm_gpu/fbgemm_cuda_utils.cuh | 233 +++-- .../fbgemm_gpu/merge_pooled_embeddings.h | 18 + .../include/fbgemm_gpu/quantize_ops.cuh | 602 +------------ fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h | 43 +- .../include/fbgemm_gpu/sparse_ops_utils.h | 15 + .../fbgemm_gpu}/split_embeddings_utils.cuh | 38 +- fbgemm_gpu/setup.py | 4 + fbgemm_gpu/src/cumem_utils.cu | 48 +- fbgemm_gpu/src/cumem_utils.h | 4 + fbgemm_gpu/src/cumem_utils_host.cpp | 47 +- .../src/histogram_binning_calibration_ops.cu | 389 ++++++++ fbgemm_gpu/src/jagged_tensor_ops.cu | 236 +++++ fbgemm_gpu/src/layout_transform_ops.cu | 7 + .../src/merge_pooled_embeddings_gpu.cpp | 260 +++--- .../src/permute_pooled_embedding_ops.cu | 13 +- .../src/permute_pooled_embedding_ops_gpu.cpp | 21 +- fbgemm_gpu/src/quantize_ops.cu | 538 +++++++++++ fbgemm_gpu/src/sparse_ops.cu | 852 +++--------------- fbgemm_gpu/src/sparse_ops_cpu.cpp | 280 +++++- fbgemm_gpu/src/sparse_ops_gpu.cpp | 23 +- fbgemm_gpu/src/split_embeddings_cache_cuda.cu | 762 ++++++++++------ fbgemm_gpu/src/split_embeddings_utils.cu | 229 +++++ .../src/split_table_batched_embeddings.cpp | 61 +- .../test/merge_pooled_embeddings_test.py | 26 + .../permute_pooled_embedding_modules_test.py | 6 +- fbgemm_gpu/test/quantize_ops_test.py | 35 + fbgemm_gpu/test/sparse_ops_test.py | 159 +++- ...plit_embedding_inference_converter_test.py | 2 - .../split_table_batched_embeddings_test.py | 69 +- fbgemm_gpu/test/tensor_assert_test.cpp | 32 + fbgemm_gpu/test/uvm_test.py | 91 +- src/DirectConv.h | 159 ++++ src/GenerateKernel.h | 2 +- src/GenerateKernelDirectConvU8S8S32ACC32.cc | 475 ++++++++++ 66 files changed, 4757 insertions(+), 3088 deletions(-) create mode 100644 fbgemm_gpu/bench/scripts/README.md create mode 100644 fbgemm_gpu/bench/scripts/batch_benchmark_run.py delete mode 100644 fbgemm_gpu/include/fbgemm_gpu/batched_unary_embedding_ops.cuh rename fbgemm_gpu/{codegen => include/fbgemm_gpu}/embedding_backward_template_helpers.cuh (59%) rename fbgemm_gpu/{codegen => include/fbgemm_gpu}/embedding_common.h (97%) create mode 100644 fbgemm_gpu/include/fbgemm_gpu/merge_pooled_embeddings.h rename fbgemm_gpu/{src => include/fbgemm_gpu}/split_embeddings_utils.cuh (77%) create mode 100644 fbgemm_gpu/src/histogram_binning_calibration_ops.cu create mode 100644 fbgemm_gpu/src/jagged_tensor_ops.cu create mode 100644 fbgemm_gpu/src/quantize_ops.cu create mode 100644 fbgemm_gpu/src/split_embeddings_utils.cu create mode 100644 fbgemm_gpu/test/tensor_assert_test.cpp create mode 100644 src/DirectConv.h create mode 100644 src/GenerateKernelDirectConvU8S8S32ACC32.cc diff --git a/defs.bzl b/defs.bzl index c9c06ed68..f78ae8a07 100644 --- a/defs.bzl +++ b/defs.bzl @@ -25,6 +25,7 @@ def get_fbgemm_generic_srcs(with_base = False): "src/FbgemmI64.cc", "src/FbgemmSparseDense.cc", "src/FbgemmI8Spmdm.cc", + "src/GenerateKernelDirectConvU8S8S32ACC32.cc", "src/GenerateKernel.cc", "src/GenerateKernelU8S8S32ACC16.cc", "src/GenerateKernelU8S8S32ACC16Avx512.cc", # Acc16 AVX512 JIT code gen diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt index 05181646f..c62f8585b 100644 --- a/fbgemm_gpu/CMakeLists.txt +++ b/fbgemm_gpu/CMakeLists.txt @@ -110,7 +110,6 @@ set(codegen_dependencies ${CMAKE_CODEGEN_DIR}/embedding_backward_split_host_template.cpp ${CMAKE_CODEGEN_DIR}/embedding_backward_split_indice_weights_template.cu ${CMAKE_CODEGEN_DIR}/embedding_backward_split_template.cu - ${CMAKE_CODEGEN_DIR}/embedding_backward_template_helpers.cuh ${CMAKE_CODEGEN_DIR}/embedding_forward_quantized_cpu_template.cpp ${CMAKE_CODEGEN_DIR}/embedding_forward_quantized_host.cpp ${CMAKE_CODEGEN_DIR}/embedding_forward_quantized_host_cpu.cpp @@ -122,6 +121,14 @@ set(codegen_dependencies ${CMAKE_CODEGEN_DIR}/__init__.template ${CMAKE_CODEGEN_DIR}/lookup_args.py ${CMAKE_CODEGEN_DIR}/split_embedding_codegen_lookup_invoker.template + ${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/cpu_utils.h + ${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/cub_namespace_postfix.cuh + ${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/dispatch_macros.h + ${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/embedding_backward_template_helpers.cuh + ${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/embedding_common.h + ${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/fbgemm_cuda_utils.cuh + ${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/split_embeddings_utils.cuh + ${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/sparse_ops_utils.h ) add_custom_command( @@ -222,8 +229,10 @@ set_source_files_properties( set(fbgemm_gpu_sources_gpu codegen/embedding_bounds_check.cu src/cumem_utils.cu + src/histogram_binning_calibration_ops.cu src/jagged_tensor_ops.cu src/layout_transform_ops.cu src/permute_pooled_embedding_ops.cu - src/sparse_ops.cu src/split_embeddings_cache_cuda.cu) + src/quantize_ops.cu src/sparse_ops.cu src/split_embeddings_cache_cuda.cu + src/split_embeddings_utils.cu) set_source_files_properties(${fbgemm_gpu_sources_gpu} PROPERTIES COMPILE_OPTIONS "${TORCH_CUDA_OPTIONS}") diff --git a/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py b/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py index 7fc2d2a5b..6d48a58e5 100644 --- a/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py +++ b/fbgemm_gpu/bench/histogram_binning_calibration_benchmark.py @@ -51,25 +51,43 @@ def main( warmup_runs: int, ) -> None: + data_types = [torch.half, torch.float, torch.double] + total_time = { "hbc": { "cpu": { torch.half: 0.0, torch.float: 0.0, + torch.double: 0.0, }, "gpu": { torch.half: 0.0, torch.float: 0.0, + torch.double: 0.0, }, }, "hbc_by_feature": { "cpu": { torch.half: 0.0, torch.float: 0.0, + torch.double: 0.0, + }, + "gpu": { + torch.half: 0.0, + torch.float: 0.0, + torch.double: 0.0, + }, + }, + "generic_hbc_by_feature": { + "cpu": { + torch.half: 0.0, + torch.float: 0.0, + torch.double: 0.0, }, "gpu": { torch.half: 0.0, torch.float: 0.0, + torch.double: 0.0, }, }, } @@ -89,9 +107,13 @@ def main( lower_bound: float = 0.0 upper_bound: float = 1.0 + w: float = (upper_bound - lower_bound) / num_bins bin_num_examples: Tensor = torch.empty([num_bins], dtype=torch.float64).fill_(0.0) bin_num_positives: Tensor = torch.empty([num_bins], dtype=torch.float64).fill_(0.0) + bin_boundaries: Tensor = torch.arange( + lower_bound + w, upper_bound - w / 2, w, dtype=torch.float64 + ) by_feature_bin_num_examples: Tensor = torch.empty( [num_bins * (num_segments + 1)], dtype=torch.float64 @@ -128,8 +150,22 @@ def fbgemm_hbc_by_feature_cpu(input: Tensor) -> Tuple[Tensor, Tensor]: 0.9995, ) + def fbgemm_generic_hbc_by_feature_cpu(input: Tensor) -> Tuple[Tensor, Tensor]: + return torch.ops.fbgemm.generic_histogram_binning_calibration_by_feature( + input, + segment_values, + segment_lengths, + num_segments, + by_feature_bin_num_examples, + by_feature_bin_num_positives, + bin_boundaries, + 0.4, + 0, + 0.9995, + ) + for step in range(iters + warmup_runs): - for data_type in [torch.half, torch.float]: + for data_type in data_types: curr_input = input_data_cpu.to(data_type) hbc_time, _ = benchmark_hbc_function( fbgemm_hbc_cpu, @@ -139,9 +175,16 @@ def fbgemm_hbc_by_feature_cpu(input: Tensor) -> Tuple[Tensor, Tensor]: hbc_by_feature_time, _ = benchmark_hbc_function( fbgemm_hbc_by_feature_cpu, curr_input ) + + generic_hbc_by_feature_time, _ = benchmark_hbc_function( + fbgemm_generic_hbc_by_feature_cpu, curr_input + ) if step >= warmup_runs: total_time["hbc"]["cpu"][data_type] += hbc_time total_time["hbc_by_feature"]["cpu"][data_type] += hbc_by_feature_time + total_time["generic_hbc_by_feature"]["cpu"][ + data_type + ] += generic_hbc_by_feature_time if torch.cuda.is_available(): bin_num_examples_gpu: Tensor = bin_num_examples.cuda() @@ -183,7 +226,27 @@ def fbgemm_hbc_by_feature_gpu(input: Tensor) -> Tuple[Tensor, Tensor]: 0.9995, ) - for data_type in [torch.half, torch.float]: + bin_boundaries_gpu: Tensor = bin_boundaries.cuda() + + def fbgemm_generic_hbc_by_feature_gpu( + input: Tensor, + ) -> Tuple[Tensor, Tensor]: + return ( + torch.ops.fbgemm.generic_histogram_binning_calibration_by_feature( + input, + segment_values_gpu, + segment_lengths_gpu, + num_segments, + by_feature_bin_num_examples_gpu, + by_feature_bin_num_positives_gpu, + bin_boundaries_gpu, + 0.4, + 0, + 0.9995, + ) + ) + + for data_type in data_types: curr_input_gpu = input_data_cpu.cuda().to(data_type) hbc_time, _ = benchmark_hbc_function( fbgemm_hbc_gpu, @@ -194,11 +257,19 @@ def fbgemm_hbc_by_feature_gpu(input: Tensor) -> Tuple[Tensor, Tensor]: fbgemm_hbc_by_feature_gpu, curr_input_gpu, ) + + generic_hbc_by_feature_time, _ = benchmark_hbc_function( + fbgemm_generic_hbc_by_feature_gpu, + curr_input_gpu, + ) if step >= warmup_runs: total_time["hbc"]["gpu"][data_type] += hbc_time total_time["hbc_by_feature"]["gpu"][ data_type ] += hbc_by_feature_time + total_time["generic_hbc_by_feature"]["gpu"][ + data_type + ] += generic_hbc_by_feature_time for op, curr_items in total_time.items(): for platform, data_items in curr_items.items(): diff --git a/fbgemm_gpu/bench/merge_embeddings_benchmark.py b/fbgemm_gpu/bench/merge_embeddings_benchmark.py index fe07f5c9b..bfabf5379 100644 --- a/fbgemm_gpu/bench/merge_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/merge_embeddings_benchmark.py @@ -23,13 +23,16 @@ @click.command() +@click.option("--all-to-one-only", is_flag=True, default=False) @click.option("--num-ads", default=1024, type=int) @click.option("--embedding-dimension", default=300, type=int) @click.option("--ads-tables", default=400, type=int) @click.option("--iters", default=10, type=int) @click.option("--p2p_bw", is_flag=True, default=False) @click.option("--dst-device", default=0, type=int) -def main(num_ads, embedding_dimension, ads_tables, iters, p2p_bw, dst_device) -> None: +def main( + all_to_one_only, num_ads, embedding_dimension, ads_tables, iters, p2p_bw, dst_device +) -> None: torch.cuda.set_device(dst_device) num_gpus = torch.cuda.device_count() ad_ds = [embedding_dimension * ads_tables for _ in range(num_gpus)] @@ -81,15 +84,25 @@ def benchmark_torch_function(iters: int, f, *args) -> float: for stream in streams: stack.enter_context(torch.cuda.stream(stream)) - t = benchmark_torch_function( - iters, - lambda: torch.ops.fbgemm.merge_pooled_embeddings( - pooled_ad_embeddings, batch_indices.size(0), batch_indices.device - ), - ) merged = torch.ops.fbgemm.merge_pooled_embeddings( pooled_ad_embeddings, batch_indices.size(0), batch_indices.device ) + + if all_to_one_only: + t = benchmark_torch_function( + iters, + lambda: torch.ops.fbgemm.all_to_one_device( + pooled_ad_embeddings, batch_indices.device + ), + ) + else: + t = benchmark_torch_function( + iters, + lambda: torch.ops.fbgemm.merge_pooled_embeddings( + pooled_ad_embeddings, batch_indices.size(0), batch_indices.device + ), + ) + print( f"Merge, B: {num_ads}, D: {embedding_dimension}, T: {ads_tables}, Num GPUs: {num_gpus}, Destination GPU: {dst_device} Output Size: {merged.numel() * 2 / 1.0e6:.2f}MB, BW: {merged.numel() * 2 / t / 1.0e9:.2f}GB/s, t: {t * 1.0e3:.2f}ms" ) diff --git a/fbgemm_gpu/bench/scripts/README.md b/fbgemm_gpu/bench/scripts/README.md new file mode 100644 index 000000000..33fa798d9 --- /dev/null +++ b/fbgemm_gpu/bench/scripts/README.md @@ -0,0 +1,54 @@ +# Running `batch_benchmark_run.py` +This script acts as a wrapper around the existing `split_table_batched_embeddings_benchmark.py` +benchmark to execute multiple benchmark instances and aggregate the results. + +Options for each execution are to be specified in individual lines of an input file that is +passed to the script via the `--command-file` argument. To accommodate various build +configurations, the command used to invoke `split_table_batched_embeddings_benchmark` instances +is passed to the script via the `--benchmark-command` argument. + +An example of a typical execution is: +``` +python batch_benchmark_run.py --benchmark-command "python split_table_batched_embeddings_benchmark.py" --command-file batch_input.txt +``` + +which will provide something like the following output: +``` +Running command 0: [] +... + +... +Running command 1: [] +... + +... +Number of commands run: 2 +Average FWD BW: 1197.9126493108731 GB/s + FWDBWD BW: 859.5188964175346 GB/s +``` + +Any commands failed will be reported to ease debugging. + +## Expected use-case +This script is intended to be used in conjunction with synthetic datasets provided +in the [dlrm_datasets repository](https://github.com/facebookresearch/dlrm_datasets). +Simply clone this repository to obtain the datasets. + +Datasets in this repository provide inputs to the `split_table_batched_embeddings_benchmark.py` +benchmark and can be specified with the `--requests_data_file` argument. A subset of tables +provided in the input dataset can be used for benchmarking through the `--tables` arguemnt. + +Please note that in order to use this feature, dimensions of the tables in the dataset +must conform to the corresponding arguments of the benchmark; these being the following: +* `--batch-size` +* `--num-tables` +* `--num-embeddings` +* `--bag-size` + +Hence, a typical line in the input file to `batch_benchmark_run.py` will look something like the following: +``` +device --requests_data_file ./fbgemm_t856_bs65536.pt --batch-size 65536 --num-tables 1 --tables "44" --num-embeddings 6618839 --bag-size 194 +``` + +An error will be shown if any of these arguments do not align with the data provided. This is +in order to ensure proper accounting when metric reporting is performed. diff --git a/fbgemm_gpu/bench/scripts/batch_benchmark_run.py b/fbgemm_gpu/bench/scripts/batch_benchmark_run.py new file mode 100644 index 000000000..4b73068ef --- /dev/null +++ b/fbgemm_gpu/bench/scripts/batch_benchmark_run.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import re +import subprocess + +import click + +logging.basicConfig(level=logging.DEBUG) + + +@click.command() +@click.option( + "--benchmark-command", + default="python split_table_batched_embeddings_benchmark.py", + help="Benchmark command to run", +) +@click.option( + "--command-file", + default="batch_input.txt", + help="File containing input commands to evaluate", +) +def batch_benchmark( + benchmark_command: str, + command_file: str, +) -> None: + assert ( + "split_table_batched_embeddings_benchmark" in benchmark_command + ), "split_table_batched_embeddings benchmark required for execution" + + benchmark_cmd = benchmark_command.strip().split() + + cmds_run = 0 + failed_runs = [] + total_fwd_bytes_read_gb = 0 + total_fwdbwd_bytes_read_gb = 0 + total_fwd_time_us = 0 + total_fwdbwd_time_us = 0 + with open(command_file) as cmd_file: + for line in cmd_file: + options = line.replace('"', "").strip().split() + cmd = benchmark_cmd + options + logging.info(f"Running command {cmds_run}: {cmd}") + result = subprocess.run( + cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT + ) + logging.info(result.stdout.decode("utf-8")) + # Parse results + found_fwd_results = False + found_fwdbwd_results = False + for line in result.stdout.decode("utf-8").splitlines(): + re_match = re.search(r"BW: ([\.\d]+) GB/s, T: ([\.\d]+)us", line) + if re_match: + bw_gb = float(re_match.groups()[0]) + time_us = int(re_match.groups()[1]) + total_bytes_read_gb = bw_gb * time_us / 1e6 + + if "Forward, " in line: + total_fwd_bytes_read_gb += total_bytes_read_gb + total_fwd_time_us += time_us + found_fwd_results = True + elif "ForwardBackward, " in line: + total_fwdbwd_bytes_read_gb += total_bytes_read_gb + total_fwdbwd_time_us += time_us + found_fwdbwd_results = True + else: + raise Exception( + f"Unexpected reported metric for line: '{line}'" + ) + if not (found_fwd_results and found_fwdbwd_results): + failed_runs.append(cmds_run) + cmds_run += 1 + logging.info(f"Number of commands run: {cmds_run}") + if failed_runs: + logging.info(f"Failed runs: {failed_runs}") + logging.info( + f"Average FWD BW: {total_fwd_bytes_read_gb / total_fwd_time_us * 1e6} GB/s" + ) + logging.info( + f" FWDBWD BW: {total_fwdbwd_bytes_read_gb / total_fwdbwd_time_us * 1e6} GB/s" + ) + + +if __name__ == "__main__": + batch_benchmark() diff --git a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py index 816fff57b..9c6e9e0b7 100644 --- a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py @@ -35,6 +35,7 @@ SparseType, SplitTableBatchedEmbeddingBagsCodegen, IntNBitTableBatchedEmbeddingBagsCodegen, + PoolingMode, ) from numpy.random import default_rng from torch import Tensor @@ -68,6 +69,34 @@ def get_table_batched_offsets_from_dense( ) +def get_offsets_from_dense(indices: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + (B, L) = indices.size() + return ( + indices.contiguous().view(-1), + torch.tensor( + np.cumsum(np.asarray([0] + [L for _ in range(B)])[:-1]).astype(np.int64) + ), + ) + + +def b_indices( + b: Callable[..., torch.Tensor], + x: torch.Tensor, + per_sample_weights: Optional[torch.Tensor] = None, + use_cpu: bool = False, + do_pooling: bool = True, +) -> torch.Tensor: + (indices, offsets) = get_offsets_from_dense(x) + if do_pooling: + return b( + indices.cuda(), + offsets.cuda(), + per_sample_weights=per_sample_weights, + ) + else: + return b(indices.cuda()) + + def generate_requests( iters: int, B: int, @@ -88,14 +117,15 @@ def generate_requests( if requests_data_file is not None: indices_tensor, offsets_tensor, lengths_tensor = torch.load(requests_data_file) - logging.warning("Ignoring L parameter as requests data file has been provided") - + average_L = 0 if tables is not None: emb_tables = tuple(int(x) for x in tables.split(",")) indices = torch.zeros(0, dtype=indices_tensor.dtype) offsets = torch.zeros(1, dtype=offsets_tensor.dtype) + total_L = 0 for t in emb_tables: t_offsets = offsets_tensor[B * t : B * (t + 1) + 1] + total_L += t_offsets[-1] - t_offsets[0] indices = torch.cat( (indices, indices_tensor[t_offsets[0] : t_offsets[-1]]) ) @@ -107,6 +137,7 @@ def generate_requests( ) indices_tensor = indices offsets_tensor = offsets + average_L = int(total_L / B) assert np.prod(offsets_tensor.size()) - 1 == np.prod((T, B)), ( f"Requested tables: {emb_tables} " @@ -117,12 +148,16 @@ def generate_requests( f"on tables: {emb_tables}" ) else: + average_L = int((offsets_tensor[-1] - offsets_tensor[0]) / B) assert (np.prod(offsets_tensor.size()) - 1) == np.prod((T, B)), ( f"Data file (indices = {indices_tensor.size()}, " f"offsets = {offsets_tensor.size()}, lengths = {lengths_tensor.size()}) " f"does not conform to inputs (T, B) = ({T}, {B})." ) + assert ( + L == average_L + ), f"Requested L does not align with provided data file ({L} vs. {average_L})" assert E > max(indices_tensor), ( f"Number of embeddings is not enough to support maximum index " f"provided by data file {E} vs. {max(indices_tensor)}" @@ -232,6 +267,90 @@ def benchmark_requests( return median_time if check_median else avg_time +def benchmark_requests_refer( + requests: List[Tuple[torch.IntTensor, torch.IntTensor, Optional[Tensor]]], + T: int, + B: int, + L: int, + E: int, + D: int, + pooling_mode: str, + weighted: bool, + flush_gpu_cache_size_mb: int = 0, + check_median: bool = False, +) -> float: + do_pooling = pooling_mode in ["sum", "mean"] + if do_pooling: + nn_embedding_list = [ + torch.nn.EmbeddingBag(E, D, mode=pooling_mode, sparse=True).cuda() + ] * T + else: + nn_embedding_list = [torch.nn.Embedding(E, D, sparse=True).cuda()] * T + + times = [] + if torch.cuda.is_available(): + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + for (indices, _, weights) in requests: + indices_list = indices.view(T, B, L).split(1) + + if weighted: + assert weights is not None + weights_list = weights.view(T, B, L).split(1) + + start_time = time.time() + if torch.cuda.is_available(): + if flush_gpu_cache_size_mb: + _ = torch.rand( + flush_gpu_cache_size_mb * 1024 * 1024 // 4, dtype=torch.float + ) + torch.cuda.synchronize() + start_event.record() + + nn_embedding_output = ( + [ + b_indices(nn_embedding, x, use_cpu=False, do_pooling=do_pooling) + for (nn_embedding, x) in zip(nn_embedding_list, indices_list) + ] + if not weighted + else [ + b_indices( + nn_embedding, + x, + per_sample_weights=xw.view(-1), + use_cpu=False, + do_pooling=do_pooling, + ) + for (nn_embedding, x, xw) in zip( + nn_embedding_list, + indices_list, + # pyre-fixme[61]: `weights_list` is undefined, or not always + # defined. + weights_list, + ) + ] + ) + if do_pooling: + final_output = torch.cat( + [f.view(B, -1) for f in nn_embedding_output], dim=1 + ) + else: + final_output = torch.cat(nn_embedding_output, dim=0).view(-1, D) + + if torch.cuda.is_available(): + end_event.record() + torch.cuda.synchronize() + it_time = start_event.elapsed_time(end_event) * 1.0e-3 + times.append(it_time) + else: + it_time = time.time() - start_time + times.append(it_time) + avg_time = sum(times) / len(requests) + median_time = statistics.median(times) + return median_time if check_median else avg_time + + def benchmark_pipelined_requests( requests: List[Tuple[torch.IntTensor, torch.IntTensor, Optional[Tensor]]], func1: Callable[[Tensor, Tensor, Optional[Tensor]], None], @@ -1043,6 +1162,7 @@ def nbit_cpu( # noqa C901 @click.option("--reuse", default=0.0) @click.option("--row-wise/--no-row-wise", default=True) @click.option("--weighted", is_flag=True, default=False) +@click.option("--pooling", type=str, default="sum") @click.option("--weighted-num-requires-grad", type=int, default=None) @click.option("--bounds-check-mode", type=int, default=BoundsCheckMode.WARNING.value) @click.option("--pruning-ratio", type=float, default=None) @@ -1054,6 +1174,7 @@ def nbit_cpu( # noqa C901 @click.option("--warmup-runs", default=2) @click.option("--output-dtype", type=SparseType, default=SparseType.FP16) @click.option("--report-aibench", is_flag=True) +@click.option("--run-reference", is_flag=True, default=False) @click.option("--requests_data_file", type=str, default=None) @click.option("--tables", type=str, default=None) def nbit_device( # noqa C901 @@ -1070,6 +1191,7 @@ def nbit_device( # noqa C901 reuse: float, row_wise: bool, weighted: bool, + pooling: str, weighted_num_requires_grad: Optional[int], bounds_check_mode: int, pruning_ratio: Optional[float], @@ -1081,6 +1203,7 @@ def nbit_device( # noqa C901 warmup_runs: int, output_dtype: SparseType, report_aibench: bool, + run_reference: bool, requests_data_file: Optional[str], tables: Optional[str], ) -> None: @@ -1124,6 +1247,17 @@ def nbit_device( # noqa C901 else: managed_option = EmbeddingLocation.MANAGED + if pooling is None or pooling == "sum": + pooling = "sum" + pooling_mode = PoolingMode.SUM + do_pooling = True + elif pooling == "mean": + pooling_mode = PoolingMode.MEAN + do_pooling = True + else: # "none" + pooling_mode = PoolingMode.NONE + do_pooling = False + emb = IntNBitTableBatchedEmbeddingBagsCodegen( [("", E, d, weights_precision, managed_option) for d in Ds], bounds_check_mode=BoundsCheckMode(bounds_check_mode), @@ -1131,15 +1265,22 @@ def nbit_device( # noqa C901 load_factor=load_factor, use_array_for_index_remapping=use_array_for_index_remapping, output_dtype=output_dtype, + pooling_mode=pooling_mode, ).cuda() emb.fill_random_weights() nparams_byte = sum(w.numel() for (w, _) in emb.split_embedding_weights()) param_size_multiplier = weights_precision.bit_rate() / 8.0 output_size_multiplier = output_dtype.bit_rate() / 8.0 - read_write_bytes = ( - output_size_multiplier * B * T * D + param_size_multiplier * B * T * L * D - ) + if do_pooling: + read_write_bytes = ( + output_size_multiplier * B * T * D + param_size_multiplier * B * T * L * D + ) + else: + read_write_bytes = ( + output_size_multiplier * B * T * L * D + + param_size_multiplier * B * T * L * D + ) logging.info( f"{weights_precision} Embedding tables: {E * T} rows, {nparams_byte / param_size_multiplier / 1.0e9: .2f} GParam, " f"{nparams_byte / 1.0e9: .2f} GB" # IntN TBE use byte for storage @@ -1222,6 +1363,62 @@ def nbit_device( # noqa C901 ) ) + if run_reference: + times = [] + for i in range(runs_of_iters): + requests = generate_requests( + iters, + B, + T, + L, + E, + reuse=reuse, + alpha=alpha, + weights_precision=weights_precision, + weighted=weighted, + requests_data_file=requests_data_file, + tables=tables, + ) + requests = [(a.int(), b.int(), c if c else None) for (a, b, c) in requests] + + # forward + time_per_iter_refer = benchmark_requests_refer( + requests, + T, + B, + L, + E, + D, + pooling, + weighted, + check_median=check_median, + ) + + # free up GPU memory + del requests + + logging.info( + f"Reference (nn.Embedding(Bag)) Iteration {i}: " + f"Forward, B: {B}, " + f"E: {E}, T: {T}, D: {D}, L: {L}, W: {weighted}, " + f"BW: {read_write_bytes / time_per_iter_refer / 1.0e9: .2f} GB/s, " # noqa: B950 + f"Time: {time_per_iter_refer * 1.0e6:.0f}us " + ) + + if i >= warmup_runs: + times.append(time_per_iter_refer) + + time_per_iter_refer = statistics.mean(times) + bandwidth = read_write_bytes / time_per_iter_refer / 1.0e9 + + logging.info( + f"Average of all iterations: " + f"Forward, B: {B}, " + f"E: {E}, T: {T}, D: {D}, L: {L}, W: {weighted}, " + f"Effective BW: {bandwidth: .2f} GB/s, " # noqa: B950 + f"Time: {time_per_iter_refer * 1.0e6:.0f}us " + ) + @cli.command() @click.option("--alpha", default=1.0) @@ -1869,7 +2066,7 @@ def bounds_check_indices( # noqa C901 # forward time_per_iter = benchmark_requests( requests, - lambda indices, offsets, _: torch.ops.fb.bounds_check_indices( + lambda indices, offsets, _: torch.ops.fbgemm.bounds_check_indices( rows_per_table, indices, offsets, diff --git a/fbgemm_gpu/codegen/embedding_backward_code_generator.py b/fbgemm_gpu/codegen/embedding_backward_code_generator.py index ee2388f9e..3acb2386d 100644 --- a/fbgemm_gpu/codegen/embedding_backward_code_generator.py +++ b/fbgemm_gpu/codegen/embedding_backward_code_generator.py @@ -397,7 +397,7 @@ def rowwise_adagrad() -> None: momentum1[idx] = new_sum_square_grads; multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps); } - multiplier = SHFL_SYNC_MACRO(multiplier, 0); + multiplier = shfl_sync(multiplier, 0); """ split_weight_update_cpu = """ at::acc_type g_local_sum_square = 0.0; @@ -474,8 +474,8 @@ def rowwise_weighted_adagrad() -> None: multiplier = learning_rate * lambda / (cbrtf(new_sum_square_grads) + eps); correction = 1.0 - multiplier * weight_decay; } - multiplier = SHFL_SYNC_MACRO(multiplier, 0); - correction = SHFL_SYNC_MACRO(correction, 0); + multiplier = shfl_sync(multiplier, 0); + correction = shfl_sync(correction, 0); """ split_weight_update_cpu = """ // weight_decay not supported for cpu version @@ -636,7 +636,7 @@ def partial_rowwise_lamb() -> None: m2 = beta2 * momentum2[idx] + (1.0 - beta2) * g_avg_square; momentum2[idx] = m2; } - m2 = SHFL_SYNC_MACRO(m2, 0); + m2 = shfl_sync(m2, 0); at::acc_type m2_hat = 1.0 / (sqrtf((m2 / (1.0 - powf(beta2, iter)))) + eps); at::acc_type weight_sum_sq = 0.0; @@ -772,7 +772,7 @@ def partial_rowwise_adam() -> None: momentum2[idx] = v_t; v_hat_t = v_t / (1.0 - powf(beta2, iter)); } - v_hat_t = SHFL_SYNC_MACRO(v_hat_t, 0); + v_hat_t = shfl_sync(v_hat_t, 0); """ split_weight_update = """ @@ -884,16 +884,28 @@ def forward_split() -> None: def forward_quantized() -> None: + @dataclass + class elem_type: + enum_name: str + cpp_type_name: str + + type_map = { + 32: elem_type("FP32", "float"), + 16: elem_type("FP16", "__half2"), + 8: elem_type("INT8", "uint32_t"), + 4: elem_type("INT4", "uint32_t"), + } + template = env.get_template("embedding_forward_quantized_split_template.cu") - src_cu = template.render(weighted=False) + src_cu = template.render(weighted=False, type_map=type_map) write("gen_embedding_forward_quantized_split_unweighted_codegen_cuda.cu", src_cu) - src_cu = template.render(weighted=True) + src_cu = template.render(weighted=True, type_map=type_map) write("gen_embedding_forward_quantized_split_weighted_codegen_cuda.cu", src_cu) template = env.get_template("embedding_forward_quantized_cpu_template.cpp") - src_cu = template.render(weighted=False) + src_cu = template.render(weighted=False, type_map=type_map) write("gen_embedding_forward_quantized_unweighted_codegen_cpu.cpp", src_cu) - src_cu = template.render(weighted=True) + src_cu = template.render(weighted=True, type_map=type_map) write("gen_embedding_forward_quantized_weighted_codegen_cpu.cpp", src_cu) diff --git a/fbgemm_gpu/codegen/embedding_backward_dense_host.cpp b/fbgemm_gpu/codegen/embedding_backward_dense_host.cpp index 8c76f867e..151f3c024 100644 --- a/fbgemm_gpu/codegen/embedding_backward_dense_host.cpp +++ b/fbgemm_gpu/codegen/embedding_backward_dense_host.cpp @@ -9,7 +9,8 @@ #include #include -#include "codegen/embedding_common.h" +#include "fbgemm_gpu/embedding_common.h" +#include "fbgemm_gpu/sparse_ops_utils.h" using Tensor = at::Tensor; @@ -394,9 +395,15 @@ Tensor split_embedding_codegen_lookup_dense_function( TORCH_LIBRARY_FRAGMENT(fb, m) { m.def( "dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, int total_D, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad) -> Tensor"); - m.impl( + DISPATCH_TO_CUDA( "dense_embedding_codegen_lookup_function", - torch::dispatch( - c10::DispatchKey::CUDA, - TORCH_FN(split_embedding_codegen_lookup_dense_function))); + split_embedding_codegen_lookup_dense_function); +} + +TORCH_LIBRARY_FRAGMENT(fbgemm, m) { + m.def( + "dense_embedding_codegen_lookup_function(Tensor dev_weights, Tensor weights_offsets, Tensor D_offsets, int total_D, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad) -> Tensor"); + DISPATCH_TO_CUDA( + "dense_embedding_codegen_lookup_function", + split_embedding_codegen_lookup_dense_function); } diff --git a/fbgemm_gpu/codegen/embedding_backward_dense_host_cpu.cpp b/fbgemm_gpu/codegen/embedding_backward_dense_host_cpu.cpp index 680b27c71..39b9eef7c 100644 --- a/fbgemm_gpu/codegen/embedding_backward_dense_host_cpu.cpp +++ b/fbgemm_gpu/codegen/embedding_backward_dense_host_cpu.cpp @@ -8,8 +8,8 @@ #include #include -#include "codegen/embedding_common.h" #include "codegen/embedding_forward_split_cpu.h" +#include "fbgemm_gpu/embedding_common.h" using Tensor = at::Tensor; @@ -176,4 +176,12 @@ TORCH_LIBRARY_IMPL(fb, CPU, m) { TORCH_FN(split_embedding_codegen_lookup_dense_function))); } +TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { + m.impl( + "dense_embedding_codegen_lookup_function", + torch::dispatch( + c10::DispatchKey::CPU, + TORCH_FN(split_embedding_codegen_lookup_dense_function))); +} + } // namespace diff --git a/fbgemm_gpu/codegen/embedding_backward_split_cpu_approx_template.cpp b/fbgemm_gpu/codegen/embedding_backward_split_cpu_approx_template.cpp index 1f686acda..18c35f323 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_cpu_approx_template.cpp +++ b/fbgemm_gpu/codegen/embedding_backward_split_cpu_approx_template.cpp @@ -11,9 +11,9 @@ #include #include -#include "codegen/embedding_common.h" #include "codegen/embedding_forward_split_cpu.h" #include "fbgemm/FbgemmEmbedding.h" +#include "fbgemm_gpu/embedding_common.h" using Tensor = at::Tensor; diff --git a/fbgemm_gpu/codegen/embedding_backward_split_cpu_template.cpp b/fbgemm_gpu/codegen/embedding_backward_split_cpu_template.cpp index dcc7ce410..2d142e230 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_cpu_template.cpp +++ b/fbgemm_gpu/codegen/embedding_backward_split_cpu_template.cpp @@ -12,10 +12,10 @@ #include #include -#include "codegen/embedding_common.h" #include "codegen/embedding_forward_split_cpu.h" #include "fbgemm/FbgemmEmbedding.h" #include "fbgemm/Types.h" +#include "fbgemm_gpu/embedding_common.h" using Tensor = at::Tensor; diff --git a/fbgemm_gpu/codegen/embedding_backward_split_host_cpu_template.cpp b/fbgemm_gpu/codegen/embedding_backward_split_host_cpu_template.cpp index e653ab8e4..d82cf3a26 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_host_cpu_template.cpp +++ b/fbgemm_gpu/codegen/embedding_backward_split_host_cpu_template.cpp @@ -10,8 +10,7 @@ #include #include "codegen/embedding_forward_split_cpu.h" - -#include "codegen/embedding_common.h" +#include "fbgemm_gpu/embedding_common.h" using Tensor = at::Tensor; @@ -186,7 +185,8 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function_cpu( bool gradient_clipping, double max_gradient, bool stochastic_rounding, - {{ args.split_function_args | join(", ") }}) { + {{ args.split_function_args | join(", ") }}, + int64_t output_dtype) { return SplitLookupFunction_{{ optimizer }}_Op::apply( host_weights, weights_placements, @@ -208,7 +208,12 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function_cpu( } TORCH_LIBRARY_FRAGMENT(fb, m) { - m.def("split_embedding_codegen_lookup_{{ optimizer }}_function_cpu(Tensor host_weights, Tensor weights_placements, Tensor weights_offsets, Tensor D_offsets, int total_D, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, bool gradient_clipping, float max_gradient, bool stochastic_rounding, {{ args.split_function_schemas | join(", ") }}) -> Tensor"); + m.def("split_embedding_codegen_lookup_{{ optimizer }}_function_cpu(Tensor host_weights, Tensor weights_placements, Tensor weights_offsets, Tensor D_offsets, int total_D, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, bool gradient_clipping, float max_gradient, bool stochastic_rounding, {{ args.split_function_schemas | join(", ") }}, int output_dtype=0) -> Tensor"); + m.impl("split_embedding_codegen_lookup_{{ optimizer }}_function_cpu", torch::dispatch(c10::DispatchKey::CPU, TORCH_FN(split_embedding_codegen_lookup_{{ optimizer }}_function_cpu))); +} + +TORCH_LIBRARY_FRAGMENT(fbgemm, m) { + m.def("split_embedding_codegen_lookup_{{ optimizer }}_function_cpu(Tensor host_weights, Tensor weights_placements, Tensor weights_offsets, Tensor D_offsets, int total_D, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, bool gradient_clipping, float max_gradient, bool stochastic_rounding, {{ args.split_function_schemas | join(", ") }}, int output_dtype=0) -> Tensor"); m.impl("split_embedding_codegen_lookup_{{ optimizer }}_function_cpu", torch::dispatch(c10::DispatchKey::CPU, TORCH_FN(split_embedding_codegen_lookup_{{ optimizer }}_function_cpu))); } diff --git a/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp index c3f519039..1262fa807 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp @@ -10,7 +10,8 @@ #include #include -#include "codegen/embedding_common.h" +#include "fbgemm_gpu/embedding_common.h" +#include "fbgemm_gpu/sparse_ops_utils.h" using Tensor = at::Tensor; @@ -491,6 +492,12 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function( TORCH_LIBRARY_FRAGMENT(fb, m) { m.def("split_embedding_codegen_lookup_{{ optimizer }}_function(Tensor placeholder_autograd_tensor, Tensor dev_weights, Tensor uvm_weights, Tensor lxu_cache_weights, Tensor weights_placements, Tensor weights_offsets, Tensor D_offsets, int total_D, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, Tensor lxu_cache_locations, bool gradient_clipping, float max_gradient, bool stochastic_rounding, {{ args.split_function_schemas | join(", ") }}, int output_dtype=0) -> Tensor"); - m.impl("split_embedding_codegen_lookup_{{ optimizer }}_function", torch::dispatch(c10::DispatchKey::CUDA, TORCH_FN(split_embedding_codegen_lookup_{{ optimizer }}_function))); + DISPATCH_TO_CUDA("split_embedding_codegen_lookup_{{ optimizer }}_function", split_embedding_codegen_lookup_{{ optimizer }}_function); } + +TORCH_LIBRARY_FRAGMENT(fbgemm, m) { + m.def("split_embedding_codegen_lookup_{{ optimizer }}_function(Tensor placeholder_autograd_tensor, Tensor dev_weights, Tensor uvm_weights, Tensor lxu_cache_weights, Tensor weights_placements, Tensor weights_offsets, Tensor D_offsets, int total_D, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, Tensor lxu_cache_locations, bool gradient_clipping, float max_gradient, bool stochastic_rounding, {{ args.split_function_schemas | join(", ") }}, int output_dtype=0) -> Tensor"); + DISPATCH_TO_CUDA("split_embedding_codegen_lookup_{{ optimizer }}_function", split_embedding_codegen_lookup_{{ optimizer }}_function); +} + // clang-format on diff --git a/fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu index 186c60294..dbe82a089 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu @@ -103,9 +103,9 @@ __launch_bounds__(kForwardMaxThreads) void {{ "dense" if dense else "split" }}_e int32_t cache_idx = (placement == PlacementType::MANAGED_CACHING && l < L) ? lxu_cache_locations[indices_start + l] : 0; {% endif %} for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) { - int64_t idx_j = SHFL_SYNC_MACRO(idx, 0); + int64_t idx_j = shfl_sync(idx, j); {% if not dense %} - int32_t cache_idx_j = SHFL_SYNC_MACRO(cache_idx, 0); + int32_t cache_idx_j = shfl_sync(cache_idx, j); {% endif %} at::acc_type grad_indice_weight = 0.0; @@ -187,6 +187,24 @@ Tensor {{ "dense" if dense else "split" }}_embedding_codegen_grad_indice_weights Tensor lxu_cache_locations, {% endif %} Tensor feature_requires_grad) { + TENSOR_ON_CUDA_GPU(grad_output); + TENSOR_ON_CUDA_GPU(dev_weights); + {% if not dense %} + TENSOR_ON_CUDA_GPU(uvm_weights); + TENSOR_ON_CUDA_GPU(lxu_cache_weights); + TENSOR_ON_CUDA_GPU(weights_placements); + {% endif %} + TENSOR_ON_CUDA_GPU(weights_offsets); + TENSOR_ON_CUDA_GPU(D_offsets); + TENSOR_ON_CUDA_GPU(indices); + TENSOR_ON_CUDA_GPU(offsets); + {% if not dense %} + TENSOR_ON_CUDA_GPU(lxu_cache_locations); + {% endif %} + if (feature_requires_grad.defined()) { + TENSOR_ON_CUDA_GPU(feature_requires_grad); + } + at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(dev_weights.get_device()); const auto T = D_offsets.size(0) - 1; diff --git a/fbgemm_gpu/codegen/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_template.cu index 5f5dbde2d..e80f4041c 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_template.cu @@ -6,7 +6,8 @@ */ // clang-format off {% set wdesc = "weighted" if weighted else "unweighted" %} -#include "codegen/embedding_backward_template_helpers.cuh" +#include "fbgemm_gpu/embedding_backward_template_helpers.cuh" +#include "fbgemm_gpu/split_embeddings_utils.cuh" {% if not dense %} constexpr int32_t kCacheLocationMissing = -1; @@ -106,8 +107,6 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ sorted_linear_indices_run, const at::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, - const at::PackedTensorAccessor32 - sorted_linear_indices_run_lengths, const at::PackedTensorAccessor32 long_run_ids, const at::PackedTensorAccessor32 @@ -190,13 +189,14 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ {% endif %} for (int32_t j = 0; j < kWarpSize && sl + j < sl_end; ++j) { {% if not nobag %} - int32_t b_j = SHFL_SYNC_MACRO(b, j); - int32_t D_start_j = SHFL_SYNC_MACRO(D_start, j); + int32_t b_j = shfl_sync(b, j); + int32_t D_start_j = shfl_sync(D_start, j); {% else %} - int32_t l_j = SHFL_SYNC_MACRO(l, j); + int32_t l_j = shfl_sync(l, j); {% endif %} + {% if weighted %} - at::acc_type idx_weight_j = SHFL_SYNC_MACRO(idx_weight, j); + at::acc_type idx_weight_j = shfl_sync(idx_weight, j); {% endif %} #pragma unroll kMaxVecsPerThread @@ -458,8 +458,6 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ sorted_linear_indices_run, const at::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, - const at::PackedTensorAccessor32 - sorted_linear_indices_run_lengths, {% if not nobag %} const at::PackedTensorAccessor32 sorted_infos, {% else %} @@ -549,26 +547,26 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ for (int32_t j = 0; j < kWarpSize && sl + j < sl_end; ++j) { {% if not nobag %} - int32_t b_j = SHFL_SYNC_MACRO(b, j); - int32_t D_start_j = SHFL_SYNC_MACRO(D_start, j); + int32_t b_j = shfl_sync(b, j); + int32_t D_start_j = shfl_sync(D_start, j); {% else %} - int32_t l_j = SHFL_SYNC_MACRO(l, j); + int32_t l_j = shfl_sync(l, j); {% endif %} {% if weighted %} - at::acc_type idx_weight_j = SHFL_SYNC_MACRO(idx_weight, j); + at::acc_type idx_weight_j = shfl_sync(idx_weight, j); {% endif %} - #pragma unroll kMaxVecsPerThread - for (int32_t i = 0; - i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D; - ++i) { - int32_t d = 4 * kWarpSize * i + threadIdx.x * 4; - {% if not nobag %} - Vec4T> grad_out_vec( - &grad_output[b_j][0] + D_start_j + d); - {% else %} - Vec4T> grad_out_vec(&grad_output[l_j][d]); - {% endif %} + #pragma unroll kMaxVecsPerThread + for (int32_t i = 0; + i < kMaxVecsPerThread && 4 * kWarpSize * i + threadIdx.x * 4 < D; + ++i) { + int32_t d = 4 * kWarpSize * i + threadIdx.x * 4; + {% if not nobag %} + Vec4T> grad_out_vec( + &grad_output[b_j][0] + D_start_j + d); + {% else %} + Vec4T> grad_out_vec(&grad_output[l_j][d]); + {% endif %} {% if weighted %} grad_sum[i].fma_(grad_out_vec, idx_weight_j); {% else %} @@ -706,6 +704,28 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ bool stochastic_rounding, {% endif %} {{ args.split_function_args | join(", ") }}) { + + TENSOR_ON_CUDA_GPU(grad_output); + TENSOR_ON_CUDA_GPU(dev_weights); + {% if not dense %} + TENSOR_ON_CUDA_GPU(uvm_weights); + TENSOR_ON_CUDA_GPU(lxu_cache_weights); + TENSOR_ON_CUDA_GPU(weights_placements); + {% endif %} + TENSOR_ON_CUDA_GPU(weights_offsets); + {% if not nobag %} + TENSOR_ON_CUDA_GPU(D_offsets); + {% endif %} + TENSOR_ON_CUDA_GPU(hash_size_cumsum); + TENSOR_ON_CUDA_GPU(indices); + TENSOR_ON_CUDA_GPU(offsets); + {% if weighted %} + TENSOR_ON_CUDA_GPU(indice_weights); + {% endif %} + {% if not dense %} + TENSOR_ON_CUDA_GPU(lxu_cache_locations); + {% endif %} + at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(dev_weights.get_device()); @@ -736,10 +756,14 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ TORCH_CHECK(D <= {{ max_embedding_dim }}); {% endif %} -#ifndef __HIP_PLATFORM_HCC__ // V100: 96 KB; A100: 160 KB. int max_shared_bytes = 0; +#ifndef __HIP_PLATFORM_HCC__ cudaDeviceGetAttribute(&max_shared_bytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev_weights.get_device()); +#else + // MI100 has 64 KB local memory (shared memory) per workgroup + max_shared_bytes = 64 << 10; +#endif C10_CUDA_KERNEL_LAUNCH_CHECK(); int shared_kb = max_shared_bytes >> 10; // V100: 64 KB; A100: 96 KB. @@ -747,80 +771,27 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ int used_shared_kb = round_down(shared_kb * 2 / 3, 16); TORCH_CHECK(used_shared_kb > 0); int used_shared_bytes = used_shared_kb << 10; -#endif - {% if not nobag %} - auto infos = at::empty_like(indices, indices.options().dtype(at::kInt)); - {% else %} - auto infos = at::empty_like(indices, indices.options().dtype(at::kLong)); - {% endif %} - auto infos_sorted = at::empty_like(infos); - auto linear_indices = at::empty_like(indices); - auto linear_indices_sorted = at::empty_like(indices); - {% if not nobag %} - linearize_index_kernel<<< - div_round_up(B * T, kMaxThreads), - kMaxThreads, - 0, - at::cuda::getCurrentCUDAStream()>>>( - hash_size_cumsum.packed_accessor32(), - indices.packed_accessor32(), - offsets.packed_accessor32(), - infos.packed_accessor32(), - linear_indices.packed_accessor32()); - {% else %} - nobag_linearize_index_kernel<<< - div_round_up(B * T, kMaxThreads), - kMaxThreads, - 0, - at::cuda::getCurrentCUDAStream()>>>( - hash_size_cumsum.packed_accessor32(), - indices.packed_accessor32(), - offsets.packed_accessor32(), - infos.packed_accessor32(), - linear_indices.packed_accessor32()); - {% endif %} - C10_CUDA_KERNEL_LAUNCH_CHECK(); - { - size_t temp_storage_bytes = 0; - AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( - nullptr, - temp_storage_bytes, - linear_indices.data_ptr(), - linear_indices_sorted.data_ptr(), - {% if not nobag %} - infos.data_ptr(), - infos_sorted.data_ptr(), - {% else %} - infos.data_ptr(), - infos_sorted.data_ptr(), - {% endif %} - linear_indices.numel(), - 0, - total_hash_size_bits, - at::cuda::getCurrentCUDAStream(), - false)); - auto temp_storage = at::empty( - {static_cast(temp_storage_bytes)}, - indices.options().dtype(at::kByte)); - AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( - temp_storage.data_ptr(), - temp_storage_bytes, - linear_indices.data_ptr(), - linear_indices_sorted.data_ptr(), - {% if not nobag %} - infos.data_ptr(), - infos_sorted.data_ptr(), - {% else %} - infos.data_ptr(), - infos_sorted.data_ptr(), - {% endif %} - linear_indices.numel(), - 0, + Tensor linear_indices, linear_indices_sorted; + Tensor infos_sorted; + Tensor sorted_linear_indices_run, sorted_linear_indices_run_lengths, + sorted_linear_indices_num_runs, + sorted_linear_indices_cumulative_run_lengths; + std::tie( + linear_indices, + linear_indices_sorted, + infos_sorted, + sorted_linear_indices_run, + sorted_linear_indices_run_lengths, + sorted_linear_indices_num_runs, + sorted_linear_indices_cumulative_run_lengths) = + transpose_embedding_input( + hash_size_cumsum, total_hash_size_bits, - at::cuda::getCurrentCUDAStream(), - false)); - } + indices, + offsets, + {{"true" if nobag else "false"}}); + {% if not dense %} auto lxu_cache_locations_sorted = at::empty_like(lxu_cache_locations); if (lxu_cache_locations.size(0) > 0) { @@ -854,41 +825,6 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ false)); } {% endif %} - auto sorted_linear_indices_run = at::empty_like(indices); - auto sorted_linear_indices_run_lengths = - at::zeros_like(indices, indices.options().dtype(at::kInt)); - auto sorted_linear_indices_num_runs = - at::zeros({1}, indices.options().dtype(at::kInt)); - - { - size_t temp_storage_bytes = 0; - AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRunLengthEncode::Encode( - nullptr, - temp_storage_bytes, - linear_indices_sorted.data_ptr(), - sorted_linear_indices_run.data_ptr(), - sorted_linear_indices_run_lengths.data_ptr(), - sorted_linear_indices_num_runs.data_ptr(), - linear_indices_sorted.numel(), - at::cuda::getCurrentCUDAStream())); - // Allocate temporary storage - auto temp_storage = at::empty( - {static_cast(temp_storage_bytes)}, - indices.options().dtype(at::kByte)); - // Run encoding - AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRunLengthEncode::Encode( - temp_storage.data_ptr(), - temp_storage_bytes, - linear_indices_sorted.data_ptr(), - sorted_linear_indices_run.data_ptr(), - sorted_linear_indices_run_lengths.data_ptr(), - sorted_linear_indices_num_runs.data_ptr(), - linear_indices_sorted.numel(), - at::cuda::getCurrentCUDAStream())); - } - - auto sorted_linear_indices_cumulative_run_lengths = - asynchronous_complete_cumsum(sorted_linear_indices_run_lengths); {% if not dense %} DISPATCH_EMB_CACHE_TYPES( @@ -946,6 +882,10 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ } {% endif %} + // early memory release + linear_indices.reset(); + linear_indices_sorted.reset(); + auto grad_output_accessor = grad_output.packed_accessor32< at::acc_type<{{ "scalar_t" if dense else "cache_t" }}, true>, 2, @@ -1041,6 +981,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ used_shared_bytes); // V100: 64 KB; A100: 96 KB. #endif C10_CUDA_KERNEL_LAUNCH_CHECK(); + // dividing by kMaxThreads is a heuristic to avoid num of blocks far exceeding num_long_run_ids[0] split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_kernel_cta_per_row_1< {% if not dense %} emb_t, @@ -1050,7 +991,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ scalar_t, {% endif %} {{ kMaxVecsPerThread }}> - <<) * 4 * kWarpSize * {{ kMaxVecsPerThread }}, @@ -1076,8 +1017,6 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ .packed_accessor32(), sorted_linear_indices_cumulative_run_lengths .packed_accessor32(), - sorted_linear_indices_run_lengths - .packed_accessor32(), long_run_ids.packed_accessor32(), num_long_run_ids.packed_accessor32(), {% if not nobag %} @@ -1126,7 +1065,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ scalar_t, {% endif %} {{ kMaxVecsPerThread }}> - <<(), sorted_linear_indices_cumulative_run_lengths .packed_accessor32(), - sorted_linear_indices_run_lengths - .packed_accessor32(), {% if not nobag %} infos_sorted.packed_accessor32(), {% else %} diff --git a/fbgemm_gpu/codegen/embedding_bounds_check.cu b/fbgemm_gpu/codegen/embedding_bounds_check.cu index 9b1c706e5..dc114dded 100644 --- a/fbgemm_gpu/codegen/embedding_bounds_check.cu +++ b/fbgemm_gpu/codegen/embedding_bounds_check.cu @@ -4,7 +4,7 @@ * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ -#include "codegen/embedding_backward_template_helpers.cuh" +#include "fbgemm_gpu/embedding_backward_template_helpers.cuh" using Tensor = at::Tensor; using namespace fbgemm_gpu; @@ -69,11 +69,8 @@ __global__ void bounds_check_indices_kernel( } auto L = indices_end - indices_start; -#ifdef __HIP_PLATFORM_HCC__ - for (index_t i = (index_t) threadIdx.x; i < L; i += (index_t) fbgemm_gpu::kWarpSize) { -#else - for (auto i = threadIdx.x; i < L; i += fbgemm_gpu::kWarpSize) { -#endif + for (index_t i = (index_t)threadIdx.x; i < L; + i += (index_t)fbgemm_gpu::kWarpSize) { auto idx = indices[indices_start + i]; if (idx == -1) { // -1 indicates pruned rows. @@ -114,6 +111,11 @@ void bounds_check_indices_cuda( Tensor offsets, int64_t bounds_check_mode_, Tensor warning) { + TENSOR_ON_CUDA_GPU(rows_per_table); + TENSOR_ON_CUDA_GPU(indices); + TENSOR_ON_CUDA_GPU(offsets); + TENSOR_ON_CUDA_GPU(warning); + at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(rows_per_table.get_device()); diff --git a/fbgemm_gpu/codegen/embedding_bounds_check_host.cpp b/fbgemm_gpu/codegen/embedding_bounds_check_host.cpp index 69c0d9c22..45354a565 100644 --- a/fbgemm_gpu/codegen/embedding_bounds_check_host.cpp +++ b/fbgemm_gpu/codegen/embedding_bounds_check_host.cpp @@ -8,7 +8,9 @@ #include #include #include -#include +#include + +#include "fbgemm_gpu/sparse_ops_utils.h" using Tensor = at::Tensor; @@ -24,8 +26,13 @@ TORCH_LIBRARY_FRAGMENT(fb, m) { // or DCE'd, etc. m.def( "bounds_check_indices(Tensor rows_per_table, Tensor(a!) indices, Tensor(a!) offsets, int bounds_check_mode, Tensor(a!) warning) -> ()"); - m.impl( - "bounds_check_indices", - torch::dispatch( - c10::DispatchKey::CUDA, TORCH_FN(bounds_check_indices_cuda))); + DISPATCH_TO_CUDA("bounds_check_indices", bounds_check_indices_cuda); +} + +TORCH_LIBRARY_FRAGMENT(fbgemm, m) { + // The (a!) tells PyTorch this is an impure operation and so cannot be CSE'd + // or DCE'd, etc. + m.def( + "bounds_check_indices(Tensor rows_per_table, Tensor(a!) indices, Tensor(a!) offsets, int bounds_check_mode, Tensor(a!) warning) -> ()"); + DISPATCH_TO_CUDA("bounds_check_indices", bounds_check_indices_cuda); } diff --git a/fbgemm_gpu/codegen/embedding_bounds_check_host_cpu.cpp b/fbgemm_gpu/codegen/embedding_bounds_check_host_cpu.cpp index 8c82d7d9e..accb6767b 100644 --- a/fbgemm_gpu/codegen/embedding_bounds_check_host_cpu.cpp +++ b/fbgemm_gpu/codegen/embedding_bounds_check_host_cpu.cpp @@ -8,7 +8,7 @@ #include #include #include -#include "codegen/embedding_common.h" +#include "fbgemm_gpu/embedding_common.h" using Tensor = at::Tensor; @@ -110,3 +110,10 @@ TORCH_LIBRARY_FRAGMENT(fb, m) { torch::dispatch( c10::DispatchKey::CPU, TORCH_FN(bounds_check_indices_cpu))); } + +TORCH_LIBRARY_FRAGMENT(fbgemm, m) { + m.impl( + "bounds_check_indices", + torch::dispatch( + c10::DispatchKey::CPU, TORCH_FN(bounds_check_indices_cpu))); +} diff --git a/fbgemm_gpu/codegen/embedding_forward_quantized_cpu_template.cpp b/fbgemm_gpu/codegen/embedding_forward_quantized_cpu_template.cpp index b6aea36c6..749e7a9f7 100644 --- a/fbgemm_gpu/codegen/embedding_forward_quantized_cpu_template.cpp +++ b/fbgemm_gpu/codegen/embedding_forward_quantized_cpu_template.cpp @@ -10,8 +10,8 @@ #include #include -#include "codegen/embedding_common.h" #include "fbgemm_gpu/dispatch_macros.h" +#include "fbgemm_gpu/embedding_common.h" #include #include diff --git a/fbgemm_gpu/codegen/embedding_forward_quantized_host.cpp b/fbgemm_gpu/codegen/embedding_forward_quantized_host.cpp index 13c8d7216..6c6e9db88 100644 --- a/fbgemm_gpu/codegen/embedding_forward_quantized_host.cpp +++ b/fbgemm_gpu/codegen/embedding_forward_quantized_host.cpp @@ -8,8 +8,10 @@ #include #include #include -#include +#include #include "c10/core/ScalarType.h" +#include "fbgemm_gpu/embedding_common.h" +#include "fbgemm_gpu/sparse_ops_utils.h" using Tensor = at::Tensor; @@ -56,6 +58,25 @@ Tensor int_nbit_split_embedding_codegen_forward_weighted_cuda( Tensor lxu_cache_locations, int64_t unused); +Tensor int_nbit_split_embedding_nobag_codegen_forward_unweighted_cuda( + Tensor dev_weights, + Tensor uvm_weights, + Tensor weights_placements, + Tensor weights_offsets, + Tensor weights_tys, + int64_t D, + int64_t max_int2_D, + int64_t max_int4_D, + int64_t max_int8_D, + int64_t max_float16_D, + int64_t max_float32_D, + Tensor indices, + Tensor offsets, + int64_t output_dtype, + Tensor lxu_cache_weights, + Tensor lxu_cache_locations, + int64_t unused); + Tensor int_nbit_split_embedding_codegen_lookup_function( Tensor dev_weights, Tensor uvm_weights, @@ -76,6 +97,29 @@ Tensor int_nbit_split_embedding_codegen_lookup_function( int64_t output_dtype, c10::optional lxu_cache_weights, c10::optional lxu_cache_locations) { + if (static_cast(pooling_mode) == PoolingMode::NONE) { + std::vector max_D_list{ + max_int2_D, max_int4_D, max_int8_D, max_float16_D, max_float32_D}; + int64_t max_D = *std::max_element(max_D_list.begin(), max_D_list.end()); + return int_nbit_split_embedding_nobag_codegen_forward_unweighted_cuda( + dev_weights, + uvm_weights, + weights_placements, + weights_offsets, + weights_tys, + max_D, + max_int2_D, + max_int4_D, + max_int8_D, + max_float16_D, + max_float32_D, + indices, + offsets, + output_dtype, + lxu_cache_weights.value_or(at::empty({0, 0}, at::kByte)), + lxu_cache_locations.value_or(at::empty({0}, at::kInt)), + 0); + } if (!indice_weights) { return int_nbit_split_embedding_codegen_forward_unweighted_cuda( dev_weights, @@ -136,24 +180,33 @@ Tensor pruned_array_lookup_cuda( TORCH_LIBRARY_FRAGMENT(fb, m) { m.def( "int_nbit_split_embedding_codegen_lookup_function(Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, int total_D, int max_int2_D, int max_int4_D, int max_int8_D, int max_float16_D, int max_float32_D, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, int output_dtype=1, Tensor? lxu_cache_weights=None, Tensor? lxu_cache_locations=None) -> Tensor"); - m.impl( + DISPATCH_TO_CUDA( + "int_nbit_split_embedding_codegen_lookup_function", + int_nbit_split_embedding_codegen_lookup_function); + + m.def( + "pruned_hashmap_lookup(Tensor indices, Tensor offsets, Tensor hash_table, Tensor hash_table_offsets) -> Tensor"); + DISPATCH_TO_CUDA( + "pruned_hashmap_lookup", pruned_hashmap_lookup_unweighted_cuda); + + m.def( + "pruned_array_lookup(Tensor indices, Tensor offsets, Tensor index_remappings, Tensor index_remappings_offsets) -> Tensor"); + DISPATCH_TO_CUDA("pruned_array_lookup", pruned_array_lookup_cuda); +} + +TORCH_LIBRARY_FRAGMENT(fbgemm, m) { + m.def( + "int_nbit_split_embedding_codegen_lookup_function(Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, int total_D, int max_int2_D, int max_int4_D, int max_int8_D, int max_float16_D, int max_float32_D, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, int output_dtype=1, Tensor? lxu_cache_weights=None, Tensor? lxu_cache_locations=None) -> Tensor"); + DISPATCH_TO_CUDA( "int_nbit_split_embedding_codegen_lookup_function", - torch::dispatch( - c10::DispatchKey::CUDA, - TORCH_FN(int_nbit_split_embedding_codegen_lookup_function))); + int_nbit_split_embedding_codegen_lookup_function); m.def( "pruned_hashmap_lookup(Tensor indices, Tensor offsets, Tensor hash_table, Tensor hash_table_offsets) -> Tensor"); - m.impl( - "pruned_hashmap_lookup", - torch::dispatch( - c10::DispatchKey::CUDA, - TORCH_FN(pruned_hashmap_lookup_unweighted_cuda))); + DISPATCH_TO_CUDA( + "pruned_hashmap_lookup", pruned_hashmap_lookup_unweighted_cuda); m.def( "pruned_array_lookup(Tensor indices, Tensor offsets, Tensor index_remappings, Tensor index_remappings_offsets) -> Tensor"); - m.impl( - "pruned_array_lookup", - torch::dispatch( - c10::DispatchKey::CUDA, TORCH_FN(pruned_array_lookup_cuda))); + DISPATCH_TO_CUDA("pruned_array_lookup", pruned_array_lookup_cuda); } diff --git a/fbgemm_gpu/codegen/embedding_forward_quantized_host_cpu.cpp b/fbgemm_gpu/codegen/embedding_forward_quantized_host_cpu.cpp index 07d3a12ee..3219d0339 100644 --- a/fbgemm_gpu/codegen/embedding_forward_quantized_host_cpu.cpp +++ b/fbgemm_gpu/codegen/embedding_forward_quantized_host_cpu.cpp @@ -151,6 +151,38 @@ TORCH_LIBRARY_FRAGMENT(fb, m) { c10::DispatchKey::CPU, TORCH_FN(pruned_array_lookup_cpu))); } +TORCH_LIBRARY_FRAGMENT(fbgemm, m) { + m.impl( + "int_nbit_split_embedding_codegen_lookup_function", + torch::dispatch( + c10::DispatchKey::CPU, + TORCH_FN(int_nbit_split_embedding_codegen_lookup_function_cpu))); + + // GPU version of pruned_hashmap needs to use CPU version of + // pruned_hashmap_insert + m.def( + "pruned_hashmap_insert(Tensor indices, Tensor dense_indices, Tensor offsets, Tensor hash_table, Tensor hash_table_offsets) -> ()"); + m.impl( + "pruned_hashmap_insert", + torch::dispatch( + c10::DispatchKey::CPU, + TORCH_FN(pruned_hashmap_insert_unweighted_cpu))); + + // CPU version of hashmap Lookup isn't used. For CPUs, we should use + // PrunedMapCPU below. + m.impl( + "pruned_hashmap_lookup", + torch::dispatch( + c10::DispatchKey::CPU, + TORCH_FN(pruned_hashmap_lookup_unweighted_cpu))); + + // CPU version of array lookup. + m.impl( + "pruned_array_lookup", + torch::dispatch( + c10::DispatchKey::CPU, TORCH_FN(pruned_array_lookup_cpu))); +} + class PrunedMapCPU : public torch::jit::CustomClassHolder { public: PrunedMapCPU() {} diff --git a/fbgemm_gpu/codegen/embedding_forward_quantized_split_template.cu b/fbgemm_gpu/codegen/embedding_forward_quantized_split_template.cu index fa2b56df9..85a08e153 100644 --- a/fbgemm_gpu/codegen/embedding_forward_quantized_split_template.cu +++ b/fbgemm_gpu/codegen/embedding_forward_quantized_split_template.cu @@ -170,632 +170,28 @@ void cp_async_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard) { #endif } +{% for nobag in [True, False] %} +{% if not nobag or not weighted %} // TODO: increase code sharing (templates for accumulator_ty, accumulation, outputs per thread, etc?) +{% for bit_width in [32, 16, 8, 4] %} template __launch_bounds__(WarpsPerBlock * kWarpSize) -__global__ void fp32_split_embedding_codegen_forward_{{ wdesc }}_kernel_small_L( +__global__ void {{ type_map[bit_width].enum_name }}_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_kernel_small_L( const at::PackedTensorAccessor64 dev_weights, const at::PackedTensorAccessor64 uvm_weights, const at::PackedTensorAccessor32 weights_placements, const at::PackedTensorAccessor32 weights_offsets, const at::PackedTensorAccessor32 weights_tys, + {% if not nobag %} const at::PackedTensorAccessor32 D_offsets, - const at::PackedTensorAccessor32 indices, - const at::PackedTensorAccessor32 offsets, - int64_t pooling_mode, - {% if weighted %} - at::PackedTensorAccessor32 - indice_weights, - {% endif %} - at::PackedTensorAccessor32 - output, // [B][total_D], - const at::PackedTensorAccessor64 lxu_cache_weights, - const at::PackedTensorAccessor32 lxu_cache_locations - ) { - int32_t B = output.size(0); - int32_t T = D_offsets.size(0) - 1; - int32_t bb_t = blockIdx.x * blockDim.y + threadIdx.y; - if (bb_t >= div_round_up(B, OutputRowsPerThread) * T) { - return; - } - static_assert( - std::is_same::value || std::is_same::value || std::is_same::value, - "output_t can only be float or half or bytes now" - ); - - uint32_t t = bb_t / div_round_up(B, OutputRowsPerThread); - - int32_t D_start = D_offsets[t]; - int32_t D_end = D_offsets[t + 1]; - int32_t D = D_end - D_start; - SparseType weight_ty = static_cast(weights_tys[t]); - if (weight_ty != SparseType::FP32) { - return; - } - - const int32_t D_bytes = padded_row_size_in_bytes(D, weight_ty); - - if (D_bytes <= MinNum128BRows * 128 || D_bytes > MaxNum128BRows * 128) { - return; - } - - uint32_t bb = bb_t % div_round_up(B, OutputRowsPerThread); - - int64_t weights_offset = weights_offsets[t]; - const int32_t D_total = padded_D(D, weight_ty); - const int32_t D_padding = D_total - D; - - uint32_t warp_idx = threadIdx.y; - int32_t indices_starts[OutputRowsPerThread]; - int32_t Ls[OutputRowsPerThread]; - int32_t max_Ls = 0; - - for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { - uint32_t b = min(static_cast(bb * OutputRowsPerThread + i), static_cast(B - 1)); - int32_t indices_start = offsets[t * B + b]; - int32_t indices_end = offsets[t * B + b + 1]; - indices_starts[i] = indices_start; - Ls[i] = indices_end - indices_start; - max_Ls = max(max_Ls, Ls[i]); - } - - const uint8_t* __restrict__ weights; - const auto placement = static_cast(weights_placements[t]); - if (placement == PlacementType::DEVICE) { - weights = &dev_weights[weights_offset]; - } else { - weights = &uvm_weights[weights_offset]; - } - constexpr size_t kOutputsPerThread = 1; - - constexpr uint32_t NumUint4PerRow = MaxNum128BRows * 128 / sizeof(uint4); - const uint32_t uint4_loads_per_row = div_round_up(D_bytes, sizeof(uint4)); - - VecNT<1> accumulators[OutputRowsPerThread][MaxNum128BRows]; - - for (uint32_t L_start = 0; L_start < max_Ls; L_start += InputRowsInFlight) { - uint32_t input_rows_in_flight = min(static_cast(InputRowsInFlight), max_Ls - L_start); - typedef uint4 AllBuffers[WarpsPerBlock][OutputRowsPerThread][InputRowsInFlight][NumUint4PerRow]; - __shared__ AllBuffers buffers; - - {% if weighted %} - typedef float AllIndiceWeights[WarpsPerBlock][OutputRowsPerThread][InputRowsInFlight]; - __shared__ AllIndiceWeights buffers_indice_weights; - {% endif %} - - for (uint32_t load_idx = threadIdx.x; load_idx < input_rows_in_flight * uint4_loads_per_row; load_idx += kWarpSize) { - uint32_t row_load_idx = load_idx % uint4_loads_per_row; - uint32_t input_row_idx = (load_idx / uint4_loads_per_row); - - #pragma unroll OutputRowsPerThread - for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { - bool valid = L_start + input_row_idx < Ls[i]; - bool cache_valid = (placement == PlacementType::MANAGED_CACHING && valid); - int32_t idx = valid ? indices[indices_starts[i] + L_start + input_row_idx] : -1; - int32_t cache_idx = cache_valid ? lxu_cache_locations[indices_starts[i] + L_start + input_row_idx] : -1; - valid = valid && (idx != -1); - const uint4* row; - if (cache_valid && cache_idx != kCacheLocationMissing) { - row = reinterpret_cast(&lxu_cache_weights[static_cast(cache_idx)][0]); - } else if (valid) { - row = reinterpret_cast(&weights[static_cast(idx) * D_bytes]); - } else { - row = reinterpret_cast(&weights[0]); - } - cp_async_zfill_cg(&buffers[warp_idx][i][input_row_idx][row_load_idx], &row[row_load_idx], valid); - - {% if weighted %} - buffers_indice_weights[warp_idx][i][input_row_idx] = valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0; - {% endif %} - } - } - // equivalent to fence + wait. - cp_async_wait<0>(); -#ifdef __HIP_PLATFORM_HCC__ - __syncthreads(); -#else - __syncwarp(); -#endif - for (uint32_t input_row_idx = 0; input_row_idx < input_rows_in_flight; ++input_row_idx) { - #pragma unroll OutputRowsPerThread - for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { - bool valid = L_start + input_row_idx < Ls[i]; - const uint32_t* row = reinterpret_cast(&buffers[warp_idx][i][input_row_idx][0]); - {% if weighted %} - float row_weight = buffers_indice_weights[warp_idx][i][input_row_idx]; - {% endif %} - #pragma unroll MaxNum128BRows - for (uint32_t j = 0; j < MaxNum128BRows; ++j) { - float v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; - if (valid) { - {% if weighted %} - accumulators[i][j].fma(v, row_weight); - {% else %} - accumulators[i][j].add(v); - {% endif %} - } - } - } - } - } - #pragma unroll OutputRowsPerThread - for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { - uint32_t b = min(static_cast(bb * OutputRowsPerThread + i), static_cast(B - 1)); - float inv_L = 1.0 / Ls[i]; - - if (std::is_same::value || std::is_same::value) { - #pragma unroll MaxNum128BRows - for (uint32_t j = 0; j < MaxNum128BRows; ++j) { - int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; - if (static_cast(pooling_mode) == PoolingMode::MEAN && Ls[i] != 0) { - accumulators[i][j].mul(inv_L); - } - if (output_d >= 0 && output_d < D) { - accumulators[i][j].store(&output[b][D_start + output_d]); - } - } - } else if (std::is_same::value) { - // INT8: - // apply per feature row-wise int8 - float thread_local_min = std::numeric_limits::max(); - float thread_local_max = std::numeric_limits::lowest(); - float2 qparams; - #pragma unroll MaxNum128BRows - for (uint32_t j = 0; j < MaxNum128BRows; ++j) { - int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; - if (static_cast(pooling_mode) == PoolingMode::MEAN && Ls[i] != 0) { - accumulators[i][j].mul(inv_L); - } - if (output_d >= 0 && output_d < D) { - thread_local_max = max(thread_local_max, accumulators[i][j].acc); - thread_local_min = min(thread_local_min, accumulators[i][j].acc); - } - } - qparams = warp_find_qparams(thread_local_min, thread_local_max); - int output_D_start = D_start + t * 8; - int output_D_end = output_D_start + D; - #pragma unroll MaxNum128BRows - for (uint32_t j = 0; j < MaxNum128BRows; ++j) { - int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; - if (output_d >= 0 && output_d < D) { - accumulators[i][j].store(&output[b][output_D_start + output_d], qparams); - } - } - if (threadIdx.x == 0) { - store_qparams_to_row(&output[b][output_D_end], qparams); - } - } else { - // INT4: not implemented yet - } - } -} - -// TODO: increase code sharing (templates for accumulator_ty, accumulation, outputs per thread, etc?) -template -__launch_bounds__(WarpsPerBlock * kWarpSize) -__global__ void fp16_split_embedding_codegen_forward_{{ wdesc }}_kernel_small_L( - const at::PackedTensorAccessor64 dev_weights, - const at::PackedTensorAccessor64 uvm_weights, - const at::PackedTensorAccessor32 weights_placements, - const at::PackedTensorAccessor32 weights_offsets, - const at::PackedTensorAccessor32 weights_tys, - const at::PackedTensorAccessor32 D_offsets, - const at::PackedTensorAccessor32 indices, - const at::PackedTensorAccessor32 offsets, - int64_t pooling_mode, - {% if weighted %} - at::PackedTensorAccessor32 - indice_weights, - {% endif %} - at::PackedTensorAccessor32 - output, // [B][total_D], - const at::PackedTensorAccessor64 lxu_cache_weights, - const at::PackedTensorAccessor32 lxu_cache_locations - ) { - int32_t B = output.size(0); - int32_t T = D_offsets.size(0) - 1; - int32_t bb_t = blockIdx.x * blockDim.y + threadIdx.y; - if (bb_t >= div_round_up(B, OutputRowsPerThread) * T) { - return; - } - static_assert( - std::is_same::value || std::is_same::value || std::is_same::value, - "output_t can only be float or half or bytes now" - ); - - uint32_t t = bb_t / div_round_up(B, OutputRowsPerThread); - - int32_t D_start = D_offsets[t]; - int32_t D_end = D_offsets[t + 1]; - int32_t D = D_end - D_start; - SparseType weight_ty = static_cast(weights_tys[t]); - if (weight_ty != SparseType::FP16) { - return; - } - - const int32_t D_bytes = padded_row_size_in_bytes(D, weight_ty); - - if (D_bytes <= MinNum128BRows * 128 || D_bytes > MaxNum128BRows * 128) { - return; - } - - uint32_t bb = bb_t % div_round_up(B, OutputRowsPerThread); - - int64_t weights_offset = weights_offsets[t]; - const int32_t D_total = padded_D(D, weight_ty); - const int32_t D_padding = D_total - D; - - uint32_t warp_idx = threadIdx.y; - int32_t indices_starts[OutputRowsPerThread]; - int32_t Ls[OutputRowsPerThread]; - int32_t max_Ls = 0; - - for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { - uint32_t b = min(static_cast(bb * OutputRowsPerThread + i), static_cast(B - 1)); - int32_t indices_start = offsets[t * B + b]; - int32_t indices_end = offsets[t * B + b + 1]; - indices_starts[i] = indices_start; - Ls[i] = indices_end - indices_start; - max_Ls = max(max_Ls, Ls[i]); - } - - const uint8_t* __restrict__ weights; - const auto placement = static_cast(weights_placements[t]); - if (placement == PlacementType::DEVICE) { - weights = &dev_weights[weights_offset]; - } else { - weights = &uvm_weights[weights_offset]; - } - constexpr size_t kOutputsPerThread = 2; - - constexpr uint32_t NumUint4PerRow = MaxNum128BRows * 128 / sizeof(uint4); - const uint32_t uint4_loads_per_row = div_round_up(D_bytes, sizeof(uint4)); - - VecNT<2> accumulators[OutputRowsPerThread][MaxNum128BRows]; - - for (uint32_t L_start = 0; L_start < max_Ls; L_start += InputRowsInFlight) { - uint32_t input_rows_in_flight = min(static_cast(InputRowsInFlight), max_Ls - L_start); - - typedef uint4 AllBuffers[WarpsPerBlock][OutputRowsPerThread][InputRowsInFlight][NumUint4PerRow]; - __shared__ AllBuffers buffers; - - {% if weighted %} - typedef float AllIndiceWeights[WarpsPerBlock][OutputRowsPerThread][InputRowsInFlight]; - __shared__ AllIndiceWeights buffers_indice_weights; - {% endif %} - - for (uint32_t load_idx = threadIdx.x; load_idx < input_rows_in_flight * uint4_loads_per_row; load_idx += kWarpSize) { - uint32_t row_load_idx = load_idx % uint4_loads_per_row; - uint32_t input_row_idx = (load_idx / uint4_loads_per_row); - - #pragma unroll OutputRowsPerThread - for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { - bool valid = L_start + input_row_idx < Ls[i]; - bool cache_valid = (placement == PlacementType::MANAGED_CACHING && valid); - int32_t idx = valid ? indices[indices_starts[i] + L_start + input_row_idx] : -1; - int32_t cache_idx = cache_valid ? lxu_cache_locations[indices_starts[i] + L_start + input_row_idx] : -1; - valid = valid && (idx != -1); - const uint4* row; - if (cache_valid && cache_idx != kCacheLocationMissing) { - row = reinterpret_cast(&lxu_cache_weights[static_cast(cache_idx)][0]); - } else if (valid) { - row = reinterpret_cast(&weights[static_cast(idx) * D_bytes]); - } else { - row = reinterpret_cast(&weights[0]); - } - cp_async_zfill_cg(&buffers[warp_idx][i][input_row_idx][row_load_idx], &row[row_load_idx], valid); - - {% if weighted %} - buffers_indice_weights[warp_idx][i][input_row_idx] = valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0; - {% endif %} - } - } - // equivalent to fence + wait. - cp_async_wait<0>(); -#ifdef __HIP_PLATFORM_HCC__ - __syncthreads(); -#else - __syncwarp(); -#endif - for (uint32_t input_row_idx = 0; input_row_idx < input_rows_in_flight; ++input_row_idx) { - #pragma unroll OutputRowsPerThread - for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { - bool valid = L_start + input_row_idx < Ls[i]; - const uint32_t* row = reinterpret_cast(&buffers[warp_idx][i][input_row_idx][0]); - - {% if weighted %} - float row_weight = buffers_indice_weights[warp_idx][i][input_row_idx]; - {% endif %} - - #pragma unroll MaxNum128BRows - for (uint32_t j = 0; j < MaxNum128BRows; ++j) { - __half2 v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; - - if (valid) { - {% if weighted %} - accumulators[i][j].fma(v, row_weight); - {% else %} - accumulators[i][j].add(v); - {% endif %} - } - } - } - } - } - - #pragma unroll OutputRowsPerThread - for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { - uint32_t b = min(static_cast(bb * OutputRowsPerThread + i), static_cast(B - 1)); - float inv_L = 1.0 / Ls[i]; - - if (std::is_same::value || std::is_same::value) { - - #pragma unroll MaxNum128BRows - for (uint32_t j = 0; j < MaxNum128BRows; ++j) { - int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; - if (static_cast(pooling_mode) == PoolingMode::MEAN && Ls[i] != 0) { - accumulators[i][j].mul(inv_L); - } - if (output_d >= 0 && output_d < D) { - accumulators[i][j].store(&output[b][D_start + output_d]); - } - } - } else if (std::is_same::value) { - // INT8: - // apply per feature row-wise int8 - float thread_local_min = std::numeric_limits::max(); - float thread_local_max = std::numeric_limits::lowest(); - float2 qparams; - #pragma unroll MaxNum128BRows - for (uint32_t j = 0; j < MaxNum128BRows; ++j) { - int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; - if (static_cast(pooling_mode) == PoolingMode::MEAN && Ls[i] != 0) { - accumulators[i][j].mul(inv_L); - } - if (output_d >= 0 && output_d < D) { - thread_local_max = max(thread_local_max, max(accumulators[i][j].acc.x, accumulators[i][j].acc.y)); - thread_local_min = min(thread_local_min, min(accumulators[i][j].acc.x, accumulators[i][j].acc.y)); - } - } - - qparams = warp_find_qparams(thread_local_min, thread_local_max); - int output_D_start = D_start + t * 8; - int output_D_end = output_D_start + D; - #pragma unroll MaxNum128BRows - for (uint32_t j = 0; j < MaxNum128BRows; ++j) { - int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; - if (output_d >= 0 && output_d < D) { - accumulators[i][j].store(&output[b][output_D_start + output_d], qparams); - } - } - if (threadIdx.x == 0) { - store_qparams_to_row(&output[b][output_D_end], qparams); - } - } else { - // INT4: not implemented yet - } - } -} - -template -__launch_bounds__(WarpsPerBlock * kWarpSize) -__global__ void int_8bit_split_embedding_codegen_forward_{{ wdesc }}_kernel_small_L( - const at::PackedTensorAccessor64 dev_weights, - const at::PackedTensorAccessor64 uvm_weights, - const at::PackedTensorAccessor32 weights_placements, - const at::PackedTensorAccessor32 weights_offsets, - const at::PackedTensorAccessor32 weights_tys, - const at::PackedTensorAccessor32 D_offsets, - const at::PackedTensorAccessor32 indices, - const at::PackedTensorAccessor32 offsets, - int64_t pooling_mode, - {% if weighted %} - at::PackedTensorAccessor32 - indice_weights, + {% else %} + int64_t D, {% endif %} - at::PackedTensorAccessor32 - output, // [B][total_D] - const at::PackedTensorAccessor64 lxu_cache_weights, - const at::PackedTensorAccessor32 lxu_cache_locations - ) { - int32_t B = output.size(0); - int32_t T = D_offsets.size(0) - 1; - int32_t bb_t = blockIdx.x * blockDim.y + threadIdx.y; - if (bb_t >= div_round_up(B, OutputRowsPerThread) * T) { - return; - } - static_assert( - std::is_same::value || std::is_same::value || std::is_same::value, - "output_t can only be float or half or bytes now" - ); - - uint32_t t = bb_t / div_round_up(B, OutputRowsPerThread); - - int32_t D_start = D_offsets[t]; - int32_t D_end = D_offsets[t + 1]; - int32_t D = D_end - D_start; - SparseType weight_ty = static_cast(weights_tys[t]); - if (weight_ty != SparseType::INT8) { - return; - } - - const int32_t D_bytes = padded_row_size_in_bytes(D, weight_ty); - - if (D_bytes <= MinNum128BRows * 128 || D_bytes > MaxNum128BRows * 128) { - return; - } - - uint32_t bb = bb_t % div_round_up(B, OutputRowsPerThread); - - int64_t weights_offset = weights_offsets[t]; - const int32_t D_total = padded_D(D, weight_ty); - const int32_t D_padding = D_total - D; - - uint32_t warp_idx = threadIdx.y; - int32_t indices_starts[OutputRowsPerThread]; - int32_t Ls[OutputRowsPerThread]; - int32_t max_Ls = 0; - - for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { - uint32_t b = min(static_cast(bb * OutputRowsPerThread + i), static_cast(B - 1)); - int32_t indices_start = offsets[t * B + b]; - int32_t indices_end = offsets[t * B + b + 1]; - indices_starts[i] = indices_start; - Ls[i] = indices_end - indices_start; - max_Ls = max(max_Ls, Ls[i]); - } - - const uint8_t* __restrict__ weights; - const auto placement = static_cast(weights_placements[t]); - if (placement == PlacementType::DEVICE) { - weights = &dev_weights[weights_offset]; - } else { - weights = &uvm_weights[weights_offset]; - } - constexpr size_t kOutputsPerThread = 4; - - constexpr uint32_t NumUint4PerRow = MaxNum128BRows * 128 / sizeof(uint4); - const uint32_t uint4_loads_per_row = div_round_up(D_bytes, sizeof(uint4)); - - VecNT<4> accumulators[OutputRowsPerThread][MaxNum128BRows]; - - for (uint32_t L_start = 0; L_start < max_Ls; L_start += InputRowsInFlight) { - uint32_t input_rows_in_flight = min(static_cast(InputRowsInFlight), max_Ls - L_start); - - typedef uint4 AllBuffers[WarpsPerBlock][OutputRowsPerThread][InputRowsInFlight][NumUint4PerRow]; - __shared__ AllBuffers buffers; - - {% if weighted %} - typedef float AllIndiceWeights[WarpsPerBlock][OutputRowsPerThread][InputRowsInFlight]; - __shared__ AllIndiceWeights buffers_indice_weights; - {% endif %} - - for (uint32_t load_idx = threadIdx.x; load_idx < input_rows_in_flight * uint4_loads_per_row; load_idx += kWarpSize) { - uint32_t row_load_idx = load_idx % uint4_loads_per_row; - uint32_t input_row_idx = (load_idx / uint4_loads_per_row); - - #pragma unroll OutputRowsPerThread - for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { - bool valid = L_start + input_row_idx < Ls[i]; - bool cache_valid = (placement == PlacementType::MANAGED_CACHING && valid); - int32_t idx = valid ? indices[indices_starts[i] + L_start + input_row_idx] : -1; - int32_t cache_idx = cache_valid ? lxu_cache_locations[indices_starts[i] + L_start + input_row_idx] : -1; - valid = valid && (idx != -1); - const uint4* row; - if (cache_valid && cache_idx != kCacheLocationMissing) { - row = reinterpret_cast(&lxu_cache_weights[static_cast(cache_idx)][0]); - } else if (valid) { - row = reinterpret_cast(&weights[static_cast(idx) * D_bytes]); - } else { - row = reinterpret_cast(&weights[0]); - } - cp_async_zfill_cg(&buffers[warp_idx][i][input_row_idx][row_load_idx], &row[row_load_idx], valid); - - {% if weighted %} - buffers_indice_weights[warp_idx][i][input_row_idx] = valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0; - {% endif %} - } - } - // equivalent to fence + wait. - cp_async_wait<0>(); -#ifdef __HIP_PLATFORM_HCC__ - __syncthreads(); -#else - __syncwarp(); -#endif - for (uint32_t input_row_idx = 0; input_row_idx < input_rows_in_flight; ++input_row_idx) { - #pragma unroll OutputRowsPerThread - for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { - bool valid = L_start + input_row_idx < Ls[i]; - const uint32_t* row = reinterpret_cast(&buffers[warp_idx][i][input_row_idx][0]); - half2 shift_scale = reinterpret_cast(row)[0]; - - {% if weighted %} - float row_weight = buffers_indice_weights[warp_idx][i][input_row_idx]; - {% endif %} - - #pragma unroll MaxNum128BRows - for (uint32_t j = 0; j < MaxNum128BRows; ++j) { - uint32_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; - if (valid) { - {% if weighted %} - accumulators[i][j].fma(v, shift_scale, row_weight); - {% else %} - accumulators[i][j].add(v, shift_scale); - {% endif %} - } - } - } - } - } - - #pragma unroll OutputRowsPerThread - for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { - uint32_t b = min(static_cast(bb * OutputRowsPerThread + i), static_cast(B - 1)); - float inv_L = 1.0 / Ls[i]; - - if (std::is_same::value || std::is_same::value) { - #pragma unroll MaxNum128BRows - for (uint32_t j = 0; j < MaxNum128BRows; ++j) { - int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; - - if (static_cast(pooling_mode) == PoolingMode::MEAN && Ls[i] != 0) { - accumulators[i][j].mul(inv_L); - } - - if (output_d >= 0 && output_d < D) { - accumulators[i][j].store(&output[b][D_start + output_d]); - } - } - } else if (std::is_same::value) { - // INT8: - // apply per feature row-wise int8 - float thread_local_min = std::numeric_limits::max(); - float thread_local_max = std::numeric_limits::lowest(); - float2 qparams; - #pragma unroll MaxNum128BRows - for (uint32_t j = 0; j < MaxNum128BRows; ++j) { - int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; - if (static_cast(pooling_mode) == PoolingMode::MEAN && Ls[i] != 0) { - accumulators[i][j].mul(inv_L); - } - if (output_d >= 0 && output_d < D) { - thread_local_max = max(thread_local_max, float4_max(accumulators[i][j].acc)); - thread_local_min = min(thread_local_min, float4_min(accumulators[i][j].acc)); - } - } - - qparams = warp_find_qparams(thread_local_min, thread_local_max); - int output_D_start = D_start + t * 8; - int output_D_end = output_D_start + D; - #pragma unroll MaxNum128BRows - for (uint32_t j = 0; j < MaxNum128BRows; ++j) { - int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; - if (output_d >= 0 && output_d < D) { - accumulators[i][j].store(&output[b][output_D_start + output_d], qparams); - } - } - if (threadIdx.x == 0) { - store_qparams_to_row(&output[b][output_D_end], qparams); - } - } else { - // INT4: not implemented yet - } - } -} - -template -__launch_bounds__(WarpsPerBlock * kWarpSize) -__global__ void int_4bit_split_embedding_codegen_forward_{{ wdesc }}_kernel_small_L( - const at::PackedTensorAccessor64 dev_weights, - const at::PackedTensorAccessor64 uvm_weights, - const at::PackedTensorAccessor32 weights_placements, - const at::PackedTensorAccessor32 weights_offsets, - const at::PackedTensorAccessor32 weights_tys, - const at::PackedTensorAccessor32 D_offsets, const at::PackedTensorAccessor32 indices, const at::PackedTensorAccessor32 offsets, + {% if not nobag %} int64_t pooling_mode, + {% endif %} {% if weighted %} at::PackedTensorAccessor32 indice_weights, @@ -805,8 +201,12 @@ __global__ void int_4bit_split_embedding_codegen_forward_{{ wdesc }}_kernel_smal const at::PackedTensorAccessor64 lxu_cache_weights, const at::PackedTensorAccessor32 lxu_cache_locations ) { + int32_t T = weights_offsets.size(0); + {% if not nobag %} int32_t B = output.size(0); - int32_t T = D_offsets.size(0) - 1; + {% else %} + int32_t B = (offsets.size(0) - 1) / T; + {% endif %} int32_t bb_t = blockIdx.x * blockDim.y + threadIdx.y; if (bb_t >= div_round_up(B, OutputRowsPerThread) * T) { return; @@ -818,11 +218,13 @@ __global__ void int_4bit_split_embedding_codegen_forward_{{ wdesc }}_kernel_smal uint32_t t = bb_t / div_round_up(B, OutputRowsPerThread); + {% if not nobag %} int32_t D_start = D_offsets[t]; int32_t D_end = D_offsets[t + 1]; int32_t D = D_end - D_start; + {% endif %} SparseType weight_ty = static_cast(weights_tys[t]); - if (weight_ty != SparseType::INT4) { + if (weight_ty != SparseType::{{ type_map[bit_width].enum_name }}) { return; } @@ -859,12 +261,14 @@ __global__ void int_4bit_split_embedding_codegen_forward_{{ wdesc }}_kernel_smal } else { weights = &uvm_weights[weights_offset]; } - constexpr size_t kOutputsPerThread = 8; + constexpr size_t kOutputsPerThread = {{ (32 // bit_width) }}; constexpr uint32_t NumUint4PerRow = MaxNum128BRows * 128 / sizeof(uint4); const uint32_t uint4_loads_per_row = div_round_up(D_bytes, sizeof(uint4)); - VecNT<8> accumulators[OutputRowsPerThread][MaxNum128BRows]; + {% if not nobag %} + VecNT<{{ (32 // bit_width) }}> accumulators[OutputRowsPerThread][MaxNum128BRows]; + {% endif %} for (uint32_t L_start = 0; L_start < max_Ls; L_start += InputRowsInFlight) { uint32_t input_rows_in_flight = min(static_cast(InputRowsInFlight), max_Ls - L_start); @@ -914,28 +318,78 @@ __global__ void int_4bit_split_embedding_codegen_forward_{{ wdesc }}_kernel_smal #pragma unroll OutputRowsPerThread for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { bool valid = L_start + input_row_idx < Ls[i]; + if (!valid) { + continue; + } const uint32_t* row = reinterpret_cast(&buffers[warp_idx][i][input_row_idx][0]); + {% if bit_width in [8, 4] %} half2 shift_scale = reinterpret_cast(row)[0]; + {% endif %} {% if weighted %} float row_weight = buffers_indice_weights[warp_idx][i][input_row_idx]; {% endif %} + using scalar_t = {{ type_map[bit_width].cpp_type_name }}; + + {% if not nobag %} #pragma unroll MaxNum128BRows for (uint32_t j = 0; j < MaxNum128BRows; ++j) { - uint32_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; - if (valid) { - {% if weighted %} - accumulators[i][j].fma(v, shift_scale, row_weight); - {% else %} - accumulators[i][j].add(v, shift_scale); - {% endif %} + scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; + {% if weighted %} + accumulators[i][j].fma(v, {% if bit_width in [8, 4] %} shift_scale, {% endif %} row_weight); + {% else %} + accumulators[i][j].add(v{% if bit_width in [8, 4] %}, shift_scale {% endif %}); + {% endif %} + } + {% else %} + int32_t output_j = indices_starts[i] + L_start + input_row_idx; + if (std::is_same::value || std::is_same::value) { + #pragma unroll MaxNum128BRows + for (uint32_t j = 0; j < MaxNum128BRows; ++j) { + int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; + if (output_d >= 0 && output_d < D) { + VecNT<{{ (32 // bit_width) }}> acc(v{% if bit_width in [8, 4] %}, shift_scale {% endif %}); + acc.store(&output[output_j][output_d]); + } + } + } else if (std::is_same::value) { + // INT8: + // apply per feature row-wise int8 + float thread_local_min = std::numeric_limits::max(); + float thread_local_max = std::numeric_limits::lowest(); + float2 qparams; + #pragma unroll MaxNum128BRows + for (uint32_t j = 0; j < MaxNum128BRows; ++j) { + int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; + VecNT<{{ (32 // bit_width) }}> acc(v{% if bit_width in [8, 4] %}, shift_scale {% endif %}); + if (output_d >= 0 && output_d < D) { + thread_local_max = max(thread_local_max, float{{ (32 // bit_width) }}_max(acc.acc)); + thread_local_min = min(thread_local_min, float{{ (32 // bit_width) }}_min(acc.acc)); + } + } + qparams = warp_find_qparams(thread_local_min, thread_local_max); + #pragma unroll MaxNum128BRows + for (uint32_t j = 0; j < MaxNum128BRows; ++j) { + int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; + if (output_d >= 0 && output_d < D) { + VecNT<{{ (32 // bit_width) }}> acc(v{% if bit_width in [8, 4] %}, shift_scale {% endif %}); + acc.store(&output[output_j][output_d], qparams); + } + } + if (threadIdx.x == 0) { + store_qparams_to_row(&output[output_j][D], qparams); } } + {% endif %} } } } + {% if not nobag %} #pragma unroll OutputRowsPerThread for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { uint32_t b = min(static_cast(bb * OutputRowsPerThread + i), static_cast(B - 1)); @@ -968,8 +422,8 @@ __global__ void int_4bit_split_embedding_codegen_forward_{{ wdesc }}_kernel_smal accumulators[i][j].mul(inv_L); } if (output_d >= 0 && output_d < D) { - thread_local_max = max(thread_local_max, float8_max(accumulators[i][j].acc)); - thread_local_min = min(thread_local_min, float8_min(accumulators[i][j].acc)); + thread_local_max = max(thread_local_max, float{{ (32 // bit_width) }}_max(accumulators[i][j].acc)); + thread_local_min = min(thread_local_min, float{{ (32 // bit_width) }}_min(accumulators[i][j].acc)); } } @@ -990,7 +444,11 @@ __global__ void int_4bit_split_embedding_codegen_forward_{{ wdesc }}_kernel_smal // INT4: not implemented yet } } + {% endif %} } +{% endfor %} // for bit_width in [32, 16, 8, 4] +{% endif %} // if not nobag or not weighted +{% endfor %} // for nobag in [True, False] __device__ inline uint32_t pruned_hash_function(uint32_t h) { // MurmorHash3 32-bit mixing function. @@ -1054,6 +512,10 @@ __global__ void int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_{ dense_indices[indices_start + l_start + subwarp_id] = slot_dense_idx; } #ifdef __HIP_PLATFORM_HCC__ + // FIXME: __any_sync with mask isn't supported by HIP yet. + // See https://fburl.com/fvy7j0lq for the similar context. + // assert false here with https://fburl.com/pfm7enw2 + assert(false); if (__any(found)) { #else if (__any_sync(subwarp_mask, found)) { @@ -1104,14 +566,20 @@ __global__ void int_nbit_split_embedding_codegen_forward_pruned_array_lookup_ker } -Tensor int_nbit_split_embedding_codegen_forward_{{ wdesc }}_cuda( +{% for nobag in [True, False] %} +{% if not nobag or not weighted %} +Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_cuda( Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, Tensor weights_offsets, Tensor weights_tys, + {% if not nobag %} Tensor D_offsets, int64_t total_D, + {% else %} + int64_t D, + {% endif %} int64_t max_int2_D, int64_t max_int4_D, int64_t max_int8_D, @@ -1119,7 +587,9 @@ Tensor int_nbit_split_embedding_codegen_forward_{{ wdesc }}_cuda( int64_t max_float32_D, Tensor indices, Tensor offsets, + {% if not nobag %} int64_t pooling_mode, + {% endif %} {% if weighted %} Tensor indice_weights, {% endif %} @@ -1128,22 +598,48 @@ Tensor int_nbit_split_embedding_codegen_forward_{{ wdesc }}_cuda( Tensor lxu_cache_locations, int64_t unused ) { + TENSOR_ON_CUDA_GPU(dev_weights); + TENSOR_ON_CUDA_GPU(uvm_weights); + TENSOR_ON_CUDA_GPU(weights_placements); + TENSOR_ON_CUDA_GPU(weights_offsets); + TENSOR_ON_CUDA_GPU(weights_tys); + {% if not nobag %} + TENSOR_ON_CUDA_GPU(D_offsets); + {% endif %} + TENSOR_ON_CUDA_GPU(indices); + TENSOR_ON_CUDA_GPU(offsets); + {% if weighted %} + TENSOR_EMPTY_OR_ON_CUDA_GPU(indice_weights); + {% endif %} + TENSOR_EMPTY_OR_ON_CUDA_GPU(lxu_cache_weights); + TENSOR_EMPTY_OR_ON_CUDA_GPU(lxu_cache_locations); + at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(dev_weights.get_device()); + {% if not nobag %} int32_t T = D_offsets.numel() - 1; + {% else %} + int32_t total_L = indices.numel(); + int32_t T = weights_offsets.numel(); + {% endif %} TORCH_CHECK(T > 0); // offsets = [B x T + 1] int32_t B = (offsets.size(0) - 1) / T; TORCH_CHECK(B >= 0); + {% if not nobag %} TORCH_CHECK(total_D > 0); + {% else %} + TORCH_CHECK(D > 0); + {% endif %} TORCH_CHECK(max_int2_D == 0); Tensor output; const int kINT8QparamsBytes = 8; SparseType o_dtype = static_cast(output_dtype); TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || o_dtype == SparseType::INT8); + {% if not nobag %} if (o_dtype == SparseType::FP32) { output = at::empty({B, total_D}, dev_weights.options().dtype(at::kFloat)); } else if (o_dtype == SparseType::FP16) { @@ -1151,6 +647,15 @@ Tensor int_nbit_split_embedding_codegen_forward_{{ wdesc }}_cuda( } else if (o_dtype == SparseType::INT8) { output = at::empty({B, total_D + T * kINT8QparamsBytes}, dev_weights.options().dtype(at::kByte)); } + {% else %} + if (o_dtype == SparseType::FP32) { + output = at::empty({total_L, D}, dev_weights.options().dtype(at::kFloat)); + } else if (o_dtype == SparseType::FP16) { + output = at::empty({total_L, D}, dev_weights.options().dtype(at::kHalf)); + } else if (o_dtype == SparseType::INT8) { + output = at::empty({total_L, D + kINT8QparamsBytes}, dev_weights.options().dtype(at::kByte)); + } + {% endif %} if (B == 0) { return output; @@ -1162,7 +667,7 @@ Tensor int_nbit_split_embedding_codegen_forward_{{ wdesc }}_cuda( constexpr int32_t kWarpsPerBlock = 4; #define X(OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \ - nbit::int_4bit_split_embedding_codegen_forward_{{ wdesc }}_kernel_small_L<<< \ + nbit::INT4_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_kernel_small_L<<< \ nbit::div_round_up(T * nbit::div_round_up(B, OutputRowsPerThread), kWarpsPerBlock), \ dim3(kWarpSize, kWarpsPerBlock), \ 0, \ @@ -1172,10 +677,16 @@ Tensor int_nbit_split_embedding_codegen_forward_{{ wdesc }}_cuda( weights_placements.packed_accessor32(), \ weights_offsets.packed_accessor32(), \ weights_tys.packed_accessor32(), \ + {% if not nobag %} \ D_offsets.packed_accessor32(), \ + {% else %} \ + D, \ + {% endif %} \ indices.packed_accessor32(), \ offsets.packed_accessor32(), \ + {% if not nobag %} \ pooling_mode, \ + {% endif %} \ {% if weighted %} indice_weights.packed_accessor32(), {% endif %} \ output.packed_accessor32(), \ lxu_cache_weights.packed_accessor64(), \ @@ -1183,7 +694,7 @@ Tensor int_nbit_split_embedding_codegen_forward_{{ wdesc }}_cuda( ); \ C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - DISPATCH_OUTPUT_TYPES(output.type(), "int4_split_embedding_codegen_forward_kernel", ([&] { + DISPATCH_OUTPUT_TYPES(output.type(), "int4_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_kernel", ([&] { if (max_int4_D > 0) { auto max_int4_128b_rows = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_int4_D, SparseType::INT4), 128); TORCH_CHECK(max_int4_128b_rows <= 4); @@ -1200,9 +711,8 @@ Tensor int_nbit_split_embedding_codegen_forward_{{ wdesc }}_cuda( })); #undef X - #define X(OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \ - nbit::int_8bit_split_embedding_codegen_forward_{{ wdesc }}_kernel_small_L<<< \ + nbit::INT8_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_kernel_small_L<<< \ nbit::div_round_up(T * nbit::div_round_up(B, OutputRowsPerThread), kWarpsPerBlock), \ dim3(kWarpSize, kWarpsPerBlock), \ 0, \ @@ -1212,10 +722,16 @@ Tensor int_nbit_split_embedding_codegen_forward_{{ wdesc }}_cuda( weights_placements.packed_accessor32(), \ weights_offsets.packed_accessor32(), \ weights_tys.packed_accessor32(), \ + {% if not nobag %} \ D_offsets.packed_accessor32(), \ + {% else %} \ + D, \ + {% endif %} \ indices.packed_accessor32(), \ offsets.packed_accessor32(), \ + {% if not nobag %} \ pooling_mode, \ + {% endif %} \ {% if weighted %} indice_weights.packed_accessor32(), {% endif %} \ output.packed_accessor32(), \ lxu_cache_weights.packed_accessor64(), \ @@ -1223,7 +739,7 @@ Tensor int_nbit_split_embedding_codegen_forward_{{ wdesc }}_cuda( ); \ C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - DISPATCH_OUTPUT_TYPES(output.type(), "int8_split_embedding_codegen_forward_kernel", ([&] { + DISPATCH_OUTPUT_TYPES(output.type(), "int8_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_kernel", ([&] { if (max_int8_D > 0) { auto max_int8_128b_rows = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_int8_D, SparseType::INT8), 128); TORCH_CHECK(max_int8_128b_rows <= 8); @@ -1244,7 +760,7 @@ Tensor int_nbit_split_embedding_codegen_forward_{{ wdesc }}_cuda( #undef X #define X(OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \ - nbit::fp16_split_embedding_codegen_forward_{{ wdesc }}_kernel_small_L<<< \ + nbit::FP16_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_kernel_small_L<<< \ nbit::div_round_up(T * nbit::div_round_up(B, OutputRowsPerThread), kWarpsPerBlock), \ dim3(kWarpSize, kWarpsPerBlock), \ 0, \ @@ -1254,10 +770,16 @@ Tensor int_nbit_split_embedding_codegen_forward_{{ wdesc }}_cuda( weights_placements.packed_accessor32(), \ weights_offsets.packed_accessor32(), \ weights_tys.packed_accessor32(), \ + {% if not nobag %} \ D_offsets.packed_accessor32(), \ + {% else %} \ + D, \ + {% endif %} \ indices.packed_accessor32(), \ offsets.packed_accessor32(), \ + {% if not nobag %} \ pooling_mode, \ + {% endif %} \ {% if weighted %} indice_weights.packed_accessor32(), {% endif %} \ output.packed_accessor32(), \ lxu_cache_weights.packed_accessor64(), \ @@ -1265,7 +787,7 @@ Tensor int_nbit_split_embedding_codegen_forward_{{ wdesc }}_cuda( ); \ C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - DISPATCH_OUTPUT_TYPES(output.type(), "fp16_split_embedding_codegen_forward_kernel", ([&] { + DISPATCH_OUTPUT_TYPES(output.type(), "fp16_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_kernel", ([&] { if (max_float16_D > 0) { auto max_fp16_128b_rows = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_float16_D, SparseType::FP16), 128); TORCH_CHECK(max_fp16_128b_rows <= 16); @@ -1286,7 +808,7 @@ Tensor int_nbit_split_embedding_codegen_forward_{{ wdesc }}_cuda( #undef X #define X(OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \ - nbit::fp32_split_embedding_codegen_forward_{{ wdesc }}_kernel_small_L<<< \ + nbit::FP32_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_kernel_small_L<<< \ nbit::div_round_up(T * nbit::div_round_up(B, OutputRowsPerThread), kWarpsPerBlock), \ dim3(kWarpSize, kWarpsPerBlock), \ 0, \ @@ -1296,10 +818,16 @@ Tensor int_nbit_split_embedding_codegen_forward_{{ wdesc }}_cuda( weights_placements.packed_accessor32(), \ weights_offsets.packed_accessor32(), \ weights_tys.packed_accessor32(), \ + {% if not nobag %} \ D_offsets.packed_accessor32(), \ + {% else %} \ + D, \ + {% endif %} \ indices.packed_accessor32(), \ offsets.packed_accessor32(), \ + {% if not nobag %} \ pooling_mode, \ + {% endif %} \ {% if weighted %} indice_weights.packed_accessor32(), {% endif %} \ output.packed_accessor32(), \ lxu_cache_weights.packed_accessor64(), \ @@ -1307,7 +835,7 @@ Tensor int_nbit_split_embedding_codegen_forward_{{ wdesc }}_cuda( ); \ C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - DISPATCH_OUTPUT_TYPES(output.type(), "fp32_split_embedding_codegen_forward_kernel", ([&] { + DISPATCH_OUTPUT_TYPES(output.type(), "fp32_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_kernel", ([&] { if (max_float32_D > 0) { auto max_fp32_128b_rows = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_float32_D, SparseType::FP32), 128); TORCH_CHECK(max_fp32_128b_rows <= 32); @@ -1321,12 +849,20 @@ Tensor int_nbit_split_embedding_codegen_forward_{{ wdesc }}_cuda( // TODO: 2-bit kernels. return output; } +{% endif %} // if not nobag or not weighted +{% endfor %} // for nobag in [True, False] Tensor pruned_hashmap_lookup_{{ wdesc }}_cuda( Tensor indices, Tensor offsets, Tensor hash_table, Tensor hash_table_offsets) { + + TENSOR_ON_CUDA_GPU(indices); + TENSOR_ON_CUDA_GPU(offsets); + TENSOR_ON_CUDA_GPU(hash_table); + TENSOR_ON_CUDA_GPU(hash_table_offsets); + at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(indices.get_device()); auto dense_indices = at::empty_like(indices); @@ -1358,6 +894,12 @@ Tensor pruned_array_lookup_cuda( Tensor offsets, Tensor index_remappings, Tensor index_remappings_offsets) { + + TENSOR_ON_CUDA_GPU(indices); + TENSOR_ON_CUDA_GPU(offsets); + TENSOR_ON_CUDA_GPU(index_remappings); + TENSOR_ON_CUDA_GPU(index_remappings_offsets); + at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(indices.get_device()); auto dense_indices = at::empty_like(indices); @@ -1395,4 +937,4 @@ Tensor pruned_array_lookup_cuda( return dense_indices; } {% endif %} -// clang-format on + // clang-format on diff --git a/fbgemm_gpu/codegen/embedding_forward_split_cpu.cpp b/fbgemm_gpu/codegen/embedding_forward_split_cpu.cpp index e7b77ba6f..cae49667d 100644 --- a/fbgemm_gpu/codegen/embedding_forward_split_cpu.cpp +++ b/fbgemm_gpu/codegen/embedding_forward_split_cpu.cpp @@ -5,11 +5,11 @@ * LICENSE file in the root directory of this source tree. */ #include "codegen/embedding_forward_split_cpu.h" -#include "codegen/embedding_common.h" #include "fbgemm/FbgemmEmbedding.h" #include "fbgemm/Types.h" #include "fbgemm/Utils.h" -#include "include/fbgemm_gpu/cpu_utils.h" +#include "fbgemm_gpu/cpu_utils.h" +#include "fbgemm_gpu/embedding_common.h" #ifdef FBCODE_CAFFE2 #include #include "folly/container/F14Map.h" diff --git a/fbgemm_gpu/codegen/embedding_forward_split_template.cu b/fbgemm_gpu/codegen/embedding_forward_split_template.cu index d4bcca1eb..ff5f47bac 100644 --- a/fbgemm_gpu/codegen/embedding_forward_split_template.cu +++ b/fbgemm_gpu/codegen/embedding_forward_split_template.cu @@ -127,29 +127,16 @@ __global__ void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if noba at::acc_type idx_weight = l < L ? indice_weights[indices_start + l] : 0; {% endif %} for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) { -#ifdef __HIP_PLATFORM_HCC__ - int64_t idx_j = __shfl(idx, j); -#else - int64_t idx_j = __shfl_sync(0xFFFFFFFF, idx, j); -#endif - + int64_t idx_j = shfl_sync(idx, j); {% if nobag %} int64_t output_j = indices_start + l_start + j; {% endif %} {% if not dense %} -#ifdef __HIP_PLATFORM_HCC__ - int32_t cache_idx_j = __shfl(cache_idx, j); -#else - int32_t cache_idx_j = __shfl_sync(0xFFFFFFFF, cache_idx, j); -#endif + int32_t cache_idx_j = shfl_sync(cache_idx, j); {% endif %} {% if weighted %} -#ifdef __HIP_PLATFORM_HCC__ - at::acc_type idx_weight_j = __shfl(idx_weight, j); -#else - at::acc_type idx_weight_j = __shfl_sync(0xFFFFFFFF, idx_weight, j); -#endif + at::acc_type idx_weight_j = shfl_sync(idx_weight, j); {% endif %} {% if not dense %} @@ -317,6 +304,25 @@ Tensor {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" {% endif %} int64_t unused ) { + TENSOR_ON_CUDA_GPU(dev_weights); + {% if not dense %} + TENSOR_ON_CUDA_GPU(uvm_weights); + TENSOR_ON_CUDA_GPU(lxu_cache_weights); + TENSOR_ON_CUDA_GPU(weights_placements); + {% endif %} + TENSOR_ON_CUDA_GPU(weights_offsets); + {% if not nobag %} + TENSOR_ON_CUDA_GPU(D_offsets); + {% endif %} + TENSOR_ON_CUDA_GPU(indices); + TENSOR_ON_CUDA_GPU(offsets); + {% if weighted %} + TENSOR_ON_CUDA_GPU(indice_weights); + {% endif %} + {% if not dense %} + TENSOR_ON_CUDA_GPU(lxu_cache_locations); + {% endif %} + at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(dev_weights.get_device()); diff --git a/fbgemm_gpu/codegen/embedding_forward_template_helpers.cuh b/fbgemm_gpu/codegen/embedding_forward_template_helpers.cuh index 86c442a2c..272729777 100644 --- a/fbgemm_gpu/codegen/embedding_forward_template_helpers.cuh +++ b/fbgemm_gpu/codegen/embedding_forward_template_helpers.cuh @@ -26,6 +26,7 @@ #include #include -#include "codegen/embedding_common.h" #include "fbgemm_gpu/dispatch_macros.h" +#include "fbgemm_gpu/embedding_common.h" #include "fbgemm_gpu/fbgemm_cuda_utils.cuh" +#include "fbgemm_gpu/sparse_ops_utils.h" diff --git a/fbgemm_gpu/fbgemm_gpu/split_embedding_inference_converter.py b/fbgemm_gpu/fbgemm_gpu/split_embedding_inference_converter.py index 0fcf595d6..7780547c6 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_embedding_inference_converter.py +++ b/fbgemm_gpu/fbgemm_gpu/split_embedding_inference_converter.py @@ -17,13 +17,6 @@ from fbgemm_gpu.split_embedding_configs import SparseType from torch import Tensor, nn -# TODO: move torch.ops.fb.embedding_bag_rowwise_prune to OSS -try: - # pyre-ignore[21] - from fbgemm_gpu import open_source # noqa: F401 -except Exception: - torch.ops.load_library("//caffe2/torch/fb/sparsenn:sparsenn_operators") - # TODO: add per-feature based converter option (based on embedding_specs during inference) # TODO: optimize embedding pruning and quantization latency. class SplitEmbInferenceConverter: @@ -75,7 +68,7 @@ def _prune_embs( (indicators, threshold) = self._prune_by_weights_l2_norm(new_num_rows, weights) - return torch.ops.fb.embedding_bag_rowwise_prune( + return torch.ops.fbgemm.embedding_bag_rowwise_prune( weights, indicators, threshold, torch.int32 ) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py index 8aae5b4cb..017381100 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py @@ -12,7 +12,7 @@ from dataclasses import dataclass from itertools import accumulate from math import log2 -from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, Union +from typing import Dict, List, NamedTuple, Optional, Tuple, Type, Union import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers import torch @@ -1458,84 +1458,6 @@ def init_embedding_weights_uniform(self, min_val: float, max_val: float) -> None param.uniform_(min_val, max_val) -class SequenceEmbeddingCodegen(SplitTableBatchedEmbeddingBagsCodegen): - """ - This class wraps around SplitTableBatchedEmbeddingBagsCodegen to get - sequence embedding op: nn.EmbeddingBag(sparse=True) - """ - - def __init__( - self, - **kwargs: Any, - ) -> None: - # assert T == 1 - assert "embedding_specs" in kwargs - assert len(kwargs["embedding_specs"]) == 1 - super(SequenceEmbeddingCodegen, self).__init__( - **kwargs, - ) - - # @torch.jit.ignore - def forward( - self, - indices: Tensor, - offsets: Optional[Tensor] = None, - per_sample_weights: Optional[Tensor] = None, - feature_requires_grad: Optional[Tensor] = None, - ) -> Tensor: - offsets = torch.arange( - 0, - indices.numel() + 1, - device=indices.device, - dtype=torch.int64, - ) - return super(SequenceEmbeddingCodegen, self).forward( - indices, - offsets, - per_sample_weights, - feature_requires_grad, - ) - - -class DenseSequenceEmbeddingCodegen(DenseTableBatchedEmbeddingBagsCodegen): - """ - This class wraps around DenseTableBatchedEmbeddingBagsCodegen to get - sequence embedding op, nn.EmbeddingBag(sparse=False) - """ - - def __init__( - self, - **kwargs: Any, - ) -> None: - # assert T == 1 - assert "embedding_specs" in kwargs - assert len(kwargs["embedding_specs"]) == 1 - super(DenseSequenceEmbeddingCodegen, self).__init__( - **kwargs, - ) - - # @torch.jit.ignore - def forward( - self, - indices: Tensor, - offsets: Optional[Tensor] = None, - per_sample_weights: Optional[Tensor] = None, - feature_requires_grad: Optional[Tensor] = None, - ) -> Tensor: - offsets = torch.arange( - 0, - indices.numel() + 1, - device=indices.device, - dtype=torch.int64, - ) - return super(DenseSequenceEmbeddingCodegen, self).forward( - indices, - offsets, - per_sample_weights, - feature_requires_grad, - ) - - def round_up(a: int, b: int) -> int: return int((a + b - 1) // b) * b @@ -1654,6 +1576,13 @@ def __init__( weights_tys: List[SparseType] = [e[3] for e in embedding_specs] locations: List[EmbeddingLocation] = [e[4] for e in embedding_specs] + # mixed D is not supported by no bag kernels + mixed_D = not all(d == dims[0] for d in dims) + if mixed_D: + assert ( + self.pooling_mode != PoolingMode.NONE + ), "Mixed dimension tables are only supported for pooling tables." + assert not self.use_cpu or all( loc == EmbeddingLocation.HOST for loc in locations ), "ComputeDevice.CPU is only for EmbeddingLocation.HOST!" @@ -1825,7 +1754,6 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> None: if not self.lxu_cache_weights.numel(): return - (indices, offsets) = indices.long(), offsets.long() linear_cache_indices = torch.ops.fb.linearize_cache_indices( self.cache_hash_size_cumsum, indices, @@ -2080,6 +2008,7 @@ def _apply_cache_state( cache_sets = ( int(1.0 * free_memory / self.max_D_cache) + ASSOC - 1 ) // ASSOC + cache_sets = 1 if cache_sets == 0 else cache_sets cache_load_factor = ( 1.0 * cache_sets * ASSOC / int(cache_state.total_cache_hash_size) ) diff --git a/fbgemm_gpu/fbgemm_gpu/uvm.py b/fbgemm_gpu/fbgemm_gpu/uvm.py index c7c1f0cf1..4850054d5 100644 --- a/fbgemm_gpu/fbgemm_gpu/uvm.py +++ b/fbgemm_gpu/fbgemm_gpu/uvm.py @@ -18,18 +18,18 @@ torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:cumem_utils") # Import all uvm enums from c++ library -create_enums(globals(), torch.ops.fb.fbgemm_gpu_uvm_enum_query) +create_enums(globals(), torch.ops.fbgemm.fbgemm_gpu_uvm_enum_query) def cudaMemAdvise( t: torch.Tensor, advice: Enum, ) -> None: - torch.ops.fb.cuda_mem_advise(t, advice.value) + torch.ops.fbgemm.cuda_mem_advise(t, advice.value) def cudaMemPrefetchAsync( t: torch.Tensor, device_t: Optional[torch.Tensor] = None, ) -> None: - torch.ops.fb.cuda_mem_prefetch_async(t, device_t) + torch.ops.fbgemm.cuda_mem_prefetch_async(t, device_t) diff --git a/fbgemm_gpu/include/fbgemm_gpu/batched_unary_embedding_ops.cuh b/fbgemm_gpu/include/fbgemm_gpu/batched_unary_embedding_ops.cuh deleted file mode 100644 index 87a4374f6..000000000 --- a/fbgemm_gpu/include/fbgemm_gpu/batched_unary_embedding_ops.cuh +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * All rights reserved. - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ -#pragma once - -#include -#include - -// Forward kernel for batched unary embedding op -template -__global__ void batched_unary_embeddings_forward_kernel( - const int32_t N, - const int32_t B, - const int32_t T, - const scalar_t* __restrict__ weight, // N * sum(E) * 1 (embedding dimension - // is 1) - const index_t* __restrict__ table_offsets, - const index_t* __restrict__ offsets, - const index_t* __restrict__ indices, - scalar_t* __restrict__ output // N * B * T -) { - index_t sum_E = table_offsets[T]; - int32_t b = blockIdx.x * blockDim.x + threadIdx.x; - if (b >= B) { - return; - } - int32_t t = blockIdx.y; - int32_t n = blockIdx.z; - index_t table_offset = table_offsets[t]; - index_t indices_start = offsets[t * B + b]; - index_t indices_end = offsets[t * B + b + 1]; - int32_t L = indices_end - indices_start; - // TODO: this should be at::acc_type - scalar_t sum = 0.0; - for (int32_t l = 0; l < L; ++l) { - auto idx = __ldg(&indices[indices_start + l]); - sum += weight[n * sum_E + table_offset + idx + 0]; - } - output[(n * B + b) * T + t] = sum; -} - -// Backward kernel for batched unary embedding op -template -__global__ void batched_unary_embeddings_backward_kernel( - const int32_t N, - const int32_t B, - const int32_t T, - const scalar_t* __restrict__ grad_output, // [N * B * T] - const index_t* __restrict__ table_offsets, - const index_t* __restrict__ offsets, - const index_t* __restrict__ indices, - scalar_t* __restrict__ grad_weight // [N * sum_E * 1] (embedding - // dimension is 1) -) { - index_t sum_E = table_offsets[T]; - int32_t n_t = blockIdx.x * blockDim.x + threadIdx.x; - int32_t n = n_t / T; - int32_t t = n_t % T; - if (n >= N) { - return; - } - index_t table_offset = table_offsets[t]; - - for (int32_t b = 0; b < B; ++b) { - int32_t indices_start = offsets[t * B + b]; - int32_t indices_end = offsets[t * B + b + 1]; - int32_t L = indices_end - indices_start; - const scalar_t go = grad_output[(n * B + b) * T + t]; - for (int32_t l = 0; l < L; ++l) { - index_t idx = __ldg(&indices[indices_start + l]); - grad_weight[n * sum_E + table_offset + idx + 0] += go; - } - } -} diff --git a/fbgemm_gpu/codegen/embedding_backward_template_helpers.cuh b/fbgemm_gpu/include/fbgemm_gpu/embedding_backward_template_helpers.cuh similarity index 59% rename from fbgemm_gpu/codegen/embedding_backward_template_helpers.cuh rename to fbgemm_gpu/include/fbgemm_gpu/embedding_backward_template_helpers.cuh index beceed8b4..ad0804fe8 100644 --- a/fbgemm_gpu/codegen/embedding_backward_template_helpers.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/embedding_backward_template_helpers.cuh @@ -28,9 +28,10 @@ #include #include -#include "codegen/embedding_common.h" -#include "fbgemm_gpu/dispatch_macros.h" -#include "fbgemm_gpu/fbgemm_cuda_utils.cuh" +#include "dispatch_macros.h" +#include "embedding_common.h" +#include "fbgemm_cuda_utils.cuh" +#include "sparse_ops_utils.h" inline at::Tensor asynchronous_complete_cumsum(at::Tensor t_in) { at::cuda::OptionalCUDAGuard device_guard; @@ -69,77 +70,6 @@ inline at::Tensor asynchronous_complete_cumsum(at::Tensor t_in) { return t_out; } -template -__global__ void __launch_bounds__(kMaxThreads) -linearize_index_kernel( - const at::PackedTensorAccessor32 - hash_size_cumsum, - const at::PackedTensorAccessor32 indices, - const at::PackedTensorAccessor32 offsets, - at::PackedTensorAccessor32 infos, - at::PackedTensorAccessor32 - linear_indices) { - int32_t T = hash_size_cumsum.size(0) - 1; - int32_t B = (offsets.size(0) - 1) / T; - int32_t b_t = blockIdx.x * blockDim.x + threadIdx.x; - int32_t b = b_t % B; - int32_t t = b_t / B; - bool valid = t < T; - - index_t hash_offset = valid ? hash_size_cumsum[t] : -1; - index_t indices_start = valid ? offsets[t * B + b] : -1; - int32_t L = valid ? offsets[t * B + b + 1] - indices_start : 0; - int32_t lane_id = threadIdx.x % fbgemm_gpu::kWarpSize; - - for (int32_t j = 0; j < fbgemm_gpu::kWarpSize; ++j) { - index_t indices_start_warp = SHFL_SYNC_MACRO(indices_start, j); - int32_t b_t_warp = SHFL_SYNC_MACRO(b_t, j); - int32_t L_warp = SHFL_SYNC_MACRO(L, j); - index_t hash_offset_warp = SHFL_SYNC_MACRO(hash_offset, j); - for (int32_t i = lane_id; i < L_warp; i += fbgemm_gpu::kWarpSize) { - index_t idx = __ldg(&indices[indices_start_warp + i]); - infos[indices_start_warp + i] = b_t_warp; - linear_indices[indices_start_warp + i] = hash_offset_warp + idx; - } - } -} - -template -__global__ void nobag_linearize_index_kernel( - const at::PackedTensorAccessor32 - hash_size_cumsum, - const at::PackedTensorAccessor32 indices, - const at::PackedTensorAccessor32 offsets, - at::PackedTensorAccessor32 infos, - at::PackedTensorAccessor32 - linear_indices) { - int32_t T = hash_size_cumsum.size(0) - 1; - int32_t B = (offsets.size(0) - 1) / T; - int32_t b_t = blockIdx.x * blockDim.x + threadIdx.x; - int32_t b = b_t % B; - int32_t t = b_t / B; - bool valid = t < T; - - index_t hash_offset = valid ? hash_size_cumsum[t] : -1; - index_t indices_start = valid ? offsets[t * B + b] : -1; - int32_t L = valid ? offsets[t * B + b + 1] - indices_start : 0; - int32_t lane_id = threadIdx.x % fbgemm_gpu::kWarpSize; - - for (int32_t j = 0; j < fbgemm_gpu::kWarpSize; ++j) { - index_t indices_start_warp = SHFL_SYNC_MACRO(indices_start, j); - int32_t t_warp = SHFL_SYNC_MACRO(t, j); - int32_t L_warp = SHFL_SYNC_MACRO(L, j); - index_t hash_offset_warp = SHFL_SYNC_MACRO(hash_offset, j); - for (int32_t i = lane_id; i < L_warp; i += fbgemm_gpu::kWarpSize) { - index_t idx = __ldg(&indices[indices_start_warp + i]); - int64_t l_t = (indices_start_warp + i) * T + t_warp; - infos[indices_start_warp + i] = l_t; - linear_indices[indices_start_warp + i] = hash_offset_warp + idx; - } - } -} - class FixedDivisor { public: explicit FixedDivisor(const int32_t d) : d_(d) { diff --git a/fbgemm_gpu/codegen/embedding_common.h b/fbgemm_gpu/include/fbgemm_gpu/embedding_common.h similarity index 97% rename from fbgemm_gpu/codegen/embedding_common.h rename to fbgemm_gpu/include/fbgemm_gpu/embedding_common.h index ec2030591..469b861c2 100644 --- a/fbgemm_gpu/codegen/embedding_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/embedding_common.h @@ -5,7 +5,7 @@ * LICENSE file in the root directory of this source tree. */ #pragma once -#include +#include namespace { diff --git a/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh b/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh index d033645bf..9de3cd43e 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh @@ -443,12 +443,19 @@ DEVICE_INLINE Vec4T vec4_acc( template DEVICE_INLINE T shfl_xor(const T val, int laneMask, int width = kWarpSize) { -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_HCC__) || CUDA_VERSION < 9000 return __shfl_xor(val, laneMask, width); -#elif CUDA_VERSION >= 9000 +#else return __shfl_xor_sync(0xffffffff, val, laneMask, width); +#endif +} + +template +DEVICE_INLINE T shfl_sync(const T val, int srcLane = 0, int width = kWarpSize) { +#if defined(__HIP_PLATFORM_HCC__) || CUDA_VERSION < 9000 + return __shfl(val, srcLane, width); #else - return __shfl_xor(val, laneMask, width); + return __shfl_sync(0xffffffff, val, srcLane, width); #endif } @@ -489,7 +496,7 @@ stochastic_rounding_scalar_uint8(float x, uint32_t random_bits) { // noise.F in [1, 2] noise.F = noise.F - 1.5; // noise.F in [-0.5, 0.5] - return std::lrintf(x + noise.F); + return lrintf(x + noise.F); } // This is a simple xorshift* RNG with 64 bits of state (vs 384 bits of state @@ -560,17 +567,12 @@ DEVICE_INLINE void stochastic_rounding_vector( float2 /* not used */) { uint4 random_bits = stochastic_rounding_rand4(&state); Half4 v; -#ifdef __HIP_PLATFORM_HCC__ - v.a = __halves2half2(stochastic_rounding_scalar(value.acc.x, random_bits.x), - stochastic_rounding_scalar(value.acc.y, random_bits.y)); - v.b = __halves2half2(stochastic_rounding_scalar(value.acc.z, random_bits.z), - stochastic_rounding_scalar(value.acc.w, random_bits.w)); -#else - v.a.x = stochastic_rounding_scalar(value.acc.x, random_bits.x); - v.a.y = stochastic_rounding_scalar(value.acc.y, random_bits.y); - v.b.x = stochastic_rounding_scalar(value.acc.z, random_bits.z); - v.b.y = stochastic_rounding_scalar(value.acc.w, random_bits.w); -#endif + v.a = __halves2half2( + stochastic_rounding_scalar(value.acc.x, random_bits.x), + stochastic_rounding_scalar(value.acc.y, random_bits.y)); + v.b = __halves2half2( + stochastic_rounding_scalar(value.acc.z, random_bits.z), + stochastic_rounding_scalar(value.acc.w, random_bits.w)); v.store(output); } @@ -582,17 +584,12 @@ DEVICE_INLINE void stochastic_rounding_vector( float2 /* not used */) { uint4 random_bits = stochastic_rounding_rand4(&state); Half4 v; -#ifdef __HIP_PLATFORM_HCC__ - v.a = __halves2half2(stochastic_rounding_scalar(value.acc.x, random_bits.x), - stochastic_rounding_scalar(value.acc.y, random_bits.y)); - v.b = __halves2half2(stochastic_rounding_scalar(value.acc.z, random_bits.z), - stochastic_rounding_scalar(value.acc.w, random_bits.w)); -#else - v.a.x = stochastic_rounding_scalar(value.acc.x, random_bits.x); - v.a.y = stochastic_rounding_scalar(value.acc.y, random_bits.y); - v.b.x = stochastic_rounding_scalar(value.acc.z, random_bits.z); - v.b.y = stochastic_rounding_scalar(value.acc.w, random_bits.w); -#endif + v.a = __halves2half2( + stochastic_rounding_scalar(value.acc.x, random_bits.x), + stochastic_rounding_scalar(value.acc.y, random_bits.y)); + v.b = __halves2half2( + stochastic_rounding_scalar(value.acc.z, random_bits.z), + stochastic_rounding_scalar(value.acc.w, random_bits.w)); v.store(output); } @@ -645,10 +642,10 @@ template <> DEVICE_INLINE void nearest_rounding_vector(uint8_t* output, Vec4T value, float2 qparams) { float inv_scale = 255.0f / (qparams.x * 255.0f + kQParamEps); - output[0] = std::lrintf((value.acc.x - qparams.y) * inv_scale); - output[1] = std::lrintf((value.acc.y - qparams.y) * inv_scale); - output[2] = std::lrintf((value.acc.z - qparams.y) * inv_scale); - output[3] = std::lrintf((value.acc.w - qparams.y) * inv_scale); + output[0] = lrintf((value.acc.x - qparams.y) * inv_scale); + output[1] = lrintf((value.acc.y - qparams.y) * inv_scale); + output[2] = lrintf((value.acc.z - qparams.y) * inv_scale); + output[3] = lrintf((value.acc.w - qparams.y) * inv_scale); } template <> @@ -657,10 +654,10 @@ DEVICE_INLINE void nearest_rounding_vector( Vec4T value, float2 qparams) { float inv_scale = 255.0f / (qparams.x * 255.0f + kQParamEps); - output[0] = std::lrintf((value.acc.x - qparams.y) * inv_scale); - output[1] = std::lrintf((value.acc.y - qparams.y) * inv_scale); - output[2] = std::lrintf((value.acc.z - qparams.y) * inv_scale); - output[3] = std::lrintf((value.acc.w - qparams.y) * inv_scale); + output[0] = lrintf((value.acc.x - qparams.y) * inv_scale); + output[1] = lrintf((value.acc.y - qparams.y) * inv_scale); + output[2] = lrintf((value.acc.z - qparams.y) * inv_scale); + output[3] = lrintf((value.acc.w - qparams.y) * inv_scale); } template <> @@ -848,7 +845,7 @@ struct SharedMemory>> { // Return if the address is aligned to the type (mainly for Vec4T). template DEVICE_INLINE bool is_aligned(const void* ptr) { - auto iptr = reinterpret_cast(ptr); + auto iptr = reinterpret_cast(ptr); return !(iptr % alignof(T)); } @@ -936,13 +933,8 @@ __device__ float2 warp_find_qparams(scalar_t local_min, scalar_t local_max) { qparams.x = (local_max - local_min) / 255.0f; qparams.y = local_min; } -#ifdef __HIP_PLATFORM_HCC__ - qparams.x = __shfl(qparams.x, 0); - qparams.y = __shfl(qparams.y, 0); -#else - qparams.x = __shfl_sync(0xFFFFFFFF, qparams.x, 0); - qparams.y = __shfl_sync(0xFFFFFFFF, qparams.y, 0); -#endif + qparams.x = shfl_sync(qparams.x, 0); + qparams.y = shfl_sync(qparams.y, 0); return qparams; } @@ -1081,30 +1073,45 @@ dequantize_permuted_int4(uint32_t packedVals, __half2 shift_scale) { // on each 4-bit value is expensive on the ALU, and 4-bit to half is expensive // on the XU. b) doing a 256-entry shared memory LUT on 8-bit pairs is // expensive on SMEM throughput. Credit to @jhj. - res.vals[0] = hmul_short2(v & 0x000F000F, __int2half_rn(32768)); - res.vals[1] = hmul_short2(v & 0x00F000F0, __int2half_rn(32768)); + res.vals[0] = hmul_short2(v & 0x000F000F, __float2half(32768)); + res.vals[1] = hmul_short2(v & 0x00F000F0, __float2half(32768)); v >>= 8; - res.vals[2] = hmul_short2(v & 0x000F000F, __int2half_rn(32768)); - res.vals[3] = hmul_short2(v & 0x00F000F0, __int2half_rn(32768)); + res.vals[2] = hmul_short2(v & 0x000F000F, __float2half(32768)); + res.vals[3] = hmul_short2(v & 0x00F000F0, __float2half(32768)); + + // ~5% perf gain is observed with the explicit type conversions using + // __float2half on Nvidia A100 GPUs (https://fburl.com/diff/ss8372zw) using + // NVCC 11.0. Additionally, HIP compiler requires these explicit type + // conversions. + half shift_scale_x = __low2half(shift_scale); + half shift_scale_y = __high2half(shift_scale); // TODO: Enable this for HIP #ifndef __HIP_PLATFORM_HCC__ res.vals[0] = hfma2( res.vals[0], - __half2(hmul(__ushort2half_rn(shift_scale.x), __int2half_rn(512)), hmul(__ushort2half_rn(shift_scale.x), __int2half_rn(512))), - __half2(__ushort2half_rn(shift_scale.y), __ushort2half_rn(shift_scale.y))); + __half2( + hmul(shift_scale_x, __float2half(512)), + hmul(shift_scale_x, __float2half(512))), + __half2(shift_scale_y, shift_scale_y)); res.vals[1] = hfma2( res.vals[1], - __half2(hmul(__ushort2half_rn(shift_scale.x), __int2half_rn(32)), hmul(__ushort2half_rn(shift_scale.x), __int2half_rn(32))), - __half2(__ushort2half_rn(shift_scale.y), __ushort2half_rn(shift_scale.y))); + __half2( + hmul(shift_scale_x, __float2half(32)), + hmul(shift_scale_x, __float2half(32))), + __half2(shift_scale_y, shift_scale_y)); res.vals[2] = hfma2( res.vals[2], - __half2(hmul(__ushort2half_rn(shift_scale.x), __int2half_rn(512)), hmul(__ushort2half_rn(shift_scale.x), __int2half_rn(512))), - __half2(__ushort2half_rn(shift_scale.y), __ushort2half_rn(shift_scale.y))); + __half2( + hmul(shift_scale_x, __float2half(512)), + hmul(shift_scale_x, __float2half(512))), + __half2(shift_scale_y, shift_scale_y)); res.vals[3] = hfma2( res.vals[3], - __half2(hmul(__ushort2half_rn(shift_scale.x), __int2half_rn(32)), hmul(__ushort2half_rn(shift_scale.x), __int2half_rn(32))), - __half2(__ushort2half_rn(shift_scale.y), __ushort2half_rn(shift_scale.y))); + __half2( + hmul(shift_scale_x, __float2half(32)), + hmul(shift_scale_x, __float2half(32))), + __half2(shift_scale_y, shift_scale_y)); #endif return res; } @@ -1114,17 +1121,25 @@ dequantize_permuted_int8(uint32_t packedVals, __half2 shift_scale) { half4 res; uint32_t v = packedVals; // See comment above, this is a minor variation. - res.vals[0] = hmul_short2(v & 0x00FF00FF, __int2half_rn(32768)); + res.vals[0] = hmul_short2(v & 0x00FF00FF, __float2half(32768)); v >>= 8; - res.vals[1] = hmul_short2(v & 0x00FF00FF, __int2half_rn(32768)); + res.vals[1] = hmul_short2(v & 0x00FF00FF, __float2half(32768)); + + half shift_scale_x = __low2half(shift_scale); + half shift_scale_y = __high2half(shift_scale); + res.vals[0] = hfma2( res.vals[0], - __half2(hmul(__ushort2half_rn(shift_scale.x), __int2half_rn(512)), hmul(__ushort2half_rn(shift_scale.x), __int2half_rn(512))), - __half2(__ushort2half_rn(shift_scale.y), __ushort2half_rn(shift_scale.y))); + __half2( + hmul(shift_scale_x, __float2half(512)), + hmul(shift_scale_x, __float2half(512))), + __half2(shift_scale_y, shift_scale_y)); res.vals[1] = hfma2( res.vals[1], - __half2(hmul(__ushort2half_rn(shift_scale.x), __int2half_rn(512)), hmul(__ushort2half_rn(shift_scale.x), __int2half_rn(512))), - __half2(__ushort2half_rn(shift_scale.y), __ushort2half_rn(shift_scale.y))); + __half2( + hmul(shift_scale_x, __float2half(512)), + hmul(shift_scale_x, __float2half(512))), + __half2(shift_scale_y, shift_scale_y)); return res; } @@ -1239,9 +1254,15 @@ struct VecNT {}; template <> struct VecNT<1> { float acc; + DEVICE_INLINE VecNT() { acc = 0; } + + DEVICE_INLINE VecNT(float a) { + acc = a; + } + DEVICE_INLINE void store(float* output_ptr) { *output_ptr = acc; } @@ -1257,7 +1278,7 @@ struct VecNT<1> { DEVICE_INLINE void store(uint8_t* output_ptr, float2 qparams) { float inv_scale = 255.0f / (qparams.x * 255.0f + kQParamEps); - output_ptr[0] = std::lrintf((acc - qparams.y) * inv_scale); + output_ptr[0] = lrintf((acc - qparams.y) * inv_scale); } DEVICE_INLINE void store(float* output_ptr, float2 qparams) { @@ -1287,10 +1308,15 @@ struct VecNT<1> { template <> struct VecNT<2> { float2 acc; + DEVICE_INLINE VecNT() { acc = make_zero_float2(); } + DEVICE_INLINE VecNT(half2 a) { + acc = __half22float2(a); + } + DEVICE_INLINE void store(float* output_ptr) { *reinterpret_cast(output_ptr) = *reinterpret_cast(&acc); } @@ -1306,8 +1332,8 @@ struct VecNT<2> { DEVICE_INLINE void store(uint8_t* output_ptr, float2 qparams) { float inv_scale = 255.0f / (qparams.x * 255.0f + kQParamEps); - output_ptr[0] = std::lrintf((acc.x - qparams.y) * inv_scale); - output_ptr[1] = std::lrintf((acc.y - qparams.y) * inv_scale); + output_ptr[0] = lrintf((acc.x - qparams.y) * inv_scale); + output_ptr[1] = lrintf((acc.y - qparams.y) * inv_scale); } DEVICE_INLINE void store(float* output_ptr, float2 qparams) { @@ -1338,9 +1364,16 @@ struct VecNT<2> { template <> struct VecNT<4> { float4 acc; + DEVICE_INLINE VecNT() { acc = make_zero_float4(); } + + DEVICE_INLINE VecNT(uint32_t v, half2 shift_scale) { + acc = make_zero_float4(); + acc = accumulate_packed_int8(acc, v, shift_scale); + } + DEVICE_INLINE void store(float* output_ptr) { bool aligned_16b = intptr_t(output_ptr) % 16 == 0; bool aligned_8b = intptr_t(output_ptr) % 8 == 0; @@ -1371,10 +1404,10 @@ struct VecNT<4> { *reinterpret_cast(output_ptr + 0) = v.x; *reinterpret_cast(output_ptr + 2) = v.y; } else { - *(output_ptr + 0) = val.vals[0].x; - *(output_ptr + 1) = val.vals[0].y; - *(output_ptr + 2) = val.vals[1].x; - *(output_ptr + 3) = val.vals[1].y; + *(output_ptr + 0) = __low2half(val.vals[0]); + *(output_ptr + 1) = __high2half(val.vals[0]); + *(output_ptr + 2) = __low2half(val.vals[1]); + *(output_ptr + 3) = __high2half(val.vals[1]); } } @@ -1384,10 +1417,10 @@ struct VecNT<4> { DEVICE_INLINE void store(uint8_t* output_ptr, float2 qparams) { float inv_scale = 255.0f / (qparams.x * 255.0f + kQParamEps); - output_ptr[0] = std::lrintf((acc.x - qparams.y) * inv_scale); - output_ptr[1] = std::lrintf((acc.y - qparams.y) * inv_scale); - output_ptr[2] = std::lrintf((acc.z - qparams.y) * inv_scale); - output_ptr[3] = std::lrintf((acc.w - qparams.y) * inv_scale); + output_ptr[0] = lrintf((acc.x - qparams.y) * inv_scale); + output_ptr[1] = lrintf((acc.y - qparams.y) * inv_scale); + output_ptr[2] = lrintf((acc.z - qparams.y) * inv_scale); + output_ptr[3] = lrintf((acc.w - qparams.y) * inv_scale); } DEVICE_INLINE void store(float* output_ptr, float2 qparams) { @@ -1420,10 +1453,16 @@ struct VecNT<4> { template <> struct VecNT<8> { float8 acc; + DEVICE_INLINE VecNT() { acc = make_zero_float8(); } + DEVICE_INLINE VecNT(uint32_t v, half2 shift_scale) { + acc = make_zero_float8(); + acc = accumulate_packed_int4(acc, v, shift_scale); + } + DEVICE_INLINE void store(float* output_ptr) { bool aligned_16b = intptr_t(output_ptr) % 16 == 0; bool aligned_8b = intptr_t(output_ptr) % 8 == 0; @@ -1470,14 +1509,14 @@ struct VecNT<8> { *reinterpret_cast(output_ptr + 4) = v.z; *reinterpret_cast(output_ptr + 6) = v.w; } else { - *(output_ptr + 0) = val.vals[0].x; - *(output_ptr + 1) = val.vals[0].y; - *(output_ptr + 2) = val.vals[1].x; - *(output_ptr + 3) = val.vals[1].y; - *(output_ptr + 4) = val.vals[2].x; - *(output_ptr + 5) = val.vals[2].y; - *(output_ptr + 6) = val.vals[3].x; - *(output_ptr + 7) = val.vals[3].y; + *(output_ptr + 0) = __low2half(val.vals[0]); + *(output_ptr + 1) = __high2half(val.vals[0]); + *(output_ptr + 2) = __low2half(val.vals[1]); + *(output_ptr + 3) = __high2half(val.vals[1]); + *(output_ptr + 4) = __low2half(val.vals[2]); + *(output_ptr + 5) = __high2half(val.vals[2]); + *(output_ptr + 6) = __low2half(val.vals[3]); + *(output_ptr + 7) = __high2half(val.vals[3]); } } @@ -1487,14 +1526,14 @@ struct VecNT<8> { DEVICE_INLINE void store(uint8_t* output_ptr, float2 qparams) { float inv_scale = 255.0f / (qparams.x * 255.0f + kQParamEps); - output_ptr[0] = std::lrintf((acc.vals[0].x - qparams.y) * inv_scale); - output_ptr[1] = std::lrintf((acc.vals[0].y - qparams.y) * inv_scale); - output_ptr[2] = std::lrintf((acc.vals[0].z - qparams.y) * inv_scale); - output_ptr[3] = std::lrintf((acc.vals[0].w - qparams.y) * inv_scale); - output_ptr[4] = std::lrintf((acc.vals[1].x - qparams.y) * inv_scale); - output_ptr[5] = std::lrintf((acc.vals[1].y - qparams.y) * inv_scale); - output_ptr[6] = std::lrintf((acc.vals[1].z - qparams.y) * inv_scale); - output_ptr[7] = std::lrintf((acc.vals[1].w - qparams.y) * inv_scale); + output_ptr[0] = lrintf((acc.vals[0].x - qparams.y) * inv_scale); + output_ptr[1] = lrintf((acc.vals[0].y - qparams.y) * inv_scale); + output_ptr[2] = lrintf((acc.vals[0].z - qparams.y) * inv_scale); + output_ptr[3] = lrintf((acc.vals[0].w - qparams.y) * inv_scale); + output_ptr[4] = lrintf((acc.vals[1].x - qparams.y) * inv_scale); + output_ptr[5] = lrintf((acc.vals[1].y - qparams.y) * inv_scale); + output_ptr[6] = lrintf((acc.vals[1].z - qparams.y) * inv_scale); + output_ptr[7] = lrintf((acc.vals[1].w - qparams.y) * inv_scale); } DEVICE_INLINE void store(float* output_ptr, float2 qparams) { @@ -1531,6 +1570,26 @@ struct VecNT<8> { #define min(a, b) ((a) < (b) ? (a) : (b)) #define max(a, b) ((a) > (b) ? (a) : (b)) +DEVICE_INLINE float float1_max(float val) { + return val; +} + +DEVICE_INLINE float float1_min(float val) { + return val; +} + +DEVICE_INLINE float float2_max(float2 val) { + float max_val = val.x; + max_val = max(max_val, val.y); + return max_val; +} + +DEVICE_INLINE float float2_min(float2 val) { + float min_val = val.x; + min_val = min(min_val, val.y); + return min_val; +} + DEVICE_INLINE float float4_max(float4 val) { float max_val = val.x; max_val = max(max_val, val.y); diff --git a/fbgemm_gpu/include/fbgemm_gpu/merge_pooled_embeddings.h b/fbgemm_gpu/include/fbgemm_gpu/merge_pooled_embeddings.h new file mode 100644 index 000000000..f22d4998b --- /dev/null +++ b/fbgemm_gpu/include/fbgemm_gpu/merge_pooled_embeddings.h @@ -0,0 +1,18 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace fbgemm_gpu { + +std::vector all_to_one_device( + std::vector inputTensors, + at::Device target_device); + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/include/fbgemm_gpu/quantize_ops.cuh b/fbgemm_gpu/include/fbgemm_gpu/quantize_ops.cuh index f9902206b..3e7ffa184 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/quantize_ops.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/quantize_ops.cuh @@ -6,17 +6,8 @@ */ #pragma once -#include -#ifndef __HIP_PLATFORM_HCC__ -#include -#endif -#include -#include - -#define QUANTIZE_OPS_MAX(a, b) ((a) > (b) ? (a) : (b)) -#define QUANTIZE_OPS_MIN(a, b) ((a) < (b) ? (a) : (b)) - namespace fbgemm_gpu { + template __device__ inline T min(const T* from, const T* to) { T result = *(from++); @@ -38,594 +29,3 @@ __device__ inline T max(const T* from, const T* to) { } } // namespace fbgemm_gpu - -template -__device__ inline __attribute__((always_inline)) T -quantize_ops_shfl_xor(const T val, int laneMask, int width) { -#ifdef __HIP_PLATFORM_HCC__ - return __shfl_xor(val, laneMask, width); -#elif CUDA_VERSION >= 9000 - return __shfl_xor_sync(0xffffffff, val, laneMask, width); -#else - return __shfl_xor(val, laneMask, width); -#endif -} - -__global__ inline void _get_8bit_qparam_cuda_kernel( - const float* __restrict__ input, - int nrows, - int ncols, - uint8_t* __restrict__ output, - float* __restrict__ range_list) { - const int row = (int)blockIdx.x * blockDim.y + threadIdx.y; - - const int ncols_aligned = (ncols + 4 - 1) / 4 * 4; - const int output_columns = ncols_aligned + 2 * sizeof(float); - - // starting values for future reductions - // TODO: Fix this for HIP -#ifdef __HIP_PLATFORM_HCC__ - float minimum_element = 0; - float maximum_element = 0; -#else - float minimum_element = CUDART_INF_F; - float maximum_element = -CUDART_INF_F; -#endif - - // always a power of 2 up to size 32. Multiple rows can share the same warp - // when smaller than 32. - const int lane_width = blockDim.x; - - // March warp-wise through the row, doing thread local min and max reductions. - // This loop will only execute once when ncol <= 32 - if (row < nrows) { - const float* const input_row = input + row * ncols; - - for (int col = threadIdx.x; col < ncols; col += lane_width) { - // Get thread-local minmax. These are the smallest min and max ever seen - // by this thread. - minimum_element = fminf(minimum_element, input_row[col]); - maximum_element = fmaxf(maximum_element, input_row[col]); - } - } - - // Perform warp-wide min and max reductions. All threads in the warp - // participate, even if they aren't assigned to a row, since we can't assume - // the existence of the `*_sync` warp primitives with support for masking. - for (int offset = lane_width >> 1; offset > 0; offset >>= 1) { - minimum_element = fminf( - minimum_element, - quantize_ops_shfl_xor(minimum_element, offset, lane_width)); - maximum_element = fmaxf( - maximum_element, - quantize_ops_shfl_xor(maximum_element, offset, lane_width)); - } - - // only the leading thread in the warp is needed to return the final result in - // output. Additionally, threads mapped to non-existent rows do not write to - // the output array. - if (threadIdx.x != 0 || row >= nrows) { - return; - } - - const float range = maximum_element - minimum_element; - float* const output_row_qparams = - reinterpret_cast(output + row * output_columns + ncols_aligned); - - output_row_qparams[0] = range / 255.0f; - output_row_qparams[1] = minimum_element; - range_list[row] = range; -} - -__global__ inline void _compute_8bit_quantize_cuda_kernel( - const float* const __restrict__ input, - const float* const __restrict__ range_list, - const int nrows, - const int ncols, - std::uint8_t* const __restrict__ output) { - constexpr float kEpsilon = 1e-8f; - - const int ncols_aligned = (ncols + 4 - 1) / 4 * 4; - const int output_columns = ncols_aligned + 2 * sizeof(float); - - int row = (int)blockIdx.y * blockDim.y + threadIdx.y; - const int col = (int)blockIdx.x * blockDim.x + threadIdx.x; - const int row_incre = blockDim.y * gridDim.y; - for (/*row*/; row < nrows; row += row_incre) { - if (col < ncols) { - // load scale, bias - float* row_qparams = reinterpret_cast( - output + row * output_columns + ncols_aligned); - float bias = row_qparams[1]; - - int input_idx = row * ncols + col; - uint8_t* output_addr = output + row * output_columns + col; - // TODO: lift range_list into shared memory. However, when nrows is large, - // it might exceed the size of shared memory. - const auto inverse_scale = 255.0f / (range_list[row] + kEpsilon); - output_addr[0] = std::lrintf((input[input_idx] - bias) * inverse_scale); - } - } -} - -// FP32 -> Fused 8-bit rowwise kernel -__global__ inline void _float_to_fused8bitrowwise_cuda_kernel( - const float* __restrict__ input, - int nrows, - int ncols, - std::uint8_t* __restrict__ output) { - constexpr float kEpsilon = 1e-8f; - - int ncols_aligned = (ncols + 4 - 1) / 4 * 4; - int output_columns = ncols_aligned + 2 * sizeof(float); - - int64_t row = (int)blockIdx.x * blockDim.x + threadIdx.x; - - if (row < nrows) { - const float* input_row = input + row * ncols; - std::uint8_t* output_row = output + row * output_columns; - float* output_row_scale_bias = - reinterpret_cast(output_row + ncols_aligned); - - float minimum_element = fbgemm_gpu::min(input_row, input_row + ncols); - float maximum_element = fbgemm_gpu::max(input_row, input_row + ncols); - float range = maximum_element - minimum_element; - - output_row_scale_bias[0] = range / 255.0f; - output_row_scale_bias[1] = minimum_element; - const auto inverse_scale = 255.0f / (range + kEpsilon); - for (std::size_t col = 0; col < ncols; ++col) { - output_row[col] = - std::lrintf((input_row[col] - minimum_element) * inverse_scale); - } - } -} - -// Fused 8-bit rowwise -> FP32 kernel -__global__ inline void _fused8bitrowwise_to_float_cuda_kernel( - const std::uint8_t* const __restrict__ input, - const int nrows, - const int ncols, - float* const __restrict__ output) { - const int output_columns = ncols - 2 * sizeof(float); - - int row = (int)blockIdx.y * blockDim.y + threadIdx.y; - const int col = (int)blockIdx.x * blockDim.x + threadIdx.x; - const int row_incre = blockDim.y * gridDim.y; - for (/*row*/; row < nrows; row += row_incre) { - if (col < output_columns) { - const std::uint8_t* input_row = input + row * ncols; - const float* input_row_scale_bias = - reinterpret_cast(input_row + output_columns); - float* output_row = output + row * output_columns; - - output_row[col] = - input_row[col] * input_row_scale_bias[0] + input_row_scale_bias[1]; - } - } -} - -// Fake 8-bit quantize kernel: FP32 -> UINT8 rowwise -> FP32 -__global__ inline void _fake_8bit_quantize_cuda_kernel( - const float* __restrict__ input, - int nrows, - int ncols, - float* __restrict__ output) { - constexpr float kEpsilon = 1e-8f; - const int row_incre = blockDim.y * gridDim.y; - for (int row = blockIdx.x * blockDim.x + threadIdx.x; row < nrows; - row += row_incre) { - const float* input_row = input + row * ncols; - float* output_row = output + row * ncols; - const int col_incre = blockDim.x * gridDim.x; - for (int col = blockIdx.y * blockDim.y + threadIdx.y; col < ncols; - col += col_incre) { - float minimum_element = fbgemm_gpu::min(input_row, input_row + ncols); - float maximum_element = fbgemm_gpu::max(input_row, input_row + ncols); - float range = maximum_element - minimum_element; - const auto inverse_scale = 255.0f / (range + kEpsilon); - std::uint8_t quantized_val = - std::lrintf((input_row[col] - minimum_element) * inverse_scale); - output_row[col] = quantized_val * (range / 255.0f) + minimum_element; - } - } -} - -// FP32 -> Fused 4/2-bit rowwise kernel -__global__ inline void _float_to_fusednbitrowwise_cuda_kernel( - int bit_rate, - const float* __restrict__ input, - int nrows, - int ncols, - std::uint8_t* __restrict__ output) { - int num_elem_per_byte = 8 / bit_rate; - int output_columns = - (ncols + num_elem_per_byte - 1) / num_elem_per_byte + 2 * sizeof(__half); - - int row = (int)blockIdx.x * blockDim.x + threadIdx.x; - const int row_incre = blockDim.x * gridDim.x; - for (/*row*/; row < nrows; row += row_incre) { - const float* input_row = input + row * ncols; - std::uint8_t* output_row = output + row * output_columns; - __half* output_row_scale_bias = reinterpret_cast<__half*>( - output_row + (ncols + num_elem_per_byte - 1) / num_elem_per_byte); - - float minimum_element = fbgemm_gpu::min(input_row, input_row + ncols); - float maximum_element = fbgemm_gpu::max(input_row, input_row + ncols); - minimum_element = __half2float(__float2half(minimum_element)); - const float range = maximum_element - minimum_element; - - float scale = __half2float( - __float2half(range == 0 ? 1.0f : range / ((1 << bit_rate) - 1))); - if (scale == 0) { - // Corner case handling when maximum_element == minimum_element - // Any scale would work because X - minimum_element will be 0 for all X - scale = 1.0f; - } - float inverse_scale = 1.0f / scale; - if (std::isinf(inverse_scale)) { - scale = 1.0f; - inverse_scale = 1.0f; - } - - output_row_scale_bias[0] = __float2half(scale); - output_row_scale_bias[1] = __float2half(minimum_element); - for (std::size_t col = 0; col < ncols; ++col) { - float X = input_row[col]; - - std::uint8_t quantized = QUANTIZE_OPS_MAX( - 0, - QUANTIZE_OPS_MIN( - static_cast( - std::lrintf((X - minimum_element) * inverse_scale)), - static_cast((1 << bit_rate) - 1))); - - if (col % num_elem_per_byte == 0) { - output_row[col / num_elem_per_byte] = quantized; - } else { - output_row[col / num_elem_per_byte] |= - (quantized << ((col & (num_elem_per_byte - 1)) * bit_rate)); - } - } - } -} - -// Fused 4/2-bit rowwise -> FP32 kernel -__global__ inline void _fusednbitrowwise_to_float_cuda_kernel( - const int bit_rate, - const std::uint8_t* input, - const int nrows, - const int ncols, - float* const output) { - const int num_elem_per_byte = 8 / bit_rate; - const int output_columns = (ncols - 2 * sizeof(__half)) * num_elem_per_byte; - - int row = (int)blockIdx.y * blockDim.y + threadIdx.y; - const int col = (int)blockIdx.x * blockDim.x + threadIdx.x; - const int row_incre = blockDim.y * gridDim.y; - for (/*row*/; row < nrows; row += row_incre) { - if (row < nrows && col < output_columns) { - const std::uint8_t* input_row = input + row * ncols; - const __half* input_row_scale_bias = reinterpret_cast( - input_row + - (output_columns + num_elem_per_byte - 1) / num_elem_per_byte); - float scale = __half2float(input_row_scale_bias[0]); - float bias = __half2float(input_row_scale_bias[1]); - float* output_row = output + row * output_columns; - - std::uint8_t quantized = input_row[col / num_elem_per_byte]; - quantized >>= (col % num_elem_per_byte) * bit_rate; - quantized &= (1 << bit_rate) - 1; - output_row[col] = scale * quantized + bias; - } - } -} - -// FP32 -> BF16 kernel -__global__ inline void _float_to_bfloat16_cuda_kernel( - const float* __restrict__ input, - const int nrows, - const int ncols, - uint16_t* __restrict__ output) { - const int row_incre = blockDim.y * gridDim.y; - const int col_incre = blockDim.x * gridDim.x; - for (int row = blockIdx.y * blockDim.y + threadIdx.y; row < nrows; - row += row_incre) { - const float* input_row = input + row * ncols; - uint16_t* output_row = output + row * ncols; - for (int col = blockIdx.x * blockDim.x + threadIdx.x; col < ncols; - col += col_incre) { - // Add 2^15 and right shift 16 to do round-nearest - output_row[col] = - (*reinterpret_cast(input_row + col) + (1 << 15)) >> - 16; - } - } -} - -// BF16 -> FP32 kernel -__global__ inline void _bfloat16_to_float_cuda_kernel( - const uint16_t* __restrict__ input, - const int nrows, - const int ncols, - float* __restrict__ output) { - const int row_incre = blockDim.y * gridDim.y; - const int col_incre = blockDim.x * gridDim.x; - for (int row = blockIdx.y * blockDim.y + threadIdx.y; row < nrows; - row += row_incre) { - for (int col = blockIdx.x * blockDim.x + threadIdx.x; col < ncols; - col += col_incre) { - const uint16_t* input_row = input + row * ncols; - float* output_row = output + row * ncols; - uint32_t val_fp32 = static_cast( - reinterpret_cast(input_row)[col]) - << 16; - reinterpret_cast(output_row)[col] = val_fp32; - } - } -} - -typedef union { - uint32_t I; - float F; -} fint32; - -// TODO: add a flag later to control whether underflow -// flushes to 0 or clips to smallest denorm number. -__device__ inline uint8_t float_to_hfp8( - float val_fp, - int ebits, - int mbits, - int bias, - float min_pos, - float max_pos) { - fint32 val_out, bouncer, smallest_normal; - uint32_t sign_bit; - - val_out.F = val_fp; - sign_bit = val_out.I & 0x80000000; - val_out.I = val_out.I & 0x7FFFFFFF; - val_out.F = min(val_out.F, max_pos); - - smallest_normal.I = (127 - bias + 1) - << 23; // smallest hfp8 normal number in FP32 - // I don't know if the input "min_pos" is the smallest denormalized number - // or the smallest normalized number. The test below needs to be done with - // the smallest normal number, which is the numerical value 2^(1-bias) - - // The conversion for denormalized values are slightly different. HFP8 is so - // low precision that gradual underflow is probably crucial - if (val_out.F >= smallest_normal.F) { - // Use round to nearest even. We make use of the standard rounding mechanism - // in FP32 rather than rounding the mantissa and handling tie-to-even and - // incrementing exponent We want to round of 23-mbits of the FP32 value - // val_in This can be done by adding a power of 2 exactly 23-mbits larger - // than the exponent of val_in This forces val_in to be moved to the right - // and rounding exact at the location corresponding to having mbits of - // explicit mantissa left - bouncer.I = (val_out.I & 0xFF800000) + ((23 - mbits) << 23); - val_out.F = (bouncer.F + val_out.F) - bouncer.F; - // adding the bouncer rounds off bits, and subtracting bouncer - // leaves the desired value, albeit in FP32 encoding - // All we need is to change the exponent encoding to using "bias" - val_out.I = uint32_t(val_out.I - ((127 - bias) << 23)) << (8 - ebits); - val_out.I = - ((val_out.I | sign_bit) >> - 24); // the 8 lsbs is the desired HFP8 encoding - - } else { - // When the value is in the denormal range, IEEE numbers essentially becomes - // a fixed point number. The lsb is the smallest non-zero number - // 2^(1-bias-mbits) Hence, we define the bouncer so that its lsb is this - // smallest non-zero number Adding the input to this bouncer forces rounding - // to occur appropriately Also, in this situation, after adding the bouncer, - // the 8 least significant bits of the sum is already the HFP8 encoding of - // the desired result. Just need to restore the sign bit - bouncer.I = (127 + (23 + (1 - bias - mbits))) << 23; - val_out.F = bouncer.F + val_out.F; - val_out.I = val_out.I | (sign_bit >> 24); - ; - } - - uint8_t bfp8_val = val_out.I; // get the 8 lsbs - return bfp8_val; -} - -// TODO: add a flag later to control whether underflow -// flushes to 0 or clips to smallest denorm number. -// -// This following is a "FakeQuant" operator. -// That is, the output is a 32-bit encoding that is to be -// interpreted as a FP32 number. The value has low precision, -// which in general has 1+mbits of precision (the 1 is due to -// the implicit bit), except when the value is subnormal. -__device__ inline float float_to_flexp( - float val_fp, - int ebits, - int mbits, - int bias, - float min_pos, - float max_pos) { - fint32 X, bouncer, scale, inv_scale; - uint32_t sign_bit; - int32_t E, expo, emin, delta_E, nbits2round; - - X.F = val_fp; - sign_bit = X.I & 0x80000000; - X.I = X.I & 0x7FFFFFFF; - - emin = 1 - bias; - - // Because the input value can be of extreme magnitude - // We scale them into less extreme to avoid potential exception during - // manipulation - E = ((X.I & 0x7F800000) >> 23) - 127; - if (E >= 0) { - scale.I = 0X2F800000; - inv_scale.I = 0X4F800000; // scale is 2^-32, inv_scale is 2^32 - delta_E = -32; - } else { - scale.I = 0x4F800000; - inv_scale.I = 0x2F800000; - delta_E = 32; - } - X.F *= scale.F; // at this point X is never close to over/underflow - expo = ((X.I & 0x7F800000) >> 23) - 127 - delta_E; - - // If expo >= emin - // We round to mbits explicit mantissa bits - // That is, we want to round off 23-mbits of the trailing bits in X - nbits2round = 23 - mbits; - // However, if expo < emin, we need to round more bits off - nbits2round += max(emin - expo, 0); - - bouncer.I = (nbits2round << 23) + (X.I & 0x7F800000); - X.F = X.F + bouncer.F; // Because bouncer is exactly 2^nbits2round bigger - // this addition forces the rounding off of nbits2round - X.F = X.F - bouncer.F; // X.F is the original X with nbits2round rounded off - - // restore the true magnitude by undoing the previous scale - X.F *= inv_scale.F; - // clip on the large end of the domain - X.F = min(X.F, max_pos); - // restores the original sign - X.I |= sign_bit; - - float val_flexp = X.F; - return val_flexp; -} - -__device__ inline float -flexp_to_float(float val_flexp, int ebits, int mbits, int bias) { - // Hello - - // Because float_to_flexp is a fakequant operator, - // the input flexp number is already a FP32 number - // with limited precision. - // Thus this flexp_to_float is really a no-op - float val_fp = val_flexp; - return val_fp; -} - -__device__ inline float -hfp8_to_float(uint8_t hfp8_val, int ebits, int mbits, int bias) { - fint32 val_out, sign, multiplier; - - sign.I = (hfp8_val & 0x80) << 24; - val_out.I = (hfp8_val & 0x7F) << (24 - (8 - ebits)); - // printf("val_out %d %d\n", val_out.I, hfp8_val); - // so that the mantissa bits start at the mantissa bit positions of FP32 - // encoding - - // Let the hfp8 mantissa bits correspond to the value frac, 0 <= frac < 1 - // So if the hfp8 value is a normal number, it's value is 2^e x (1+frac) - // where e is its (true, unbiased) exponent - // If the hfp8 value is denormal, the value is 2^(1-bias) x frac - - // However, the bit pattern in the 8-bit exponent field of val_out.F - // is bias+e when hfp8 is normal, and 0 when hfp8 is subnormal. - // So, as an FP32 value, when hfp8 is normal, val_out.F represents the value - // of 2^(bias+e-127) * (1+frac) - // And when hfp8 is subnormal, val_out.F is also subnormal, and represents the - // value of 2^(-126) * frac In either case, val_out.F corresponds to - // 2^(bias-127) * (value of hfp8 input) Thus, if we multiply val_out.F by - // 2^(127-bias), we obtain the hfp8 value as an FP32 number - - multiplier.I = (127 + (127 - bias)) << 23; // multiplier.F is 2^(127-bias) - val_out.F *= multiplier.F; - val_out.I |= sign.I; - return val_out.F; -} - -__global__ inline void _float_to_hfp8_cuda_kernel( - const float* __restrict__ input, - const int nrows, - const int ncols, - uint8_t* __restrict__ output, - int ebits, - int mbits, - int bias, - float min_pos, - float max_pos) { - const int row_incre = blockDim.y * gridDim.y; - const int col_incre = blockDim.x * gridDim.x; - for (int row = blockIdx.y * blockDim.y + threadIdx.y; row < nrows; - row += row_incre) { - const float* input_row = input + row * ncols; - uint8_t* output_row = output + row * ncols; - for (int col = blockIdx.x * blockDim.x + threadIdx.x; col < ncols; - col += col_incre) { - output_row[col] = - float_to_hfp8(input_row[col], ebits, mbits, bias, min_pos, max_pos); - } - } -} - -__global__ inline void _hfp8_to_float_cuda_kernel( - const uint8_t* __restrict__ input, - const int nrows, - const int ncols, - float* __restrict__ output, - int ebits, - int mbits, - int bias) { - const int row_incre = blockDim.y * gridDim.y; - const int col_incre = blockDim.x * gridDim.x; - for (int row = blockIdx.y * blockDim.y + threadIdx.y; row < nrows; - row += row_incre) { - for (int col = blockIdx.x * blockDim.x + threadIdx.x; col < ncols; - col += col_incre) { - const uint8_t* input_row = input + row * ncols; - float* output_row = output + row * ncols; - output_row[col] = hfp8_to_float(input_row[col], ebits, mbits, bias); - } - } -} - -__global__ inline void _float_to_flexp_cuda_kernel( - const float* __restrict__ input, - const int nrows, - const int ncols, - float* __restrict__ output, - int ebits, - int mbits, - int bias, - float min_pos, - float max_pos) { - const int row_incre = blockDim.y * gridDim.y; - const int col_incre = blockDim.x * gridDim.x; - for (int row = blockIdx.y * blockDim.y + threadIdx.y; row < nrows; - row += row_incre) { - const float* input_row = input + row * ncols; - float* output_row = output + row * ncols; - for (int col = blockIdx.x * blockDim.x + threadIdx.x; col < ncols; - col += col_incre) { - output_row[col] = - float_to_flexp(input_row[col], ebits, mbits, bias, min_pos, max_pos); - } - } -} - -__global__ inline void _flexp_to_float_cuda_kernel( - const float* __restrict__ input, - const int nrows, - const int ncols, - float* __restrict__ output, - int ebits, - int mbits, - int bias) { - const int row_incre = blockDim.y * gridDim.y; - const int col_incre = blockDim.x * gridDim.x; - for (int row = blockIdx.y * blockDim.y + threadIdx.y; row < nrows; - row += row_incre) { - for (int col = blockIdx.x * blockDim.x + threadIdx.x; col < ncols; - col += col_incre) { - const float* input_row = input + row * ncols; - float* output_row = output + row * ncols; - output_row[col] = flexp_to_float(input_row[col], ebits, mbits, bias); - } - } -} - -#undef QUANTIZE_OPS_MAX -#undef QUANTIZE_OPS_MIN diff --git a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h index f3f672e1d..e430836be 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h +++ b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h @@ -154,12 +154,12 @@ at::Tensor batched_unary_embeddings_backward_cuda( const at::Tensor& indices); at::Tensor jagged_2d_to_dense_forward_cuda( - at::Tensor embeddings, + at::Tensor values, at::Tensor offsets, int32_t max_L); at::Tensor jagged_2d_to_dense_backward_cuda( - at::Tensor grad_padded_embeddings, + at::Tensor grad_padded_values, at::Tensor offsets, int32_t total_L); @@ -298,4 +298,43 @@ histogram_binning_calibration_by_feature_cuda( int64_t bin_ctr_in_use_after = 0, double bin_ctr_weight_value = 1.0); +// Same as above, but accepts generic "bin_boundaries", which is assumed to be +// sorted. +// +// Returns calibrated_prediction. +std::tuple +generic_histogram_binning_calibration_by_feature_cpu( + const at::Tensor& logit, + const at::Tensor& segment_value, + const at::Tensor& segment_lengths, + int64_t num_segments, + const at::Tensor& bin_num_examples, + const at::Tensor& bin_num_positives, + const at::Tensor& bin_boundaries, + double positive_weight, + int64_t bin_ctr_in_use_after = 0, + double bin_ctr_weight_value = 1.0); + +std::tuple +generic_histogram_binning_calibration_by_feature_cuda( + const at::Tensor& logit, + const at::Tensor& segment_value, + const at::Tensor& segment_lengths, + int64_t num_segments, + const at::Tensor& bin_num_examples, + const at::Tensor& bin_num_positives, + const at::Tensor& bin_boundaries, + double positive_weight, + int64_t bin_ctr_in_use_after = 0, + double bin_ctr_weight_value = 1.0); + +std::tuple embedding_bag_rowwise_prune( + const at::Tensor& weights, + const at::Tensor& indicator, + const double threshold, + at::ScalarType compressed_indices_dtype, + const bool abs, + const int64_t min_non_pruned_rows, + const c10::optional& min_save_ratio); + } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops_utils.h b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops_utils.h index 8221f2b5c..306011889 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops_utils.h +++ b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops_utils.h @@ -48,6 +48,15 @@ inline bool torch_tensor_on_cuda_gpu_check( return !ten.has_value() || ten->is_cuda(); } +inline bool torch_tensor_empty_or_on_cuda_gpu_check(const at::Tensor& ten) { + return (ten.numel() == 0) || ten.is_cuda(); +} + +inline bool torch_tensor_empty_or_on_cuda_gpu_check( + const c10::optional& ten) { + return !ten.has_value() || (ten->numel() == 0) || ten->is_cuda(); +} + #define DISPATCH_TO_CUDA(name, function) \ m.impl(name, torch::dispatch(c10::DispatchKey::CUDA, TORCH_FN(function))) @@ -71,6 +80,12 @@ inline bool torch_tensor_on_cuda_gpu_check( #x " must be a CUDA tensor; it is currently on device ", \ torch_tensor_device_name(x)) +#define TENSOR_EMPTY_OR_ON_CUDA_GPU(x) \ + TORCH_CHECK( \ + torch_tensor_empty_or_on_cuda_gpu_check(x), \ + #x " must be empty or a CUDA tensor; it is currently on device ", \ + torch_tensor_device_name(x)) + #define TENSORS_ON_SAME_DEVICE(x, y) \ TORCH_CHECK( \ torch_tensor_on_same_device_check(x, y), \ diff --git a/fbgemm_gpu/src/split_embeddings_utils.cuh b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.cuh similarity index 77% rename from fbgemm_gpu/src/split_embeddings_utils.cuh rename to fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.cuh index f10fa268c..99f568664 100644 --- a/fbgemm_gpu/src/split_embeddings_utils.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.cuh @@ -41,13 +41,8 @@ inline __device__ void warpBitonicMergeLE16(K& k, V& v) { // Reverse the first comparison stage. // For example, merging a list of size 8 has the exchanges: // 0 <-> 15, 1 <-> 14, ... -#ifdef __HIP_PLATFORM_HCC__ - K otherK = __shfl_xor(k, 2 * L - 1); - V otherV = __shfl_xor(v, 2 * L - 1); -#else - K otherK = shfl_xor(k, 2 * L - 1); - V otherV = shfl_xor(v, 2 * L - 1); -#endif + K otherK = fbgemm_gpu::shfl_xor(k, 2 * L - 1); + V otherV = fbgemm_gpu::shfl_xor(v, 2 * L - 1); // Whether we are the lesser thread in the exchange bool small = !(laneId & L); @@ -69,13 +64,8 @@ inline __device__ void warpBitonicMergeLE16(K& k, V& v) { #pragma unroll for (int32_t stride = IsBitonic ? L : L / 2; stride > 0; stride /= 2) { -#ifdef __HIP_PLATFORM_HCC__ - K otherK = __shfl_xor(k, stride); - V otherV = __shfl_xor(v, stride); -#else - K otherK = shfl_xor(k, stride); - V otherV = shfl_xor(v, stride); -#endif + K otherK = fbgemm_gpu::shfl_xor(k, stride); + V otherV = fbgemm_gpu::shfl_xor(v, stride); // Whether we are the lesser thread in the exchange bool small = !(laneId & stride); @@ -122,3 +112,23 @@ std::pair lru_cache_find_uncached_cuda( at::Tensor lxu_cache_state, int64_t time_stamp, at::Tensor lru_state); + +/** + * "Transpose" embedding inputs by sorting indices by their values. + * Logically this transpose compressed sparse row (CSR) representation + * stored in indices and offsets to compressed sparse column (CSC). + */ +std::tuple< + at::Tensor /*linear_indices*/, + at::Tensor /*linear_indices_sorted*/, + at::Tensor /*infos_sorted*/, + at::Tensor /*sorted_linear_indices_run*/, + at::Tensor /*sorted_linear_indices_run_lengths*/, + at::Tensor /*sorted_linear_indices_num_runs*/, + at::Tensor /*sorted_linear_indices_cumulative_run_lengths*/> +transpose_embedding_input( + at::Tensor hash_size_cumsum, + int64_t total_hash_size_bits, + at::Tensor indices, + at::Tensor offsets, + bool nobag = false); diff --git a/fbgemm_gpu/setup.py b/fbgemm_gpu/setup.py index d3f10f725..7145a16b9 100644 --- a/fbgemm_gpu/setup.py +++ b/fbgemm_gpu/setup.py @@ -264,6 +264,7 @@ def post_hipify(hip_file): os.path.join(cur_dir, "src/cumem_utils_host.cpp"), os.path.join(cur_dir, "src/quantize_ops_cpu.cpp"), os.path.join(cur_dir, "src/quantize_ops_gpu.cpp"), + os.path.join(cur_dir, "src/quantize_ops.cu"), os.path.join(cur_dir, "src/cpu_utils.cpp"), os.path.join(cur_dir, "src/sparse_ops_cpu.cpp"), os.path.join(cur_dir, "src/sparse_ops_gpu.cpp"), @@ -274,6 +275,9 @@ def post_hipify(hip_file): os.path.join(cur_dir, "src/layout_transform_ops_cpu.cpp"), os.path.join(cur_dir, "src/layout_transform_ops_gpu.cpp"), os.path.join(cur_dir, "src/layout_transform_ops.cu"), + os.path.join(cur_dir, "src/jagged_tensor_ops.cu"), + os.path.join(cur_dir, "src/histogram_binning_calibration_ops.cu"), + os.path.join(cur_dir, "src/split_embeddings_utils.cu"), ], include_dirs=[ cur_dir, diff --git a/fbgemm_gpu/src/cumem_utils.cu b/fbgemm_gpu/src/cumem_utils.cu index cf7045356..e973d3fa8 100644 --- a/fbgemm_gpu/src/cumem_utils.cu +++ b/fbgemm_gpu/src/cumem_utils.cu @@ -9,6 +9,8 @@ #include #include +#include +#include #include "cumem_utils.h" #include "fbgemm_gpu/enum_utils.h" @@ -116,6 +118,21 @@ Tensor new_managed_tensor_internal( return at::empty({0}, self.options()) .set_(indirect_storage, 0, sizes, strides); } + +std::tuple adjust_to_page_boundaries(void* ptr, size_t size) { + static uint64_t page_mask = ([]() -> uint64_t { + uint64_t page_size = (uint64_t)sysconf(_SC_PAGESIZE); + return (page_size - 1); + })(); + + uint64_t raw_ptr = (uint64_t)ptr; + uint64_t raw_ptr_adjusted = raw_ptr & ~page_mask; + uint64_t raw_ptr_end_adjusted = (raw_ptr + size + page_mask) & ~page_mask; + uint64_t size_adjusted = raw_ptr_end_adjusted - raw_ptr_adjusted; + + return std::make_tuple((void*)raw_ptr_adjusted, (size_t)size_adjusted); +} + } // namespace // Allocate a cuda Tensor with unified managed memory (UVM) @@ -139,6 +156,12 @@ Tensor new_managed_tensor(Tensor self, std::vector sizes) { ptr, size_bytes, cudaMemAdviseSetAccessedBy, at::cuda::current_device())); C10_CUDA_KERNEL_LAUNCH_CHECK(); + // Work around fork issue - see uvm_mem_advice_dont_fork for details + auto adjusted = adjust_to_page_boundaries(ptr, size_bytes); + int result = + madvise(std::get<0>(adjusted), std::get<1>(adjusted), MADV_DONTFORK); + TORCH_CHECK(result == 0) + return t; } @@ -234,7 +257,6 @@ int64_t uvm_get_guard_index(Tensor& t) { } return cuda_device_index; } - } // namespace #ifdef __HIP_PLATFORM_HCC__ @@ -286,11 +308,14 @@ void uvm_cuda_mem_advise(Tensor t, int64_t cudaMemoryAdvise) { device_guard.set_index(cuda_device_index); +#ifndef __HIP_PLATFORM_HCC__ + // FIXME: some advanced "cudaMemAdvise" flags are not supported by HIP. AT_CUDA_CHECK(cudaMemAdvise( ptr, size_bytes, static_cast(cudaMemoryAdvise), hint_device)); +#endif return; } #endif @@ -332,16 +357,33 @@ void uvm_mem_advice_dont_fork(Tensor t) { size_t size_bytes = at::detail::computeStorageNbytes( t.sizes(), t.strides(), t.dtype().itemsize()); - int result = madvise(ptr, size_bytes, MADV_DONTFORK); + auto adjusted = adjust_to_page_boundaries(ptr, size_bytes); + + int result = + madvise(std::get<0>(adjusted), std::get<1>(adjusted), MADV_DONTFORK); TORCH_CHECK(result == 0) return; } -FBGEMM_GPU_ENUM_GLOGAL(uvm) +Tensor uvm_to_cpu_clone(Tensor t) { + TORCH_CHECK(uvm_storage(t)); + TORCH_CHECK(t.is_contiguous()); + + Tensor cpu_clone = at::empty_like(t, t.options().device(kCPU)); + + size_t size_bytes = at::detail::computeStorageNbytes( + t.sizes(), t.strides(), t.dtype().itemsize()); + + memcpy(cpu_clone.data_ptr(), t.data_ptr(), size_bytes); + return cpu_clone; +} + +FBGEMM_GPU_ENUM_GLOGAL(uvm) #ifdef __HIP_PLATFORM_HCC__ +// FIXME: some advanced "cudaMemAdvise" flags are not supported by HIP. FBGEMM_GPU_ENUM_REGISTER_START(uvm, hipMemoryAdvise){ FBGEMM_GPU_ENUM_ITEM(cudaMemAdviseSetReadMostly), FBGEMM_GPU_ENUM_ITEM(cudaMemAdviseUnsetReadMostly), diff --git a/fbgemm_gpu/src/cumem_utils.h b/fbgemm_gpu/src/cumem_utils.h index 65a5989b0..1e00eef65 100644 --- a/fbgemm_gpu/src/cumem_utils.h +++ b/fbgemm_gpu/src/cumem_utils.h @@ -44,6 +44,10 @@ void uvm_cuda_mem_prefetch_async(Tensor t, c10::optional device_t); // table on fork - causing slowdown on the next access from a CPU. void uvm_mem_advice_dont_fork(Tensor t); +// Copy a contigious uvm Tensor (uvm_storage(t) is true) into a CPU Tensor +// The copy uses single threaded memcpy +Tensor uvm_to_cpu_clone(Tensor t); + FBGEMM_GPU_ENUM_CREATE_TAG(uvm) } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/cumem_utils_host.cpp b/fbgemm_gpu/src/cumem_utils_host.cpp index 41f7933ab..739b6f990 100644 --- a/fbgemm_gpu/src/cumem_utils_host.cpp +++ b/fbgemm_gpu/src/cumem_utils_host.cpp @@ -7,6 +7,7 @@ #include #include #include "fbgemm_gpu/enum_utils.h" +#include "fbgemm_gpu/sparse_ops_utils.h" #include "cumem_utils.h" @@ -21,18 +22,11 @@ TORCH_LIBRARY_FRAGMENT(fb, m) { "uvm_to_device(Tensor self, Tensor prototype) -> Tensor", TORCH_FN(uvm_to_device)); m.def("uvm_to_cpu(Tensor t) -> Tensor"); - m.impl( - "uvm_to_cpu", - torch::dispatch(c10::DispatchKey::CUDA, TORCH_FN(uvm_to_cpu))); + DISPATCH_TO_CUDA("uvm_to_cpu", uvm_to_cpu); m.def("new_managed_tensor(Tensor self, int[] sizes) -> Tensor"); - m.impl( - "new_managed_tensor", - torch::dispatch(c10::DispatchKey::CUDA, TORCH_FN(new_managed_tensor))); + DISPATCH_TO_CUDA("new_managed_tensor", new_managed_tensor); m.def("new_vanilla_managed_tensor(Tensor self, int[] sizes) -> Tensor"); - m.impl( - "new_vanilla_managed_tensor", - torch::dispatch( - c10::DispatchKey::CUDA, TORCH_FN(new_vanilla_managed_tensor))); + DISPATCH_TO_CUDA("new_vanilla_managed_tensor", new_vanilla_managed_tensor); m.def( "cuda_mem_advise(Tensor t, int advice) -> ()", TORCH_FN(uvm_cuda_mem_advise)); @@ -43,7 +37,40 @@ TORCH_LIBRARY_FRAGMENT(fb, m) { "uvm_mem_advice_dont_fork(Tensor t) -> ()", TORCH_FN(uvm_mem_advice_dont_fork)); + m.def("uvm_to_cpu_clone(Tensor t) -> Tensor", TORCH_FN(uvm_to_cpu_clone)); + +#ifndef __HIP_PLATFORM_HCC__ + // FIXME: some advanced "cudaMemAdvise" flags are not supported by HIP. + m.def(FBGEMM_GPU_ENUM_OP(uvm, fbgemm_gpu_uvm_enum_query)); +#endif +} + +TORCH_LIBRARY_FRAGMENT(fbgemm, m) { + m.def("is_uvm_tensor(Tensor t) -> bool", TORCH_FN(is_uvm_tensor)); + m.def("uvm_storage(Tensor t) -> bool", TORCH_FN(uvm_storage)); + m.def( + "uvm_to_device(Tensor self, Tensor prototype) -> Tensor", + TORCH_FN(uvm_to_device)); + m.def("uvm_to_cpu(Tensor t) -> Tensor"); + DISPATCH_TO_CUDA("uvm_to_cpu", uvm_to_cpu); + m.def("new_managed_tensor(Tensor self, int[] sizes) -> Tensor"); + DISPATCH_TO_CUDA("new_managed_tensor", new_managed_tensor); + m.def("new_vanilla_managed_tensor(Tensor self, int[] sizes) -> Tensor"); + DISPATCH_TO_CUDA("new_vanilla_managed_tensor", new_vanilla_managed_tensor); + m.def( + "cuda_mem_advise(Tensor t, int advice) -> ()", + TORCH_FN(uvm_cuda_mem_advise)); + m.def( + "cuda_mem_prefetch_async(Tensor t, Tensor? device_t) -> ()", + TORCH_FN(uvm_cuda_mem_prefetch_async)); + m.def( + "uvm_mem_advice_dont_fork(Tensor t) -> ()", + TORCH_FN(uvm_mem_advice_dont_fork)); + +#ifndef __HIP_PLATFORM_HCC__ + // FIXME: some advanced "cudaMemAdvise" flags are not supported by HIP. m.def(FBGEMM_GPU_ENUM_OP(uvm, fbgemm_gpu_uvm_enum_query)); +#endif } } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/histogram_binning_calibration_ops.cu b/fbgemm_gpu/src/histogram_binning_calibration_ops.cu new file mode 100644 index 000000000..34699d4f7 --- /dev/null +++ b/fbgemm_gpu/src/histogram_binning_calibration_ops.cu @@ -0,0 +1,389 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include "fbgemm_gpu/fbgemm_cuda_utils.cuh" +#include "fbgemm_gpu/sparse_ops.h" +#include "fbgemm_gpu/sparse_ops_utils.h" + +using Tensor = at::Tensor; + +namespace fbgemm_gpu { + +template +__global__ void histogram_binning_calibration_kernel( + const int64_t num_logits, + const int64_t num_bins, + const double recalibrate_value, + const double step, + const int64_t bin_ctr_in_use_after, + const double bin_ctr_weight_value, + const T* const logit_data, + const double* const bin_num_examples_data, + const double* const bin_num_positives_data, + T* const calibrated_prediction_data, + int64_t* const bin_ids_data) { + const int32_t index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= num_logits) { + return; + } + + const T pre_sigmoid = logit_data[index] + recalibrate_value; + const double uncalibrated = 1.0 / (1.0 + exp(-pre_sigmoid)); + + bin_ids_data[index] = ceil(uncalibrated / step) - 1; + + const auto curr_bin_num_examples = bin_num_examples_data[bin_ids_data[index]]; + if (curr_bin_num_examples > bin_ctr_in_use_after) { + const auto curr_bin_ctr = + bin_num_positives_data[bin_ids_data[index]] / curr_bin_num_examples; + calibrated_prediction_data[index] = curr_bin_ctr * bin_ctr_weight_value + + uncalibrated * (1.0 - bin_ctr_weight_value); + } else { + calibrated_prediction_data[index] = uncalibrated; + } +} + +std::tuple histogram_binning_calibration_cuda( + const Tensor& logit, + const Tensor& bin_num_examples, + const Tensor& bin_num_positives, + double positive_weight, + double lower_bound, + double upper_bound, + int64_t bin_ctr_in_use_after, + double bin_ctr_weight_value) { + TENSOR_ON_CUDA_GPU(logit); + TENSOR_ON_CUDA_GPU(bin_num_examples); + TENSOR_ON_CUDA_GPU(bin_num_positives); + TORCH_CHECK(bin_num_examples.numel() == bin_num_positives.numel()); + + at::cuda::OptionalCUDAGuard device_guard; + device_guard.set_index(logit.get_device()); + + Tensor calibrated_prediction = at::empty_like(logit); + Tensor bin_ids = at::empty({logit.numel()}, logit.options().dtype(at::kLong)); + const double recalibrate_value = std::log(positive_weight); + const double step = (upper_bound - lower_bound) / + static_cast(bin_num_examples.numel()); + + const int32_t num_threads = 512; + const auto logit_packed = logit.contiguous(); + const auto bin_num_examples_packed = bin_num_examples.contiguous(); + const auto bin_num_positives_packed = bin_num_positives.contiguous(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + logit.type(), "histogram_binning_calibration_cuda", [&]() { + histogram_binning_calibration_kernel + <<>>( + logit.numel(), + bin_num_examples.numel(), + recalibrate_value, + step, + bin_ctr_in_use_after, + bin_ctr_weight_value, + logit_packed.data_ptr(), + bin_num_examples_packed.data_ptr(), + bin_num_positives_packed.data_ptr(), + calibrated_prediction.data_ptr(), + bin_ids.data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + + return std::make_tuple(calibrated_prediction, bin_ids); +} + +template +__global__ void to_dense_segment_value_kernel( + const int64_t num_lengths, + const int64_t* const segment_value_data, + const T* const segment_offsets_data, + int64_t* const dense_segment_value_data) { + const int32_t index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= num_lengths - 1) { + return; + } + + const auto curr_offset = segment_offsets_data[index]; + const auto next_offset = segment_offsets_data[index + 1]; + if (next_offset > curr_offset) { + // Add 1 to distinguish between 0 inserted by densification vs. original + // value. + dense_segment_value_data[index] = segment_value_data[curr_offset] + 1; + } +} + +template +__global__ void histogram_binning_calibration_by_feature_kernel( + const int64_t num_logits, + const int64_t num_bins, + const int64_t num_segments, + const double recalibrate_value, + const double step, + const int64_t bin_ctr_in_use_after, + const double bin_ctr_weight_value, + const T* const logit_data, + const int64_t* const dense_segment_value_data, + const double* const bin_num_examples_data, + const double* const bin_num_positives_data, + T* const calibrated_prediction_data, + int64_t* const bin_ids_data) { + const int32_t index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= num_logits) { + return; + } + + const T pre_sigmoid = logit_data[index] + recalibrate_value; + const double uncalibrated = 1.0 / (1.0 + exp(-pre_sigmoid)); + + const int64_t curr_segment_value = + dense_segment_value_data[index] > num_segments + ? 0 + : std::max(0L, dense_segment_value_data[index] * num_bins); + + bin_ids_data[index] = ceil(uncalibrated / step) - 1 + curr_segment_value; + + const auto curr_bin_num_examples = bin_num_examples_data[bin_ids_data[index]]; + if (curr_bin_num_examples > bin_ctr_in_use_after) { + const auto curr_bin_ctr = + bin_num_positives_data[bin_ids_data[index]] / curr_bin_num_examples; + calibrated_prediction_data[index] = curr_bin_ctr * bin_ctr_weight_value + + uncalibrated * (1.0 - bin_ctr_weight_value); + } else { + calibrated_prediction_data[index] = uncalibrated; + } +} + +std::tuple histogram_binning_calibration_by_feature_cuda( + const Tensor& logit, + const Tensor& segment_value, + const Tensor& segment_lengths, + int64_t num_segments, + const Tensor& bin_num_examples, + const Tensor& bin_num_positives, + int64_t num_bins, + double positive_weight, + double lower_bound, + double upper_bound, + int64_t bin_ctr_in_use_after, + double bin_ctr_weight_value) { + TENSOR_ON_CUDA_GPU(logit); + TENSOR_ON_CUDA_GPU(segment_value); + TENSOR_ON_CUDA_GPU(segment_lengths); + TENSOR_ON_CUDA_GPU(bin_num_examples); + TENSOR_ON_CUDA_GPU(bin_num_positives); + TORCH_CHECK(bin_num_examples.numel() == bin_num_positives.numel()); + + at::cuda::OptionalCUDAGuard device_guard; + device_guard.set_index(logit.get_device()); + + // Convert lengths to offsets for better handling on GPUs. + const auto segment_lengths_packed = segment_lengths.contiguous(); + auto segment_offsets = + asynchronous_complete_cumsum_gpu(segment_lengths_packed.view(-1)); + + // dense_segment_value is used as a temporary storage. + Tensor dense_segment_value = + at::zeros({logit.numel()}, segment_value.options()); + + const int32_t num_threads = 512; + const auto segment_value_packed = segment_value.contiguous(); + const auto segment_offsets_packed = segment_offsets.contiguous(); + auto dense_segment_value_packed = dense_segment_value.contiguous(); + AT_DISPATCH_INDEX_TYPES( + segment_offsets.scalar_type(), "to_dense_segment_value_cuda", [&]() { + to_dense_segment_value_kernel + <<>>( + segment_offsets.numel(), + segment_value_packed.data_ptr(), + segment_offsets_packed.data_ptr(), + dense_segment_value_packed.data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + + Tensor calibrated_prediction = at::empty_like(logit); + Tensor bin_ids = at::empty({logit.numel()}, logit.options().dtype(at::kLong)); + const double recalibrate_value = std::log(positive_weight); + const double step = + (upper_bound - lower_bound) / static_cast(num_bins); + + const auto logit_packed = logit.contiguous(); + const auto bin_num_examples_packed = bin_num_examples.contiguous(); + const auto bin_num_positives_packed = bin_num_positives.contiguous(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + logit.type(), "histogram_binning_calibration_by_feature_cuda", [&]() { + histogram_binning_calibration_by_feature_kernel + <<>>( + logit.numel(), + num_bins, + num_segments, + recalibrate_value, + step, + bin_ctr_in_use_after, + bin_ctr_weight_value, + logit_packed.data_ptr(), + dense_segment_value_packed.data_ptr(), + bin_num_examples_packed.data_ptr(), + bin_num_positives_packed.data_ptr(), + calibrated_prediction.data_ptr(), + bin_ids.data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + + return std::make_tuple(calibrated_prediction, bin_ids); +} + +template +__global__ void generic_histogram_binning_calibration_by_feature_kernel( + const int64_t num_logits, + const int64_t num_bins, + const int64_t num_segments, + const double recalibrate_value, + const int64_t bin_ctr_in_use_after, + const double bin_ctr_weight_value, + const T* const logit_data, + const int64_t* const dense_segment_value_data, + const double* const bin_num_examples_data, + const double* const bin_num_positives_data, + const double* const bin_boundaries, + T* const calibrated_prediction_data, + int64_t* const bin_ids_data) { + const int32_t index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= num_logits) { + return; + } + + const T pre_sigmoid = logit_data[index] + recalibrate_value; + const double uncalibrated = 1.0 / (1.0 + exp(-pre_sigmoid)); + + // Perform binary search. + int left = 0; + int right = num_bins - 1; + while (left != right) { + const int middle = (left + right) >> 1; + if (bin_boundaries[middle] < uncalibrated) { + left = middle + 1; + } else { + right = middle; + } + } + const int curr_bin_id = left; + + const int64_t curr_segment_value = + dense_segment_value_data[index] > num_segments + ? 0 + : std::max(0L, dense_segment_value_data[index] * num_bins); + + bin_ids_data[index] = curr_bin_id + curr_segment_value; + + const auto curr_bin_num_examples = bin_num_examples_data[bin_ids_data[index]]; + if (curr_bin_num_examples > bin_ctr_in_use_after) { + const auto curr_bin_ctr = + bin_num_positives_data[bin_ids_data[index]] / curr_bin_num_examples; + calibrated_prediction_data[index] = curr_bin_ctr * bin_ctr_weight_value + + uncalibrated * (1.0 - bin_ctr_weight_value); + } else { + calibrated_prediction_data[index] = uncalibrated; + } +} + +std::tuple +generic_histogram_binning_calibration_by_feature_cuda( + const Tensor& logit, + const Tensor& segment_value, + const Tensor& segment_lengths, + int64_t num_segments, + const Tensor& bin_num_examples, + const Tensor& bin_num_positives, + const Tensor& bin_boundaries, + double positive_weight, + int64_t bin_ctr_in_use_after, + double bin_ctr_weight_value) { + TENSOR_ON_CUDA_GPU(logit); + TENSOR_ON_CUDA_GPU(segment_value); + TENSOR_ON_CUDA_GPU(segment_lengths); + TENSOR_ON_CUDA_GPU(bin_num_examples); + TENSOR_ON_CUDA_GPU(bin_num_positives); + TENSOR_ON_CUDA_GPU(bin_boundaries); + TORCH_CHECK(bin_num_examples.numel() == bin_num_positives.numel()); + TORCH_CHECK( + bin_num_examples.numel() == + (num_segments + 1) * (bin_boundaries.numel() + 1)); + + at::cuda::OptionalCUDAGuard device_guard; + device_guard.set_index(logit.get_device()); + + // Convert lengths to offsets for better handling on GPUs. + const auto segment_lengths_packed = segment_lengths.contiguous(); + auto segment_offsets = + asynchronous_complete_cumsum_gpu(segment_lengths_packed.view(-1)); + + // dense_segment_value is used as a temporary storage. + Tensor dense_segment_value = + at::zeros({logit.numel()}, segment_value.options()); + + const int32_t num_threads = 512; + const auto segment_value_packed = segment_value.contiguous(); + const auto segment_offsets_packed = segment_offsets.contiguous(); + auto dense_segment_value_packed = dense_segment_value.contiguous(); + AT_DISPATCH_INDEX_TYPES( + segment_offsets.scalar_type(), "to_dense_segment_value_cuda", [&]() { + to_dense_segment_value_kernel + <<>>( + segment_offsets.numel(), + segment_value_packed.data_ptr(), + segment_offsets_packed.data_ptr(), + dense_segment_value_packed.data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + + Tensor calibrated_prediction = at::empty_like(logit); + Tensor bin_ids = at::empty({logit.numel()}, logit.options().dtype(at::kLong)); + const double recalibrate_value = std::log(positive_weight); + + const auto logit_packed = logit.contiguous(); + const auto bin_num_examples_packed = bin_num_examples.contiguous(); + const auto bin_num_positives_packed = bin_num_positives.contiguous(); + const auto bin_boundaries_packed = bin_boundaries.contiguous(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + logit.type(), "histogram_binning_calibration_by_feature_cuda", [&]() { + generic_histogram_binning_calibration_by_feature_kernel + <<>>( + logit.numel(), + bin_boundaries.numel() + 1, + num_segments, + recalibrate_value, + bin_ctr_in_use_after, + bin_ctr_weight_value, + logit_packed.data_ptr(), + dense_segment_value_packed.data_ptr(), + bin_num_examples_packed.data_ptr(), + bin_num_positives_packed.data_ptr(), + bin_boundaries_packed.data_ptr(), + calibrated_prediction.data_ptr(), + bin_ids.data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + + return std::make_tuple(calibrated_prediction, bin_ids); +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/jagged_tensor_ops.cu b/fbgemm_gpu/src/jagged_tensor_ops.cu new file mode 100644 index 000000000..6c6a241ad --- /dev/null +++ b/fbgemm_gpu/src/jagged_tensor_ops.cu @@ -0,0 +1,236 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include "fbgemm_gpu/fbgemm_cuda_utils.cuh" +#include "fbgemm_gpu/sparse_ops_utils.h" + +using Tensor = at::Tensor; + +namespace fbgemm_gpu { + +template +__global__ void jagged_2d_to_dense_forward_kernel( + int32_t B, + int32_t max_L, + int32_t D, + at::PackedTensorAccessor32 offsets, + at::PackedTensorAccessor64 values, + at::PackedTensorAccessor64 padded_values) { + int32_t b_l = blockIdx.x * blockDim.y + threadIdx.y; + int32_t l = b_l / B; + int32_t b = b_l % B; + if (b_l >= B * max_L) { + return; + } + int32_t row_start = offsets[b]; + int32_t row_end = offsets[b + 1]; + int32_t length = row_end - row_start; + if (l < length) { + for (int32_t d = 0; d < D; d += fbgemm_gpu::kWarpSize) { + if (d + threadIdx.x < D) { + padded_values[b][l][d + threadIdx.x] = + values[row_start + l][d + threadIdx.x]; + } + } + } else { + for (int32_t d = 0; d < D; d += fbgemm_gpu::kWarpSize) { + if (d + threadIdx.x < D) { + padded_values[b][l][d + threadIdx.x] = 0.0; + } + } + } +} + +Tensor +jagged_2d_to_dense_forward_cuda(Tensor values, Tensor offsets, int32_t max_L) { + TENSOR_ON_CUDA_GPU(values); + TENSOR_ON_CUDA_GPU(offsets); + + TORCH_CHECK(values.dim() == 2); + TORCH_CHECK(offsets.dim() == 1); + TORCH_CHECK(max_L > 0); + at::cuda::OptionalCUDAGuard device_guard; + device_guard.set_index(values.get_device()); + + int32_t D = values.size(1); + int32_t B = offsets.numel() - 1; + auto padded_values = at::empty({B, max_L, D}, values.options()); + const auto values_contig = values.contiguous(); + const auto offsets_contig = offsets.contiguous(); + + AT_DISPATCH_INDEX_TYPES( + offsets.scalar_type(), "jagged_2d_to_dense_forward_kernel_1", ([&]() { + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + values.scalar_type(), + "jagged_2d_to_dense_forward_kernel_2", + ([&]() { + jagged_2d_to_dense_forward_kernel + <<>>( + B, + max_L, + D, + offsets_contig.packed_accessor32(), + values_contig.packed_accessor64(), + padded_values.packed_accessor64()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + })); + })); + + return padded_values; +} + +template +__global__ void jagged_2d_to_dense_backward_kernel( + int32_t B, + int32_t max_L, + int32_t D, + at::PackedTensorAccessor32 offsets, + at::PackedTensorAccessor64 grad_padded_values, + at::PackedTensorAccessor64 grad_values) { + int32_t b_l = blockIdx.x * blockDim.y + threadIdx.y; + int32_t l = b_l / B; + int32_t b = b_l % B; + if (b_l >= B * max_L) { + return; + } + int32_t row_start = offsets[b]; + int32_t row_end = offsets[b + 1]; + int32_t length = row_end - row_start; + if (l < length) { + for (int32_t d = 0; d < D; d += fbgemm_gpu::kWarpSize) { + if (d + threadIdx.x < D) { + grad_values[row_start + l][d + threadIdx.x] = + grad_padded_values[b][l][d + threadIdx.x]; + } + } + } +} + +Tensor jagged_2d_to_dense_backward_cuda( + Tensor grad_padded_values, + Tensor offsets, + int32_t total_L) { + TENSOR_ON_CUDA_GPU(grad_padded_values); + TENSOR_ON_CUDA_GPU(offsets); + + TORCH_CHECK(grad_padded_values.dim() == 3); + TORCH_CHECK(offsets.dim() == 1); + TORCH_CHECK(total_L >= 0); + TORCH_CHECK(offsets.numel() == grad_padded_values.size(0) + 1); + at::cuda::OptionalCUDAGuard device_guard; + device_guard.set_index(grad_padded_values.get_device()); + + int32_t B = grad_padded_values.size(0); + int32_t max_L = grad_padded_values.size(1); + int32_t D = grad_padded_values.size(2); + auto grad_values = at::zeros({total_L, D}, grad_padded_values.options()); + const auto grad_padded_values_config = grad_padded_values.contiguous(); + const auto offsets_contig = offsets.contiguous(); + + AT_DISPATCH_INDEX_TYPES( + offsets.scalar_type(), "jagged_2d_to_dense_backward_kernel_1", ([&]() { + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_padded_values.scalar_type(), + "jagged_2d_to_dense_backward_kernel_2", + ([&]() { + jagged_2d_to_dense_backward_kernel + <<>>( + B, + max_L, + D, + offsets_contig.packed_accessor32(), + grad_padded_values_config + .packed_accessor64(), + grad_values.packed_accessor64()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + })); + })); + + return grad_values; +} + +template +__global__ void jagged_1d_to_dense_kernel( + int32_t B, + int32_t max_L, + data_t padding_value, + at::PackedTensorAccessor32 offsets, + at::PackedTensorAccessor64 values, + at::PackedTensorAccessor64 padded_values) { + const int32_t b_l = blockIdx.x * blockDim.x + threadIdx.x; + if (b_l >= B * max_L) { + return; + } + int32_t b = b_l / max_L; + int32_t l = b_l % max_L; + int32_t row_start = offsets[b]; + int32_t row_end = offsets[b + 1]; + int32_t length = row_end - row_start; + if (l < length) { + padded_values[b][l] = values[row_start + l]; + } else { + padded_values[b][l] = padding_value; + } +} + +Tensor jagged_1d_to_dense_gpu( + Tensor values, + Tensor offsets, + int64_t max_L, + int64_t padding_value) { + TENSOR_ON_CUDA_GPU(values); + TENSOR_ON_CUDA_GPU(offsets); + + TORCH_CHECK(values.dim() == 1); + TORCH_CHECK(offsets.dim() == 1); + TORCH_CHECK(max_L > 0); + at::cuda::OptionalCUDAGuard device_guard; + device_guard.set_index(values.get_device()); + + int32_t B = offsets.numel() - 1; + auto padded_values = at::empty({B, max_L}, values.options()); + const auto values_contig = values.contiguous(); + const auto offsets_contig = offsets.contiguous(); + const int32_t num_threads = 512; // 256~1024 per xingl + AT_DISPATCH_INDEX_TYPES( + offsets.scalar_type(), "jagged_1d_to_dense_kernel_1", ([&]() { + AT_DISPATCH_ALL_TYPES( + values.scalar_type(), "jagged_1d_to_dense_kernel_2", ([&]() { + jagged_1d_to_dense_kernel + <<>>( + B, + max_L, + padding_value, + offsets_contig.packed_accessor32(), + values_contig.packed_accessor64(), + padded_values.packed_accessor64()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + })); + })); + + return padded_values; +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/layout_transform_ops.cu b/fbgemm_gpu/src/layout_transform_ops.cu index f00400ef1..9666aadaa 100644 --- a/fbgemm_gpu/src/layout_transform_ops.cu +++ b/fbgemm_gpu/src/layout_transform_ops.cu @@ -13,6 +13,7 @@ #include "fbgemm_gpu/layout_transform_ops.cuh" #include "fbgemm_gpu/sparse_ops.h" +#include "fbgemm_gpu/sparse_ops_utils.h" #include #include @@ -31,6 +32,8 @@ namespace fbgemm_gpu { Tensor recat_embedding_grad_output_cuda( Tensor grad_output, // [B_local][T_global][D] std::vector num_features_per_rank) { + TENSOR_ON_CUDA_GPU(grad_output); + at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(grad_output.get_device()); @@ -69,6 +72,7 @@ Tensor recat_embedding_grad_output_cuda( Tensor recat_embedding_grad_output_mixed_D_cuda( const Tensor& grad_output, // [B_local][Sum_T_global(D)] const std::vector& dim_sum_per_rank) { + TENSOR_ON_CUDA_GPU(grad_output); TORCH_CHECK(grad_output.is_contiguous()); at::cuda::OptionalCUDAGuard device_guard; @@ -110,6 +114,9 @@ Tensor recat_embedding_grad_output_mixed_D_batch_cuda( const Tensor& grad_output, // [B_local][Sum_T_global(D)] const Tensor& dim_sum_per_rank, const Tensor& cumsum_dim_sum_per_rank) { + TENSOR_ON_CUDA_GPU(grad_output); + TENSOR_ON_CUDA_GPU(dim_sum_per_rank); + TENSOR_ON_CUDA_GPU(cumsum_dim_sum_per_rank); TORCH_CHECK(grad_output.is_contiguous()); at::cuda::OptionalCUDAGuard device_guard; diff --git a/fbgemm_gpu/src/merge_pooled_embeddings_gpu.cpp b/fbgemm_gpu/src/merge_pooled_embeddings_gpu.cpp index e621cf682..04ec601ea 100644 --- a/fbgemm_gpu/src/merge_pooled_embeddings_gpu.cpp +++ b/fbgemm_gpu/src/merge_pooled_embeddings_gpu.cpp @@ -15,12 +15,16 @@ #include #include -// TODO: Enable merge_pooled_embeddings for HIP +// FIXME: Enable merge_pooled_embeddings for HIP. +// AMD GPUs don't seem to have nvml equivalent library support. #ifndef __HIP_PLATFORM_HCC__ #include #include +#include "fbgemm_gpu/merge_pooled_embeddings.h" +#include "fbgemm_gpu/sparse_ops_utils.h" + using Tensor = at::Tensor; #define NVML_CHECK(fn) \ @@ -33,6 +37,7 @@ using Node = int64_t; using Links = int64_t; template using AdjacencyMatrix = std::function; +namespace { AdjacencyMatrix get_nvlink_matrix() { auto world_size = at::cuda::getNumGPUs(); @@ -184,80 +189,72 @@ AdjacencyMatrix get_intermediate_node(AdjacencyMatrix links) { return [](Node, Node) { return -1; }; } } -namespace { -Tensor cat_dim_1( - std::vector tensors, - int batch_size, - at::Device output_device) { - if (tensors.size() == 0) { - return at::empty({0}, at::TensorOptions().device(output_device)); - } - int64_t total_dim_1 = 0; - std::vector cumulative_dims; - cumulative_dims.push_back(0); - for (const auto& t : tensors) { - TORCH_CHECK(t.dim() == 2); - TORCH_CHECK(t.size(0) == batch_size); - total_dim_1 += t.size(-1); - cumulative_dims.push_back(total_dim_1); - } - - auto* prop = at::cuda::getCurrentDeviceProperties(); - auto output = at::empty( - {batch_size, total_dim_1}, - tensors.front().options().device(output_device)); - TORCH_CHECK( - output.stride(0) * output.element_size() <= - static_cast(prop->memPitch)); - std::vector copy_begin_events(tensors.size()); - std::vector copy_completion_events(tensors.size()); +// Tensors in `output_tensors` should all be on target_device. We copy the +// tensor in the same index from `input_tensors` to `output_tensors`. If the +// tensor in `input_tensors` is already in the `target_device`, we will skip +// copy it if `skip_if_same_device` is true. +void all_to_one( + std::vector& input_tensors, + std::vector& output_tensors, + at::Device target_device, + bool skip_if_same_device) { + auto num_gpus = at::cuda::getNumGPUs(); + std::vector copy_begin_events(num_gpus); + std::vector copy_completion_events(num_gpus); - Node dst_device_id = output_device.index(); static auto intermediate_nodes = get_intermediate_node(get_nvlink_matrix()); - // Do the intermediate copies, if required by our multi-hop config. - for (auto& ten : tensors) { - Node src_device_id = ten.device().index(); - auto intermediate_node = intermediate_nodes(src_device_id, dst_device_id); + for (auto& ten : input_tensors) { + Node src_device_id = ten.get_device(); + auto intermediate_node = + intermediate_nodes(src_device_id, target_device.index()); if (intermediate_node != -1) { ten = ten.to(at::Device(at::kCUDA, intermediate_node)); } } - // synchronize source streams and launch copies on source stream. - for (const auto i : c10::irange(tensors.size())) { - auto src = tensors[i]; - if (src.device() != output.device()) { - auto dst = output.slice(1, cumulative_dims[i], cumulative_dims[i + 1]); + // For each source device, we sync its current stream and launch all the + // copies that are from that device. + for (const auto device_id : c10::irange(num_gpus)) { + auto src_device = at::Device(at::kCUDA, device_id); + if (src_device == target_device) { + continue; + } - at::Device dst_device = dst.device(); - at::Device src_device = src.device(); - at::cuda::CUDAGuard device_guard(src_device); - // We always perform the copy on the source device, using the current - // stream on the source device, and we fully synchronize on both src and - // dst's current streams for completion of the copy. We have to explicitly - // do this for non-contig copies. This mimics the behavior of cross-device - // cudaMemcpyAsync on the default stream. + // synchronize source streams and launch copies on source stream. + at::cuda::CUDAGuard device_guard(src_device); + // We always perform the copy on the source device, using the current + // stream on the source device, and we fully synchronize on both src and + // dst's current streams for completion of the copy. We have to explicitly + // do this for non-contig copies. This mimics the behavior of cross-device + // cudaMemcpyAsync on the default stream. + + at::cuda::CUDAStream copy_stream = + at::cuda::getCurrentCUDAStream(device_id); + // This is a cross-device copy on the src current stream and dst current + // stream. We perform a two-way barrier between both devices' streams + // before the copy. This ensures that any write-after-write and + // write-after-read dependencies on the destination side are handled, so + // that no one is operating on the dst memory when we perform the copy. + // src waits on dst barrier (src already waits on src) + auto& dst_ready = copy_begin_events[device_id]; + device_guard.set_device(target_device); + dst_ready.record(at::cuda::getCurrentCUDAStream(target_device.index())); + device_guard.set_device(src_device); + dst_ready.block(copy_stream); + for (const auto i : c10::irange(input_tensors.size())) { + auto& src = input_tensors[i]; + if (src.device() != src_device) { + continue; + } - at::cuda::CUDAStream copy_stream = - at::cuda::getCurrentCUDAStream(src_device.index()); - // This is a cross-device copy on the src current stream and dst current - // stream. We perform a two-way barrier between both devices' streams - // before the copy. This ensures that any write-after-write and - // write-after-read dependencies on the destination side are handled, so - // that no one is operating on the dst memory when we perform the copy. - // src waits on dst barrier (src already waits on src) - auto& dst_ready = copy_begin_events[i]; - device_guard.set_device(dst_device); - dst_ready.record(at::cuda::getCurrentCUDAStream(dst_device.index())); - device_guard.set_device(src_device); - dst_ready.block(copy_stream); + auto& dst = output_tensors[i]; // on source device, launch memcpy. AT_CUDA_CHECK(cudaMemcpy2DAsync( dst.data_ptr(), dst.stride(0) * dst.element_size(), src.data_ptr(), - src.stride(0) * dst.element_size(), + src.stride(0) * src.element_size(), src.size(1) * src.element_size(), src.size(0), cudaMemcpyDeviceToDevice, @@ -266,57 +263,84 @@ Tensor cat_dim_1( } // Do the same-GPU cases. - for (const auto i : c10::irange(tensors.size())) { - auto src = tensors[i]; - if (src.device() == output.device()) { - auto dst = output.slice(1, cumulative_dims[i], cumulative_dims[i + 1]); - at::Device src_device = src.device(); - // single device memcpy, not that src_device == dst_device. - at::cuda::CUDAStream copy_stream = - at::cuda::getCurrentCUDAStream(src_device.index()); - AT_CUDA_CHECK(cudaMemcpy2DAsync( - dst.data_ptr(), - dst.stride(0) * dst.element_size(), - src.data_ptr(), - src.stride(0) * src.element_size(), - src.size(1) * src.element_size(), - src.size(0), - cudaMemcpyDeviceToDevice, - copy_stream)); + if (!skip_if_same_device) { + for (const auto i : c10::irange(input_tensors.size())) { + auto& src = input_tensors[i]; + if (src.device() == target_device) { + auto& dst = output_tensors[i]; + // single device memcpy, not that src_device == dst_device. + at::cuda::CUDAStream copy_stream = + at::cuda::getCurrentCUDAStream(target_device.index()); + AT_CUDA_CHECK(cudaMemcpy2DAsync( + dst.data_ptr(), + dst.stride(0) * dst.element_size(), + src.data_ptr(), + src.stride(0) * src.element_size(), + src.size(1) * src.element_size(), + src.size(0), + cudaMemcpyDeviceToDevice, + copy_stream)); + } } } + // wait for cross-device copies to complete. - for (const auto i : c10::irange(tensors.size())) { - auto src = tensors[i]; - if (src.device() != output.device()) { - auto dst = output.slice(1, cumulative_dims[i], cumulative_dims[i + 1]); - at::Device dst_device = dst.device(); - at::Device src_device = src.device(); + for (const auto device_id : c10::irange(num_gpus)) { + if (device_id != target_device.index()) { + auto src_device = at::Device(at::kCUDA, device_id); // Still on src_device, record stream event at::cuda::CUDAGuard device_guard(src_device); at::cuda::CUDAStream copy_stream = - at::cuda::getCurrentCUDAStream(src_device.index()); + at::cuda::getCurrentCUDAStream(device_id); - auto& src_ready = copy_completion_events[i]; + auto& src_ready = copy_completion_events[device_id]; src_ready.record(copy_stream); - device_guard.set_device(dst_device); - src_ready.block(at::cuda::getCurrentCUDAStream(dst_device.index())); + device_guard.set_device(target_device); + src_ready.block(at::cuda::getCurrentCUDAStream(target_device.index())); } } AT_CUDA_CHECK(cudaGetLastError()); +} + +Tensor cat_dim_1( + std::vector tensors, + int batch_size, + at::Device output_device) { + if (tensors.size() == 0) { + return at::empty({0}, at::TensorOptions().device(output_device)); + } + int64_t total_dim_1 = 0; + std::vector cumulative_dims; + cumulative_dims.push_back(0); + for (const auto& t : tensors) { + TORCH_CHECK(t.dim() == 2); + TORCH_CHECK(t.size(0) == batch_size); + total_dim_1 += t.size(-1); + cumulative_dims.push_back(total_dim_1); + } + + auto* prop = at::cuda::getCurrentDeviceProperties(); + auto output = at::empty( + {batch_size, total_dim_1}, + tensors.front().options().device(output_device)); + TORCH_CHECK( + output.stride(0) * output.element_size() <= + static_cast(prop->memPitch)); + std::vector output_tensors; + output_tensors.reserve(tensors.size()); + + for (const auto i : c10::irange(tensors.size())) { + output_tensors.push_back( + output.slice(1, cumulative_dims[i], cumulative_dims[i + 1])); + } + all_to_one( + tensors, output_tensors, output_device, /* skip_if_same_device */ false); return output; } -} // namespace -namespace fbgemm_gpu { - -// TODO: Add device arg. -Tensor merge_pooled_embeddings( - std::vector pooled_embeddings, - int64_t batch_size, - at::Device target_device) { +void init_p2p_access() { static std::once_flag flag; std::call_once(flag, []() { for (const auto i : c10::irange(at::cuda::getNumGPUs())) { @@ -334,25 +358,55 @@ Tensor merge_pooled_embeddings( } } }); +} +} // namespace + +namespace fbgemm_gpu { + +Tensor merge_pooled_embeddings( + std::vector pooled_embeddings, + int64_t batch_size, + at::Device target_device) { + init_p2p_access(); at::cuda::CUDAGuard g(target_device); TORCH_CHECK(!pooled_embeddings.empty()); return cat_dim_1(pooled_embeddings, batch_size, target_device); } +std::vector all_to_one_device( + std::vector input_tensors, + at::Device target_device) { + init_p2p_access(); + at::cuda::CUDAGuard g(target_device); + + std::vector output_tensors; + output_tensors.reserve(input_tensors.size()); + + for (const auto& tensor : input_tensors) { + output_tensors.push_back( + tensor.device() != target_device + ? at::empty(tensor.sizes(), tensor.options().device(target_device)) + : tensor); + } + all_to_one( + input_tensors, + output_tensors, + target_device, + /* skip_if_same_device */ true); + return output_tensors; +} + } // namespace fbgemm_gpu TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( "merge_pooled_embeddings(Tensor[] pooled_embeddings, int batch_size, Device target_device) -> Tensor"); -} - -TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { - m.impl( - "merge_pooled_embeddings", - torch::dispatch( - c10::DispatchKey::CUDA, - TORCH_FN(fbgemm_gpu::merge_pooled_embeddings))); + DISPATCH_TO_CUDA( + "merge_pooled_embeddings", fbgemm_gpu::merge_pooled_embeddings); + m.def( + "all_to_one_device(Tensor[] input_tensors, Device target_device) -> Tensor[]"); + DISPATCH_TO_CUDA("all_to_one_device", fbgemm_gpu::all_to_one_device); } #endif diff --git a/fbgemm_gpu/src/permute_pooled_embedding_ops.cu b/fbgemm_gpu/src/permute_pooled_embedding_ops.cu index a6af19428..58d3bd2d2 100644 --- a/fbgemm_gpu/src/permute_pooled_embedding_ops.cu +++ b/fbgemm_gpu/src/permute_pooled_embedding_ops.cu @@ -14,6 +14,7 @@ #include "fbgemm_gpu/fbgemm_cuda_utils.cuh" #include "fbgemm_gpu/layout_transform_ops.cuh" #include "fbgemm_gpu/permute_pooled_embedding_ops.h" +#include "fbgemm_gpu/sparse_ops_utils.h" using Tensor = at::Tensor; @@ -25,6 +26,12 @@ Tensor permute_pooled_embs_gpu( const Tensor& permute_list, const Tensor& inv_offset_dim_list, const Tensor& inv_permute_list) { + // inv_permute_list is not being used so it's not checked here. + TENSOR_ON_CUDA_GPU(pooled_embs); + TENSOR_ON_CUDA_GPU(offset_dim_list); + TENSOR_ON_CUDA_GPU(permute_list); + TENSOR_ON_CUDA_GPU(inv_offset_dim_list); + at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(pooled_embs.get_device()); // We couldn't pass the "pooled_embs.is_contiguous()" check in the backward @@ -35,9 +42,9 @@ Tensor permute_pooled_embs_gpu( const int64_t T = permute_list.numel(); const int64_t dim_sum = pooled_embs_contiguous.size(1); // inv_permute_list is not being used so it's not checked here. - TORCH_CHECK(pooled_embs_contiguous.device() == offset_dim_list.device()); - TORCH_CHECK(pooled_embs_contiguous.device() == permute_list.device()); - TORCH_CHECK(pooled_embs_contiguous.device() == inv_offset_dim_list.device()); + TENSORS_ON_SAME_DEVICE(pooled_embs_contiguous, offset_dim_list); + TENSORS_ON_SAME_DEVICE(pooled_embs_contiguous, permute_list); + TENSORS_ON_SAME_DEVICE(pooled_embs_contiguous, inv_offset_dim_list); TORCH_CHECK(offset_dim_list.numel() == permute_list.numel() + 1); TORCH_CHECK(offset_dim_list.numel() == inv_offset_dim_list.numel()); Tensor permuted_pooled_embs = at::empty_like(pooled_embs_contiguous); diff --git a/fbgemm_gpu/src/permute_pooled_embedding_ops_gpu.cpp b/fbgemm_gpu/src/permute_pooled_embedding_ops_gpu.cpp index 90f885db2..f12358fcf 100644 --- a/fbgemm_gpu/src/permute_pooled_embedding_ops_gpu.cpp +++ b/fbgemm_gpu/src/permute_pooled_embedding_ops_gpu.cpp @@ -5,17 +5,14 @@ * LICENSE file in the root directory of this source tree. */ #include -#include #include #include -#include -#include #include -#include -#include -#include #include +#include "fbgemm_gpu/permute_pooled_embedding_ops.h" +#include "fbgemm_gpu/sparse_ops_utils.h" + using Tensor = at::Tensor; namespace fbgemm_gpu { @@ -138,11 +135,7 @@ Tensor permute_pooled_embs_auto_grad_cpu( TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( "permute_pooled_embs(Tensor pooled_embs, Tensor offset_dim_list, Tensor permute_list, Tensor inv_offset_dim_list, Tensor inv_permute_list) -> Tensor"); - m.impl( - "permute_pooled_embs", - torch::dispatch( - c10::DispatchKey::CUDA, - TORCH_FN(fbgemm_gpu::permute_pooled_embs_gpu))); + DISPATCH_TO_CUDA("permute_pooled_embs", fbgemm_gpu::permute_pooled_embs_gpu); m.def( "permute_pooled_embs_auto_grad(Tensor pooled_embs, Tensor offset_dim_list, Tensor permute_list, Tensor inv_offset_dim_list, Tensor inv_permute_list) -> Tensor"); m.impl( @@ -150,9 +143,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { torch::dispatch( c10::DispatchKey::CPU, TORCH_FN(fbgemm_gpu::permute_pooled_embs_auto_grad_cpu))); - m.impl( + DISPATCH_TO_CUDA( "permute_pooled_embs_auto_grad", - torch::dispatch( - c10::DispatchKey::CUDA, - TORCH_FN(fbgemm_gpu::permute_pooled_embs_auto_grad_gpu))); + fbgemm_gpu::permute_pooled_embs_auto_grad_gpu); } diff --git a/fbgemm_gpu/src/quantize_ops.cu b/fbgemm_gpu/src/quantize_ops.cu new file mode 100644 index 000000000..236505c5e --- /dev/null +++ b/fbgemm_gpu/src/quantize_ops.cu @@ -0,0 +1,538 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#ifndef __HIP_PLATFORM_HCC__ +#include +#endif +#include "fbgemm_gpu/quantize_ops.cuh" +#include "fbgemm_gpu/sparse_ops_utils.h" + +using Tensor = at::Tensor; + +namespace fbgemm_gpu { + +namespace { + +// FP32 -> Fused 8-bit rowwise kernel +__global__ inline void _float_to_fused8bitrowwise_cuda_kernel( + const float* __restrict__ input, + int nrows, + int ncols, + std::uint8_t* __restrict__ output) { + constexpr float kEpsilon = 1e-8f; + + int ncols_aligned = (ncols + 4 - 1) / 4 * 4; + int output_columns = ncols_aligned + 2 * sizeof(float); + + int64_t row = (int)blockIdx.x * blockDim.x + threadIdx.x; + + if (row < nrows) { + const float* input_row = input + row * ncols; + std::uint8_t* output_row = output + row * output_columns; + float* output_row_scale_bias = + reinterpret_cast(output_row + ncols_aligned); + + float minimum_element = fbgemm_gpu::min(input_row, input_row + ncols); + float maximum_element = fbgemm_gpu::max(input_row, input_row + ncols); + float range = maximum_element - minimum_element; + + output_row_scale_bias[0] = range / 255.0f; + output_row_scale_bias[1] = minimum_element; + const auto inverse_scale = 255.0f / (range + kEpsilon); + for (std::size_t col = 0; col < ncols; ++col) { + output_row[col] = + lrintf((input_row[col] - minimum_element) * inverse_scale); + } + } +} + +template +__device__ inline __attribute__((always_inline)) T +quantize_ops_shfl_xor(const T val, int laneMask, int width) { +#ifdef __HIP_PLATFORM_HCC__ + return __shfl_xor(val, laneMask, width); +#elif CUDA_VERSION >= 9000 + return __shfl_xor_sync(0xffffffff, val, laneMask, width); +#else + return __shfl_xor(val, laneMask, width); +#endif +} + +__global__ inline void _get_8bit_qparam_cuda_kernel( + const float* __restrict__ input, + int nrows, + int ncols, + uint8_t* __restrict__ output, + float* __restrict__ range_list) { + const int row = (int)blockIdx.x * blockDim.y + threadIdx.y; + + const int ncols_aligned = (ncols + 4 - 1) / 4 * 4; + const int output_columns = ncols_aligned + 2 * sizeof(float); + + // starting values for future reductions +#ifdef __HIP_PLATFORM_HCC__ +#define HIPRT_INF_F __int_as_float(0x7f800000) + float minimum_element = HIPRT_INF_F; + float maximum_element = -HIPRT_INF_F; +#undef HIPRT_INF_F +#else + float minimum_element = CUDART_INF_F; + float maximum_element = -CUDART_INF_F; +#endif + + // always a power of 2 up to size 32. Multiple rows can share the same warp + // when smaller than 32. + const int lane_width = blockDim.x; + + // March warp-wise through the row, doing thread local min and max reductions. + // This loop will only execute once when ncol <= 32 + if (row < nrows) { + const float* const input_row = input + row * ncols; + + for (int col = threadIdx.x; col < ncols; col += lane_width) { + // Get thread-local minmax. These are the smallest min and max ever seen + // by this thread. + minimum_element = fminf(minimum_element, input_row[col]); + maximum_element = fmaxf(maximum_element, input_row[col]); + } + } + + // Perform warp-wide min and max reductions. All threads in the warp + // participate, even if they aren't assigned to a row, since we can't assume + // the existence of the `*_sync` warp primitives with support for masking. + for (int offset = lane_width >> 1; offset > 0; offset >>= 1) { + minimum_element = fminf( + minimum_element, + quantize_ops_shfl_xor(minimum_element, offset, lane_width)); + maximum_element = fmaxf( + maximum_element, + quantize_ops_shfl_xor(maximum_element, offset, lane_width)); + } + + // only the leading thread in the warp is needed to return the final result in + // output. Additionally, threads mapped to non-existent rows do not write to + // the output array. + if (threadIdx.x != 0 || row >= nrows) { + return; + } + + const float range = maximum_element - minimum_element; + float* const output_row_qparams = + reinterpret_cast(output + row * output_columns + ncols_aligned); + + output_row_qparams[0] = range / 255.0f; + output_row_qparams[1] = minimum_element; + range_list[row] = range; +} + +__global__ inline void _compute_8bit_quantize_cuda_kernel( + const float* const __restrict__ input, + const float* const __restrict__ range_list, + const int nrows, + const int ncols, + std::uint8_t* const __restrict__ output) { + constexpr float kEpsilon = 1e-8f; + + const int ncols_aligned = (ncols + 4 - 1) / 4 * 4; + const int output_columns = ncols_aligned + 2 * sizeof(float); + + int row = (int)blockIdx.y * blockDim.y + threadIdx.y; + const int col = (int)blockIdx.x * blockDim.x + threadIdx.x; + const int row_incre = blockDim.y * gridDim.y; + for (/*row*/; row < nrows; row += row_incre) { + if (col < ncols) { + // load scale, bias + float* row_qparams = reinterpret_cast( + output + row * output_columns + ncols_aligned); + float bias = row_qparams[1]; + + int input_idx = row * ncols + col; + uint8_t* output_addr = output + row * output_columns + col; + // TODO: lift range_list into shared memory. However, when nrows is large, + // it might exceed the size of shared memory. + const auto inverse_scale = 255.0f / (range_list[row] + kEpsilon); + output_addr[0] = lrintf((input[input_idx] - bias) * inverse_scale); + } + } +} + +// Fused 8-bit rowwise -> FP32 kernel +__global__ inline void _fused8bitrowwise_to_float_cuda_kernel( + const std::uint8_t* const __restrict__ input, + const int nrows, + const int ncols, + float* const __restrict__ output) { + const int output_columns = ncols - 2 * sizeof(float); + + int row = (int)blockIdx.y * blockDim.y + threadIdx.y; + const int col = (int)blockIdx.x * blockDim.x + threadIdx.x; + const int row_incre = blockDim.y * gridDim.y; + for (/*row*/; row < nrows; row += row_incre) { + if (col < output_columns) { + const std::uint8_t* input_row = input + row * ncols; + const float* input_row_scale_bias = + reinterpret_cast(input_row + output_columns); + float* output_row = output + row * output_columns; + + output_row[col] = + input_row[col] * input_row_scale_bias[0] + input_row_scale_bias[1]; + } + } +} + +#define QUANTIZE_OPS_MAX(a, b) ((a) > (b) ? (a) : (b)) +#define QUANTIZE_OPS_MIN(a, b) ((a) < (b) ? (a) : (b)) + +// FP32 -> Fused 4/2-bit rowwise kernel +__global__ inline void _float_to_fusednbitrowwise_cuda_kernel( + int bit_rate, + const float* __restrict__ input, + int nrows, + int ncols, + std::uint8_t* __restrict__ output) { + int num_elem_per_byte = 8 / bit_rate; + int output_columns = + (ncols + num_elem_per_byte - 1) / num_elem_per_byte + 2 * sizeof(__half); + + int row = (int)blockIdx.x * blockDim.x + threadIdx.x; + const int row_incre = blockDim.x * gridDim.x; + for (/*row*/; row < nrows; row += row_incre) { + const float* input_row = input + row * ncols; + std::uint8_t* output_row = output + row * output_columns; + __half* output_row_scale_bias = reinterpret_cast<__half*>( + output_row + (ncols + num_elem_per_byte - 1) / num_elem_per_byte); + + float minimum_element = fbgemm_gpu::min(input_row, input_row + ncols); + float maximum_element = fbgemm_gpu::max(input_row, input_row + ncols); + minimum_element = __half2float(__float2half(minimum_element)); + const float range = maximum_element - minimum_element; + + float scale = __half2float( + __float2half(range == 0 ? 1.0f : range / ((1 << bit_rate) - 1))); + if (scale == 0) { + // Corner case handling when maximum_element == minimum_element + // Any scale would work because X - minimum_element will be 0 for all X + scale = 1.0f; + } + float inverse_scale = 1.0f / scale; + if (std::isinf(inverse_scale)) { + scale = 1.0f; + inverse_scale = 1.0f; + } + + output_row_scale_bias[0] = __float2half(scale); + output_row_scale_bias[1] = __float2half(minimum_element); + for (std::size_t col = 0; col < ncols; ++col) { + float X = input_row[col]; + + std::uint8_t quantized = QUANTIZE_OPS_MAX( + 0, + QUANTIZE_OPS_MIN( + static_cast(lrintf((X - minimum_element) * inverse_scale)), + static_cast((1 << bit_rate) - 1))); + + if (col % num_elem_per_byte == 0) { + output_row[col / num_elem_per_byte] = quantized; + } else { + output_row[col / num_elem_per_byte] |= + (quantized << ((col & (num_elem_per_byte - 1)) * bit_rate)); + } + } + } +} + +// Fused 4/2-bit rowwise -> FP32 kernel +__global__ inline void _fusednbitrowwise_to_float_cuda_kernel( + const int bit_rate, + const std::uint8_t* input, + const int nrows, + const int ncols, + float* const output) { + const int num_elem_per_byte = 8 / bit_rate; + const int output_columns = (ncols - 2 * sizeof(__half)) * num_elem_per_byte; + + int row = (int)blockIdx.y * blockDim.y + threadIdx.y; + const int col = (int)blockIdx.x * blockDim.x + threadIdx.x; + const int row_incre = blockDim.y * gridDim.y; + for (/*row*/; row < nrows; row += row_incre) { + if (row < nrows && col < output_columns) { + const std::uint8_t* input_row = input + row * ncols; + const __half* input_row_scale_bias = reinterpret_cast( + input_row + + (output_columns + num_elem_per_byte - 1) / num_elem_per_byte); + float scale = __half2float(input_row_scale_bias[0]); + float bias = __half2float(input_row_scale_bias[1]); + float* output_row = output + row * output_columns; + + std::uint8_t quantized = input_row[col / num_elem_per_byte]; + quantized >>= (col % num_elem_per_byte) * bit_rate; + quantized &= (1 << bit_rate) - 1; + output_row[col] = scale * quantized + bias; + } + } +} +} // namespace + +Tensor _float_to_fused8bitrowwise_gpu(const Tensor& input) { + TENSOR_ON_CUDA_GPU(input); + TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); + + at::cuda::OptionalCUDAGuard device_guard; + device_guard.set_index(input.get_device()); + + const auto input_sizes = input.sizes(); + const auto last_dim = input_sizes.size() - 1; + const int nrows = c10::size_to_dim_(last_dim, input_sizes); + const int ncols = input_sizes[last_dim]; + const int ncols_aligned = (ncols + 4 - 1) / 4 * 4; + const int output_columns = ncols_aligned + 2 * sizeof(float); + + // Global memory instructions support reading or writing words of size equal + // to 1, 2, 4, 8, or 16 bytes. Any access (via a variable or a pointer) to + // data residing in global memory compiles to a single global memory + // instruction if and only if the size of the data type is 1, 2, 4, 8, or 16 + // bytes and the data is naturally aligned (i.e., its address is a multiple of + // that size). + auto output_dims = input_sizes.vec(); + output_dims[last_dim] = output_columns; + auto output = at::empty( + output_dims, // 4 = sizeof(float) + input.options().dtype(at::kByte)); + + if (nrows == 0 || ncols == 0) { + return output; + } + + constexpr int threads_per_block = 256; + const auto num_blocks = cuda_calc_xblock_count(nrows, threads_per_block); + // think unsigned as we use 0, 255 + + if (nrows <= 20) { + _float_to_fused8bitrowwise_cuda_kernel<<< + num_blocks, + threads_per_block, + 0, + at::cuda::getCurrentCUDAStream()>>>( + input.data_ptr(), nrows, ncols, output.data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + // range_tensor is used to store the range for each embedding row. + // We save range/255.0f as row scale, and use 255.0f / (range + kEpsilon) to + // quantize. This will guarantee the numerical match but bring some perf + // regression. + auto range_tensor = at::empty({nrows}, input.options().dtype(at::kFloat)); + + { + // we need a blockDim.x that is a power of 2 no larger than the warp size + // of 32 + + int blockDim_x = 1; + if (ncols > 16) { + // max warp size + blockDim_x = 32; + } else { + while (blockDim_x < ncols) { + blockDim_x <<= 1; + } + } + + const int rows_per_block = threads_per_block / blockDim_x; + const auto num_blocks_warp = + cuda_calc_xblock_count(nrows, rows_per_block); + + _get_8bit_qparam_cuda_kernel<<< + num_blocks_warp, + dim3(blockDim_x, rows_per_block), + 0, + at::cuda::getCurrentCUDAStream()>>>( + input.data_ptr(), + nrows, + ncols, + output.data_ptr(), + range_tensor.data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + + { + const int blockDim_x = std::min(ncols, threads_per_block); + dim3 blockDim(blockDim_x, threads_per_block / blockDim_x); + const auto gridDim_x = cuda_calc_xblock_count(ncols, blockDim.x); + const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y); + dim3 gridDim(gridDim_x, gridDim_y); + + _compute_8bit_quantize_cuda_kernel<<< + gridDim, + blockDim, + 0, + at::cuda::getCurrentCUDAStream()>>>( + input.data_ptr(), + range_tensor.data_ptr(), + nrows, + ncols, + output.data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + } + + return output; +} + +Tensor _fused8bitrowwise_to_float_gpu(const Tensor& input) { + TENSOR_ON_CUDA_GPU(input); + TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); + + at::cuda::OptionalCUDAGuard device_guard; + device_guard.set_index(input.get_device()); + + const auto input_sizes = input.sizes(); + const auto last_dim = input_sizes.size() - 1; + const int nrows = c10::size_to_dim_(last_dim, input_sizes); + const int ncols = input_sizes[last_dim]; + const int ncols_aligned = (ncols + 4 - 1) / 4 * 4; + const int output_columns = ncols_aligned - 2 * sizeof(float); + + // Global memory instructions support reading or writing words of size equal + // to 1, 2, 4, 8, or 16 bytes. Any access (via a variable or a pointer) to + // data residing in global memory compiles to a single global memory + // instruction if and only if the size of the data type is 1, 2, 4, 8, or 16 + // bytes and the data is naturally aligned (i.e., its address is a multiple of + // that size). + auto output_dims = input_sizes.vec(); + output_dims[last_dim] = output_columns; + auto output = at::empty( + output_dims, // 4 = sizeof(float) + input.options().dtype(at::kFloat)); + + if (nrows == 0 || output_columns == 0) { + return output; + } + + constexpr int threads_per_block = 256; + + const int blockDim_x = std::min(threads_per_block, output_columns); + dim3 blockDim(blockDim_x, threads_per_block / blockDim_x); + + const auto gridDim_x = cuda_calc_xblock_count(output_columns, blockDim.x); + const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y); + dim3 gridDim(gridDim_x, gridDim_y); + + _fused8bitrowwise_to_float_cuda_kernel<<< + gridDim, + blockDim, + 0, + at::cuda::getCurrentCUDAStream()>>>( + input.data_ptr(), nrows, ncols, output.data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return output; +} + +Tensor _float_to_fusednbitrowwise_gpu( + const Tensor& input, + const int64_t bit_rate) { + TENSOR_ON_CUDA_GPU(input); + TENSOR_NDIM_EQUALS(input, 2); + + at::cuda::OptionalCUDAGuard device_guard; + device_guard.set_index(input.get_device()); + + const int nrows = input.size(0); + const int ncols = input.size(1); + const int num_elem_per_byte = 8 / bit_rate; + TORCH_CHECK( + ncols % (2 * num_elem_per_byte) == 0, + "ncols needs to be multiple of 2 Bytes (half type size) to make the address aligned"); + const int output_columns = + (ncols + num_elem_per_byte - 1) / num_elem_per_byte + + 2 * sizeof(at::Half); + + // Global memory instructions support reading or writing words of size equal + // to 1, 2, 4, 8, or 16 bytes. Any access (via a variable or a pointer) to + // data residing in global memory compiles to a single global memory + // instruction if and only if the size of the data type is 1, 2, 4, 8, or 16 + // bytes and the data is naturally aligned (i.e., its address is a multiple of + // that size). + auto output = at::empty( + {nrows, output_columns}, + input.options().dtype(at::kByte)); // at::kBytes for uint8_t + + if (nrows == 0 || ncols == 0) { + return output; + } + + constexpr auto threads_per_block = 256; + const auto num_blocks = cuda_calc_xblock_count(nrows, threads_per_block); + // think unsigned as we use 0, 255 + + _float_to_fusednbitrowwise_cuda_kernel<<< + num_blocks, + threads_per_block, + 0, + at::cuda::getCurrentCUDAStream()>>>( + bit_rate, + input.data_ptr(), + nrows, + ncols, + output.data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return output; +} + +Tensor _fusednbitrowwise_to_float_gpu( + const Tensor& input, + const int64_t bit_rate) { + TENSOR_ON_CUDA_GPU(input); + TENSOR_NDIM_EQUALS(input, 2); + + at::cuda::OptionalCUDAGuard device_guard; + device_guard.set_index(input.get_device()); + + const int nrows = input.size(0); + const int ncols = input.size(1); + const int num_elem_per_byte = 8 / bit_rate; + const int output_columns = (ncols - 2 * sizeof(at::Half)) * num_elem_per_byte; + + // Global memory instructions support reading or writing words of size equal + // to 1, 2, 4, 8, or 16 bytes. Any access (via a variable or a pointer) to + // data residing in global memory compiles to a single global memory + // instruction if and only if the size of the data type is 1, 2, 4, 8, or 16 + // bytes and the data is naturally aligned (i.e., its address is a multiple of + // that size). + auto output = at::empty( + {nrows, output_columns}, // 4 = sizeof(float) + input.options().dtype(at::kFloat)); // at::kBytes for uint8_t + + if (nrows == 0 || output_columns == 0) { + return output; + } + + constexpr int threads_per_block = 256; + + const int blockDim_x = std::min(output_columns, threads_per_block); + dim3 blockDim(blockDim_x, threads_per_block / blockDim_x); + const auto gridDim_x = cuda_calc_xblock_count(output_columns, blockDim.x); + const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y); + dim3 gridDim(gridDim_x, gridDim_y); + + _fusednbitrowwise_to_float_cuda_kernel<<< + gridDim, + blockDim, + 0, + at::cuda::getCurrentCUDAStream()>>>( + bit_rate, + input.data_ptr(), + nrows, + ncols, + output.data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return output; +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/sparse_ops.cu b/fbgemm_gpu/src/sparse_ops.cu index 876a1dc26..45ae079ef 100644 --- a/fbgemm_gpu/src/sparse_ops.cu +++ b/fbgemm_gpu/src/sparse_ops.cu @@ -4,8 +4,6 @@ * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ -#include "fbgemm_gpu/batched_unary_embedding_ops.cuh" -#include "fbgemm_gpu/quantize_ops.cuh" #include "fbgemm_gpu/sparse_ops.cuh" #include "fbgemm_gpu/sparse_ops.h" #include "fbgemm_gpu/sparse_ops_utils.h" @@ -18,15 +16,15 @@ #include -#include "ATen/Parallel.h" - // clang-format off #include "fbgemm_gpu/cub_namespace_prefix.cuh" #include "cub/device/device_scan.cuh" #include "fbgemm_gpu/cub_namespace_postfix.cuh" // clang-format on +#include "fbgemm_gpu/embedding_backward_template_helpers.cuh" #include "fbgemm_gpu/fbgemm_cuda_utils.cuh" +#include "fbgemm_gpu/split_embeddings_utils.cuh" using Tensor = at::Tensor; @@ -127,6 +125,8 @@ Tensor segment_sum_csr_cuda( } Tensor asynchronous_inclusive_cumsum_gpu(const Tensor& t_in) { + TENSOR_ON_CUDA_GPU(t_in); + at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(t_in.get_device()); size_t temp_storage_bytes = 0; @@ -162,6 +162,8 @@ Tensor asynchronous_inclusive_cumsum_gpu(const Tensor& t_in) { } Tensor asynchronous_exclusive_cumsum_gpu(const Tensor& t_in) { + TENSOR_ON_CUDA_GPU(t_in); + at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(t_in.get_device()); size_t temp_storage_bytes = 0; @@ -197,6 +199,8 @@ Tensor asynchronous_exclusive_cumsum_gpu(const Tensor& t_in) { } Tensor asynchronous_complete_cumsum_gpu(const Tensor& t_in) { + TENSOR_ON_CUDA_GPU(t_in); + at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(t_in.get_device()); size_t temp_storage_bytes = 0; @@ -766,263 +770,6 @@ block_bucketize_sparse_features_cuda( return {new_lengths, new_indices, new_weights, new_pos, unbucketize_permute}; } -Tensor _float_to_fused8bitrowwise_gpu(const Tensor& input) { - TENSOR_ON_CUDA_GPU(input); - TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); - - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(input.get_device()); - - const auto input_sizes = input.sizes(); - const auto last_dim = input_sizes.size() - 1; - const int nrows = c10::size_to_dim_(last_dim, input_sizes); - const int ncols = input_sizes[last_dim]; - const int ncols_aligned = (ncols + 4 - 1) / 4 * 4; - const int output_columns = ncols_aligned + 2 * sizeof(float); - - // Global memory instructions support reading or writing words of size equal - // to 1, 2, 4, 8, or 16 bytes. Any access (via a variable or a pointer) to - // data residing in global memory compiles to a single global memory - // instruction if and only if the size of the data type is 1, 2, 4, 8, or 16 - // bytes and the data is naturally aligned (i.e., its address is a multiple of - // that size). - auto output_dims = input_sizes.vec(); - output_dims[last_dim] = output_columns; - auto output = at::empty( - output_dims, // 4 = sizeof(float) - input.options().dtype(at::kByte)); - - if (nrows == 0 || ncols == 0) { - return output; - } - - constexpr int threads_per_block = 256; - const auto num_blocks = cuda_calc_xblock_count(nrows, threads_per_block); - // think unsigned as we use 0, 255 - - if (nrows <= 20) { - _float_to_fused8bitrowwise_cuda_kernel<<< - num_blocks, - threads_per_block, - 0, - at::cuda::getCurrentCUDAStream()>>>( - input.data_ptr(), nrows, ncols, output.data_ptr()); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } else { - // range_tensor is used to store the range for each embedding row. - // We save range/255.0f as row scale, and use 255.0f / (range + kEpsilon) to - // quantize. This will guarantee the numerical match but bring some perf - // regression. - auto range_tensor = at::empty({nrows}, input.options().dtype(at::kFloat)); - - { - // we need a blockDim.x that is a power of 2 no larger than the warp size - // of 32 - - int blockDim_x = 1; - if (ncols > 16) { - // max warp size - blockDim_x = 32; - } else { - while (blockDim_x < ncols) { - blockDim_x <<= 1; - } - } - - const int rows_per_block = threads_per_block / blockDim_x; - const auto num_blocks_warp = - cuda_calc_xblock_count(nrows, rows_per_block); - - _get_8bit_qparam_cuda_kernel<<< - num_blocks_warp, - dim3(blockDim_x, rows_per_block), - 0, - at::cuda::getCurrentCUDAStream()>>>( - input.data_ptr(), - nrows, - ncols, - output.data_ptr(), - range_tensor.data_ptr()); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } - - { - const int blockDim_x = std::min(ncols, threads_per_block); - dim3 blockDim(blockDim_x, threads_per_block / blockDim_x); - const auto gridDim_x = cuda_calc_xblock_count(ncols, blockDim.x); - const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y); - dim3 gridDim(gridDim_x, gridDim_y); - - _compute_8bit_quantize_cuda_kernel<<< - gridDim, - blockDim, - 0, - at::cuda::getCurrentCUDAStream()>>>( - input.data_ptr(), - range_tensor.data_ptr(), - nrows, - ncols, - output.data_ptr()); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } - } - - return output; -} - -Tensor _fused8bitrowwise_to_float_gpu(const Tensor& input) { - TENSOR_ON_CUDA_GPU(input); - TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); - - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(input.get_device()); - - const auto input_sizes = input.sizes(); - const auto last_dim = input_sizes.size() - 1; - const int nrows = c10::size_to_dim_(last_dim, input_sizes); - const int ncols = input_sizes[last_dim]; - const int ncols_aligned = (ncols + 4 - 1) / 4 * 4; - const int output_columns = ncols_aligned - 2 * sizeof(float); - - // Global memory instructions support reading or writing words of size equal - // to 1, 2, 4, 8, or 16 bytes. Any access (via a variable or a pointer) to - // data residing in global memory compiles to a single global memory - // instruction if and only if the size of the data type is 1, 2, 4, 8, or 16 - // bytes and the data is naturally aligned (i.e., its address is a multiple of - // that size). - auto output_dims = input_sizes.vec(); - output_dims[last_dim] = output_columns; - auto output = at::empty( - output_dims, // 4 = sizeof(float) - input.options().dtype(at::kFloat)); - - if (nrows == 0 || output_columns == 0) { - return output; - } - - constexpr int threads_per_block = 256; - - const int blockDim_x = std::min(threads_per_block, output_columns); - dim3 blockDim(blockDim_x, threads_per_block / blockDim_x); - - const auto gridDim_x = cuda_calc_xblock_count(output_columns, blockDim.x); - const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y); - dim3 gridDim(gridDim_x, gridDim_y); - - _fused8bitrowwise_to_float_cuda_kernel<<< - gridDim, - blockDim, - 0, - at::cuda::getCurrentCUDAStream()>>>( - input.data_ptr(), nrows, ncols, output.data_ptr()); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - - return output; -} - -Tensor _float_to_fusednbitrowwise_gpu( - const Tensor& input, - const int64_t bit_rate) { - TENSOR_ON_CUDA_GPU(input); - TENSOR_NDIM_EQUALS(input, 2); - - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(input.get_device()); - - const int nrows = input.size(0); - const int ncols = input.size(1); - const int num_elem_per_byte = 8 / bit_rate; - TORCH_CHECK( - ncols % (2 * num_elem_per_byte) == 0, - "ncols needs to be multiple of 2 Bytes (half type size) to make the address aligned"); - const int output_columns = - (ncols + num_elem_per_byte - 1) / num_elem_per_byte + - 2 * sizeof(at::Half); - - // Global memory instructions support reading or writing words of size equal - // to 1, 2, 4, 8, or 16 bytes. Any access (via a variable or a pointer) to - // data residing in global memory compiles to a single global memory - // instruction if and only if the size of the data type is 1, 2, 4, 8, or 16 - // bytes and the data is naturally aligned (i.e., its address is a multiple of - // that size). - auto output = at::empty( - {nrows, output_columns}, - input.options().dtype(at::kByte)); // at::kBytes for uint8_t - - if (nrows == 0 || ncols == 0) { - return output; - } - - constexpr auto threads_per_block = 256; - const auto num_blocks = cuda_calc_xblock_count(nrows, threads_per_block); - // think unsigned as we use 0, 255 - - _float_to_fusednbitrowwise_cuda_kernel<<< - num_blocks, - threads_per_block, - 0, - at::cuda::getCurrentCUDAStream()>>>( - bit_rate, - input.data_ptr(), - nrows, - ncols, - output.data_ptr()); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - - return output; -} - -Tensor _fusednbitrowwise_to_float_gpu( - const Tensor& input, - const int64_t bit_rate) { - TENSOR_ON_CUDA_GPU(input); - TENSOR_NDIM_EQUALS(input, 2); - - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(input.get_device()); - - const int nrows = input.size(0); - const int ncols = input.size(1); - const int num_elem_per_byte = 8 / bit_rate; - const int output_columns = (ncols - 2 * sizeof(at::Half)) * num_elem_per_byte; - - // Global memory instructions support reading or writing words of size equal - // to 1, 2, 4, 8, or 16 bytes. Any access (via a variable or a pointer) to - // data residing in global memory compiles to a single global memory - // instruction if and only if the size of the data type is 1, 2, 4, 8, or 16 - // bytes and the data is naturally aligned (i.e., its address is a multiple of - // that size). - auto output = at::empty( - {nrows, output_columns}, // 4 = sizeof(float) - input.options().dtype(at::kFloat)); // at::kBytes for uint8_t - - if (nrows == 0 || output_columns == 0) { - return output; - } - - constexpr int threads_per_block = 256; - - const int blockDim_x = std::min(output_columns, threads_per_block); - dim3 blockDim(blockDim_x, threads_per_block / blockDim_x); - const auto gridDim_x = cuda_calc_xblock_count(output_columns, blockDim.x); - const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y); - dim3 gridDim(gridDim_x, gridDim_y); - - _fusednbitrowwise_to_float_cuda_kernel<<< - gridDim, - blockDim, - 0, - at::cuda::getCurrentCUDAStream()>>>( - bit_rate, - input.data_ptr(), - nrows, - ncols, - output.data_ptr()); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - - return output; -} - template __global__ void reorder_batched_ad_lengths_kernel( // reorder lengths from (ragged) [B x T x #num_ads_b)] to @@ -1135,11 +882,7 @@ __global__ void reorder_batched_ad_indices_kernel( const int32_t output_segment_start = reordered_cat_ad_offsets[output_segment_offset_start]; -#ifdef __HIP_PLATFORM_HCC__ for (int32_t i = threadIdx.x; i < input_segment_end - input_segment_start; -#else - for (auto i = threadIdx.x; i < input_segment_end - input_segment_start; -#endif i += blockDim.x) { reordered_cat_ad_indices[output_segment_start + i] = cat_ad_indices[input_segment_start + i]; @@ -1190,6 +933,38 @@ Tensor reorder_batched_ad_indices_gpu( return reordered_cat_ad_indices; } +// Forward kernel for batched unary embedding op +template +__global__ void batched_unary_embeddings_forward_kernel( + const int32_t N, + const int32_t B, + const int32_t T, + const scalar_t* __restrict__ weight, // N * sum(E) * 1 (embedding dimension + // is 1) + const index_t* __restrict__ table_offsets, + const index_t* __restrict__ offsets, + const index_t* __restrict__ indices, + scalar_t* __restrict__ output // N * B * T +) { + index_t sum_E = table_offsets[T]; + int32_t b = blockIdx.x * blockDim.x + threadIdx.x; + if (b >= B) { + return; + } + int32_t t = blockIdx.y; + int32_t n = blockIdx.z; + index_t table_offset = table_offsets[t]; + index_t indices_start = offsets[t * B + b]; + index_t indices_end = offsets[t * B + b + 1]; + int32_t L = indices_end - indices_start; + at::acc_type sum = 0.0; + for (int32_t l = 0; l < L; ++l) { + auto idx = __ldg(&indices[indices_start + l]); + sum += weight[n * sum_E + table_offset + idx + 0]; + } + output[(n * B + b) * T + t] = sum; +} + Tensor batched_unary_embeddings_forward_cuda( const Tensor& weight, const Tensor& table_offsets, @@ -1234,6 +1009,74 @@ Tensor batched_unary_embeddings_forward_cuda( return output; } +// Backward kernel for batched unary embedding op +// We sort input indices so we don't have race conditions, an approach similar +// to the usual split table batched embedding backward. +// We can think of the following alternatives but each with challenges: +// 1) Assign output elements to different threads. Each thread scan all indices +// corresponding to the table it owns but only accumulate gradients when an +// index value matches with the output element it owns. +// A challenge is each thread need to binary search to map from [0 .. sum_E] +// to table id. +// 2) Densify indices and offsets to create [B, sum_E] matrix. Then, do batched +// GEMM where ith GEMM multiplies [N, B] submatrix of grad_output with +// [B, E_i] submatrix where E_i is the num of embeddings of ith table. +// Concatenating the GEMM outputs will result in [N, B, T] +// A challenge is there's no available batched GEMM routine with varying K +// dimension. +template +__global__ void batched_unary_embeddings_backward_kernel( + const int32_t N, + const int32_t B, + const int32_t T, + const scalar_t* __restrict__ grad_output, // [N * B * T] + const index_t* __restrict__ table_offsets, + scalar_t* __restrict__ grad_weight, // [N * sum_E * 1] (embedding + // dimension is 1) + const at::PackedTensorAccessor32 + sorted_linear_indices_run, + const int32_t* __restrict__ sorted_linear_indices_cumulative_run_lengths, + const int32_t* __restrict__ sorted_infos, + const int32_t* __restrict__ sorted_linear_indices_num_runs, + FixedDivisor fd) { + int32_t run_id = blockIdx.x * blockDim.x + threadIdx.x; + int32_t n = blockIdx.y; + if (n >= N) { + return; + } + if (run_id >= sorted_linear_indices_run.size(0)) { + return; + } + if (run_id >= sorted_linear_indices_num_runs[0]) { + return; + } + int64_t linear_index = sorted_linear_indices_run[run_id]; + int32_t segment_start = sorted_linear_indices_cumulative_run_lengths[run_id]; + int32_t segment_end = + sorted_linear_indices_cumulative_run_lengths[run_id + 1]; + int32_t SL = segment_end - segment_start; + + if (SL == 0) { + return; + } + + // now, each segment corresponds to exactly one table `t` and row in + // that table (`idx`). Thus, we can hoist out some of the book-keeping. + auto info = sorted_infos[segment_start]; + int t = fd.Div(info); + + at::acc_type grad_sum = 0.0; + for (int32_t sl = 0; sl < SL; ++sl) { + int32_t b = fd.Mod(sorted_infos[segment_start + sl]); + grad_sum += grad_output[(n * B + b) * T + t]; + } + + index_t table_offset = table_offsets[t]; + index_t sum_E = table_offsets[T]; + int64_t idx = linear_index - table_offset; + grad_weight[n * sum_E + table_offset + idx] = grad_sum; +} + Tensor batched_unary_embeddings_backward_cuda( const Tensor& grad_output, const Tensor& weight, @@ -1256,8 +1099,30 @@ Tensor batched_unary_embeddings_backward_cuda( TORCH_CHECK(N > 0); TORCH_CHECK(B > 0); TORCH_CHECK(T > 0); - int threads = std::min(N * T, 512); - dim3 blocks(cuda_calc_xblock_count(N * T, threads)); + + // weight: [N, sum_E] + // total_hash_size_bits = log2(sum_E) + int64_t total_hash_size_bits = log2(weight.numel() / N) + 1; + + Tensor linear_indices, linear_indices_sorted; + Tensor infos_sorted; + Tensor sorted_linear_indices_run, sorted_linear_indices_run_lengths, + sorted_linear_indices_num_runs, + sorted_linear_indices_cumulative_run_lengths; + std::tie( + linear_indices, + linear_indices_sorted, + infos_sorted, + sorted_linear_indices_run, + sorted_linear_indices_run_lengths, + sorted_linear_indices_num_runs, + sorted_linear_indices_cumulative_run_lengths) = + transpose_embedding_input( + table_offsets, total_hash_size_bits, indices, offsets); + + int threads = std::min(sorted_linear_indices_run.numel(), 512); + dim3 blocks( + cuda_calc_xblock_count(sorted_linear_indices_run.numel(), threads), N); auto grad_weight = at::zeros_like(weight); AT_DISPATCH_INDEX_TYPES( indices.type(), "batched_unary_embeddings_backward_kernel", ([&] { @@ -1272,453 +1137,20 @@ Tensor batched_unary_embeddings_backward_cuda( T, grad_output.data_ptr(), table_offsets.data_ptr(), - offsets.data_ptr(), - indices.data_ptr(), - grad_weight.data_ptr()); + grad_weight.data_ptr(), + sorted_linear_indices_run.packed_accessor32< + index_t, + 1, + at::RestrictPtrTraits>(), + sorted_linear_indices_cumulative_run_lengths + .data_ptr(), + infos_sorted.data_ptr(), + sorted_linear_indices_num_runs.data_ptr(), + FixedDivisor(B)); C10_CUDA_KERNEL_LAUNCH_CHECK(); })); })); return grad_weight; } -template -__global__ void jagged_2d_to_dense_forward_kernel( - int32_t B, - int32_t max_L, - int32_t D, - at::PackedTensorAccessor32 offsets, - at::PackedTensorAccessor64 embeddings, - at::PackedTensorAccessor64 padded_embeddings) { - int32_t b_l = blockIdx.x * blockDim.y + threadIdx.y; - int32_t l = b_l / B; - int32_t b = b_l % B; - if (b_l >= B * max_L) { - return; - } - int32_t row_start = offsets[b]; - int32_t row_end = offsets[b + 1]; - int32_t length = row_end - row_start; - if (l < length) { - for (int32_t d = 0; d < D; d += fbgemm_gpu::kWarpSize) { - if (d + threadIdx.x < D) { - padded_embeddings[b][l][d + threadIdx.x] = - embeddings[row_start + l][d + threadIdx.x]; - } - } - } else { - for (int32_t d = 0; d < D; d += fbgemm_gpu::kWarpSize) { - if (d + threadIdx.x < D) { - padded_embeddings[b][l][d + threadIdx.x] = 0.0; - } - } - } -} - -Tensor jagged_2d_to_dense_forward_cuda( - Tensor embeddings, - Tensor offsets, - int32_t max_L) { - TORCH_CHECK(embeddings.dim() == 2); - TORCH_CHECK(offsets.dim() == 1); - TORCH_CHECK(max_L > 0); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(embeddings.get_device()); - - int32_t D = embeddings.size(1); - int32_t B = offsets.numel() - 1; - auto padded_embeddings = at::empty({B, max_L, D}, embeddings.options()); - const auto embeddings_contig = embeddings.contiguous(); - const auto offsets_contig = offsets.contiguous(); - - AT_DISPATCH_INDEX_TYPES( - offsets.scalar_type(), "jagged_2d_to_dense_forward_kernel_1", ([&]() { - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - embeddings.scalar_type(), - "jagged_2d_to_dense_forward_kernel_2", - ([&]() { - jagged_2d_to_dense_forward_kernel - <<>>( - B, - max_L, - D, - offsets_contig.packed_accessor32(), - embeddings_contig.packed_accessor64(), - padded_embeddings.packed_accessor64()); - })); - })); - - return padded_embeddings; -} - -template -__global__ void jagged_2d_to_dense_backward_kernel( - int32_t B, - int32_t max_L, - int32_t D, - at::PackedTensorAccessor32 offsets, - at::PackedTensorAccessor64 grad_padded_embeddings, - at::PackedTensorAccessor64 grad_embeddings) { - int32_t b_l = blockIdx.x * blockDim.y + threadIdx.y; - int32_t l = b_l / B; - int32_t b = b_l % B; - if (b_l >= B * max_L) { - return; - } - int32_t row_start = offsets[b]; - int32_t row_end = offsets[b + 1]; - int32_t length = row_end - row_start; - if (l < length) { - for (int32_t d = 0; d < D; d += fbgemm_gpu::kWarpSize) { - if (d + threadIdx.x < D) { - grad_embeddings[row_start + l][d + threadIdx.x] = - grad_padded_embeddings[b][l][d + threadIdx.x]; - } - } - } -} - -Tensor jagged_2d_to_dense_backward_cuda( - Tensor grad_padded_embeddings, - Tensor offsets, - int32_t total_L) { - TORCH_CHECK(grad_padded_embeddings.dim() == 3); - TORCH_CHECK(offsets.dim() == 1); - TORCH_CHECK(total_L >= 0); - TORCH_CHECK(offsets.numel() == grad_padded_embeddings.size(0) + 1); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(grad_padded_embeddings.get_device()); - - int32_t B = grad_padded_embeddings.size(0); - int32_t max_L = grad_padded_embeddings.size(1); - int32_t D = grad_padded_embeddings.size(2); - auto grad_embeddings = - at::zeros({total_L, D}, grad_padded_embeddings.options()); - const auto grad_padded_embeddings_config = - grad_padded_embeddings.contiguous(); - const auto offsets_contig = offsets.contiguous(); - - AT_DISPATCH_INDEX_TYPES( - offsets.scalar_type(), "jagged_2d_to_dense_backward_kernel_1", ([&]() { - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad_padded_embeddings.scalar_type(), - "jagged_2d_to_dense_backward_kernel_2", - ([&]() { - jagged_2d_to_dense_backward_kernel - <<>>( - B, - max_L, - D, - offsets_contig.packed_accessor32(), - grad_padded_embeddings_config - .packed_accessor64(), - grad_embeddings.packed_accessor64()); - })); - })); - - return grad_embeddings; -} - -template -__global__ void jagged_1d_to_dense_kernel( - int32_t B, - int32_t max_L, - data_t padding_value, - at::PackedTensorAccessor32 offsets, - at::PackedTensorAccessor64 values, - at::PackedTensorAccessor64 padded_values) { - const int32_t b_l = blockIdx.x * blockDim.x + threadIdx.x; - if (b_l >= B * max_L) { - return; - } - int32_t b = b_l / max_L; - int32_t l = b_l % max_L; - int32_t row_start = offsets[b]; - int32_t row_end = offsets[b + 1]; - int32_t length = row_end - row_start; - if (l < length) { - padded_values[b][l] = values[row_start + l]; - } else { - padded_values[b][l] = padding_value; - } -} - -Tensor jagged_1d_to_dense_gpu( - Tensor values, - Tensor offsets, - int64_t max_L, - int64_t padding_value) { - TORCH_CHECK(values.dim() == 1); - TORCH_CHECK(offsets.dim() == 1); - TORCH_CHECK(max_L > 0); - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(values.get_device()); - - int32_t B = offsets.numel() - 1; - auto padded_values = at::empty({B, max_L}, values.options()); - const auto values_contig = values.contiguous(); - const auto offsets_contig = offsets.contiguous(); - const int32_t num_threads = 512; // 256~1024 per xingl - AT_DISPATCH_INDEX_TYPES( - offsets.scalar_type(), "jagged_1d_to_dense_kernel_1", ([&]() { - AT_DISPATCH_ALL_TYPES( - values.scalar_type(), "jagged_1d_to_dense_kernel_2", ([&]() { - jagged_1d_to_dense_kernel - <<>>( - B, - max_L, - padding_value, - offsets_contig.packed_accessor32(), - values_contig.packed_accessor64(), - padded_values.packed_accessor64()); - })); - })); - - return padded_values; -} - -template -__global__ void histogram_binning_calibration_kernel( - const int64_t num_logits, - const int64_t num_bins, - const double recalibrate_value, - const double step, - const int64_t bin_ctr_in_use_after, - const double bin_ctr_weight_value, - const T* const logit_data, - const double* const bin_num_examples_data, - const double* const bin_num_positives_data, - T* const calibrated_prediction_data, - int64_t* const bin_ids_data) { - const int32_t index = blockIdx.x * blockDim.x + threadIdx.x; - if (index >= num_logits) { - return; - } - - const T pre_sigmoid = logit_data[index] + recalibrate_value; - const double uncalibrated = 1.0 / (1.0 + exp(-pre_sigmoid)); - - bin_ids_data[index] = ceil(uncalibrated / step) - 1; - - const auto curr_bin_num_examples = bin_num_examples_data[bin_ids_data[index]]; - if (curr_bin_num_examples > bin_ctr_in_use_after) { - const auto curr_bin_ctr = - bin_num_positives_data[bin_ids_data[index]] / curr_bin_num_examples; - calibrated_prediction_data[index] = curr_bin_ctr * bin_ctr_weight_value + - uncalibrated * (1.0 - bin_ctr_weight_value); - } else { - calibrated_prediction_data[index] = uncalibrated; - } -} - -std::tuple histogram_binning_calibration_cuda( - const Tensor& logit, - const Tensor& bin_num_examples, - const Tensor& bin_num_positives, - double positive_weight, - double lower_bound, - double upper_bound, - int64_t bin_ctr_in_use_after, - double bin_ctr_weight_value) { - TENSOR_ON_CUDA_GPU(logit); - TENSOR_ON_CUDA_GPU(bin_num_examples); - TENSOR_ON_CUDA_GPU(bin_num_positives); - TORCH_CHECK(bin_num_examples.numel() == bin_num_positives.numel()); - - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(logit.get_device()); - - Tensor calibrated_prediction = at::empty_like(logit); - Tensor bin_ids = at::empty({logit.numel()}, logit.options().dtype(at::kLong)); - const double recalibrate_value = std::log(positive_weight); - const double step = (upper_bound - lower_bound) / - static_cast(bin_num_examples.numel()); - - const int32_t num_threads = 512; - const auto logit_packed = logit.contiguous(); - const auto bin_num_examples_packed = bin_num_examples.contiguous(); - const auto bin_num_positives_packed = bin_num_positives.contiguous(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - logit.type(), "histogram_binning_calibration_cuda", [&]() { - histogram_binning_calibration_kernel - <<>>( - logit.numel(), - bin_num_examples.numel(), - recalibrate_value, - step, - bin_ctr_in_use_after, - bin_ctr_weight_value, - logit_packed.data_ptr(), - bin_num_examples_packed.data_ptr(), - bin_num_positives_packed.data_ptr(), - calibrated_prediction.data_ptr(), - bin_ids.data_ptr()); - }); - - return std::make_tuple(calibrated_prediction, bin_ids); -} - -template -__global__ void to_dense_segment_value_kernel( - const int64_t num_lengths, - const int64_t* const segment_value_data, - const T* const segment_offsets_data, - int64_t* const dense_segment_value_data) { - const int32_t index = blockIdx.x * blockDim.x + threadIdx.x; - if (index >= num_lengths - 1) { - return; - } - - const auto curr_offset = segment_offsets_data[index]; - const auto next_offset = segment_offsets_data[index + 1]; - if (next_offset > curr_offset) { - // Add 1 to distinguish between 0 inserted by densification vs. original - // value. - dense_segment_value_data[index] = segment_value_data[curr_offset] + 1; - } -} - -template -__global__ void histogram_binning_calibration_by_feature_kernel( - const int64_t num_logits, - const int64_t num_bins, - const int64_t num_segments, - const double recalibrate_value, - const double step, - const int64_t bin_ctr_in_use_after, - const double bin_ctr_weight_value, - const T* const logit_data, - const int64_t* const dense_segment_value_data, - const double* const bin_num_examples_data, - const double* const bin_num_positives_data, - T* const calibrated_prediction_data, - int64_t* const bin_ids_data) { - const int32_t index = blockIdx.x * blockDim.x + threadIdx.x; - if (index >= num_logits) { - return; - } - - const T pre_sigmoid = logit_data[index] + recalibrate_value; - const double uncalibrated = 1.0 / (1.0 + exp(-pre_sigmoid)); - - const int64_t curr_segment_value = - dense_segment_value_data[index] > num_segments - ? 0 - : std::max(0L, dense_segment_value_data[index] * num_bins); - - bin_ids_data[index] = ceil(uncalibrated / step) - 1 + curr_segment_value; - - const auto curr_bin_num_examples = bin_num_examples_data[bin_ids_data[index]]; - if (curr_bin_num_examples > bin_ctr_in_use_after) { - const auto curr_bin_ctr = - bin_num_positives_data[bin_ids_data[index]] / curr_bin_num_examples; - calibrated_prediction_data[index] = curr_bin_ctr * bin_ctr_weight_value + - uncalibrated * (1.0 - bin_ctr_weight_value); - } else { - calibrated_prediction_data[index] = uncalibrated; - } -} - -std::tuple histogram_binning_calibration_by_feature_cuda( - const Tensor& logit, - const Tensor& segment_value, - const Tensor& segment_lengths, - int64_t num_segments, - const Tensor& bin_num_examples, - const Tensor& bin_num_positives, - int64_t num_bins, - double positive_weight, - double lower_bound, - double upper_bound, - int64_t bin_ctr_in_use_after, - double bin_ctr_weight_value) { - TENSOR_ON_CUDA_GPU(logit); - TENSOR_ON_CUDA_GPU(segment_value); - TENSOR_ON_CUDA_GPU(segment_lengths); - TENSOR_ON_CUDA_GPU(bin_num_examples); - TENSOR_ON_CUDA_GPU(bin_num_positives); - TORCH_CHECK(bin_num_examples.numel() == bin_num_positives.numel()); - - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(logit.get_device()); - - // Convert lengths to offsets for better handling on GPUs. - const auto segment_lengths_packed = segment_lengths.contiguous(); - auto segment_offsets = - asynchronous_complete_cumsum_gpu(segment_lengths_packed.view(-1)); - - // dense_segment_value is used as a temporary storage. - Tensor dense_segment_value = - at::zeros({logit.numel()}, segment_value.options()); - - const int32_t num_threads = 512; - const auto segment_value_packed = segment_value.contiguous(); - const auto segment_offsets_packed = segment_offsets.contiguous(); - auto dense_segment_value_packed = dense_segment_value.contiguous(); - AT_DISPATCH_INDEX_TYPES( - segment_offsets.scalar_type(), "to_dense_segment_value_cuda", [&]() { - to_dense_segment_value_kernel - <<>>( - segment_offsets.numel(), - segment_value_packed.data_ptr(), - segment_offsets_packed.data_ptr(), - dense_segment_value_packed.data_ptr()); - }); - - Tensor calibrated_prediction = at::empty_like(logit); - Tensor bin_ids = at::empty({logit.numel()}, logit.options().dtype(at::kLong)); - const double recalibrate_value = std::log(positive_weight); - const double step = - (upper_bound - lower_bound) / static_cast(num_bins); - - const auto logit_packed = logit.contiguous(); - const auto bin_num_examples_packed = bin_num_examples.contiguous(); - const auto bin_num_positives_packed = bin_num_positives.contiguous(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - logit.type(), "histogram_binning_calibration_by_feature_cuda", [&]() { - histogram_binning_calibration_by_feature_kernel - <<>>( - logit.numel(), - num_bins, - num_segments, - recalibrate_value, - step, - bin_ctr_in_use_after, - bin_ctr_weight_value, - logit_packed.data_ptr(), - dense_segment_value_packed.data_ptr(), - bin_num_examples_packed.data_ptr(), - bin_num_positives_packed.data_ptr(), - calibrated_prediction.data_ptr(), - bin_ids.data_ptr()); - }); - - return std::make_tuple(calibrated_prediction, bin_ids); -} - } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/sparse_ops_cpu.cpp b/fbgemm_gpu/src/sparse_ops_cpu.cpp index 80ea6d222..35e04d8a0 100644 --- a/fbgemm_gpu/src/sparse_ops_cpu.cpp +++ b/fbgemm_gpu/src/sparse_ops_cpu.cpp @@ -5,7 +5,9 @@ * LICENSE file in the root directory of this source tree. */ +#include #include +#include #include #include @@ -838,8 +840,8 @@ void jagged_2d_to_dense_forward_kernel( int32_t max_L, int32_t D, const index_t* offsets, - const scalar_t* embeddings_data, - scalar_t* padded_embeddings_data) { + const scalar_t* values_data, + scalar_t* padded_values_data) { const auto block_size = max_L * D; const auto embedding_byte_size = D * sizeof(scalar_t); for (auto b = 0; b < B; ++b) { @@ -851,53 +853,51 @@ void jagged_2d_to_dense_forward_kernel( } auto padding_length = max_L - length; memcpy( - &padded_embeddings_data[b * block_size], - &embeddings_data[start_idx * D], + &padded_values_data[b * block_size], + &values_data[start_idx * D], length * embedding_byte_size); memset( - &padded_embeddings_data[b * block_size + length * D], + &padded_values_data[b * block_size + length * D], 0, padding_length * embedding_byte_size); } } -Tensor jagged_2d_to_dense_forward_cpu( - Tensor embeddings, - Tensor offsets, - int64_t max_L) { - TORCH_CHECK(embeddings.dim() == 2); +Tensor +jagged_2d_to_dense_forward_cpu(Tensor values, Tensor offsets, int64_t max_L) { + TORCH_CHECK(values.dim() == 2); TORCH_CHECK(offsets.dim() == 1); TORCH_CHECK(max_L > 0); const auto B = offsets.numel() - 1; - const auto D = embeddings.size(1); - const auto embeddings_contig = embeddings.expect_contiguous(); + const auto D = values.size(1); + const auto values_contig = values.expect_contiguous(); const auto offsets_contig = offsets.expect_contiguous(); - if (embeddings.size(0) == 0) { - return at::zeros({B, max_L, D}, embeddings.options()); + if (values.size(0) == 0) { + return at::zeros({B, max_L, D}, values.options()); } - auto padded_embeddings = at::empty({B, max_L, D}, embeddings.options()); + auto padded_values = at::empty({B, max_L, D}, values.options()); AT_DISPATCH_INDEX_TYPES( offsets_contig->scalar_type(), "jagged_2d_to_dense_forward_by_offsets", ([&]() { AT_DISPATCH_FLOATING_TYPES_AND_HALF( - embeddings_contig->scalar_type(), - "jagged_2d_to_dense_forward_by_embeddings", + values_contig->scalar_type(), + "jagged_2d_to_dense_forward_by_values", ([&]() { jagged_2d_to_dense_forward_kernel( B, max_L, D, offsets_contig->data_ptr(), - embeddings_contig->data_ptr(), - padded_embeddings.data_ptr()); + values_contig->data_ptr(), + padded_values.data_ptr()); })); })); - return padded_embeddings; + return padded_values; } template @@ -1136,6 +1136,116 @@ std::tuple histogram_binning_calibration_by_feature_cpu( return std::make_tuple(calibrated_prediction, bin_ids); } +template +void _generic_histogram_binning_calibration_by_feature_cpu_kernel( + const int64_t num_logits, + const int64_t num_bins, + const int64_t num_segments, + const int64_t num_lengths, + const double recalibrate_value, + const int64_t bin_ctr_in_use_after, + const double bin_ctr_weight_value, + const T* const logit_data, + const int64_t* const segment_value_data, + const int64_t* const segment_lengths_data, + const double* const bin_num_examples_data, + const double* const bin_num_positives_data, + const double* const bin_boundaries, + int64_t* const dense_segment_value_data, + T* const calibrated_prediction_data, + int64_t* const bin_ids_data) { + int k = 0; + for (const auto i : c10::irange(num_lengths)) { + if (segment_lengths_data[i] > 0) { + // Add 1 to distinguish between 0 inserted by densification vs. original + // value. + dense_segment_value_data[i] = segment_value_data[k] + 1; + ++k; + } + } + + for (const auto i : c10::irange(num_logits)) { + const T pre_sigmoid = logit_data[i] + recalibrate_value; + const double uncalibrated = 1.0 / (1.0 + std::exp(-pre_sigmoid)); + + const int curr_bin_id = + std::lower_bound( + bin_boundaries, bin_boundaries + num_bins, uncalibrated) - + bin_boundaries; + + const int64_t curr_segment_value = + dense_segment_value_data[i] > num_segments + ? 0 + : std::max(0L, dense_segment_value_data[i] * num_bins); + + bin_ids_data[i] = curr_bin_id + curr_segment_value; + + const auto curr_bin_num_examples = bin_num_examples_data[bin_ids_data[i]]; + if (curr_bin_num_examples > bin_ctr_in_use_after) { + const auto curr_bin_ctr = + bin_num_positives_data[bin_ids_data[i]] / curr_bin_num_examples; + calibrated_prediction_data[i] = curr_bin_ctr * bin_ctr_weight_value + + uncalibrated * (1.0 - bin_ctr_weight_value); + } else { + calibrated_prediction_data[i] = uncalibrated; + } + } +} + +std::tuple generic_histogram_binning_calibration_by_feature_cpu( + const Tensor& logit, + const Tensor& segment_value, + const Tensor& segment_lengths, + int64_t num_segments, + const Tensor& bin_num_examples, + const Tensor& bin_num_positives, + const Tensor& bin_boundaries, + double positive_weight, + int64_t bin_ctr_in_use_after, + double bin_ctr_weight_value) { + TENSOR_ON_CPU(logit); + TENSOR_ON_CPU(segment_value); + TENSOR_ON_CPU(segment_lengths); + TENSOR_ON_CPU(bin_num_examples); + TENSOR_ON_CPU(bin_num_positives); + TENSOR_ON_CPU(bin_boundaries); + TORCH_CHECK(bin_num_examples.numel() == bin_num_positives.numel()); + TORCH_CHECK( + bin_num_examples.numel() == + (num_segments + 1) * (bin_boundaries.numel() + 1)); + + // dense_segment_value is used as a temporary storage. + Tensor dense_segment_value = + at::zeros({logit.numel()}, segment_value.options()); + Tensor calibrated_prediction = at::empty_like(logit); + Tensor bin_ids = at::empty({logit.numel()}, logit.options().dtype(at::kLong)); + const double recalibrate_value = std::log(positive_weight); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + logit.type(), + "generic_histogram_binning_calibration_by_feature_cpu", + [&]() { + _generic_histogram_binning_calibration_by_feature_cpu_kernel( + logit.numel(), + bin_boundaries.numel() + 1, + num_segments, + segment_lengths.numel(), + recalibrate_value, + bin_ctr_in_use_after, + bin_ctr_weight_value, + logit.data_ptr(), + segment_value.data_ptr(), + segment_lengths.data_ptr(), + bin_num_examples.data_ptr(), + bin_num_positives.data_ptr(), + bin_boundaries.data_ptr(), + dense_segment_value.data_ptr(), + calibrated_prediction.data_ptr(), + bin_ids.data_ptr()); + }); + + return std::make_tuple(calibrated_prediction, bin_ids); +} + template void _segment_sum_csr_cpu_kernel( const int num_segments, @@ -1172,6 +1282,125 @@ Tensor segment_sum_csr_cpu( })); return output; } + +bool should_prune( + const Tensor& weights, + const int64_t num_rows_kept, + double min_save_ratio) { + TENSOR_ON_CPU(weights); + const auto weight_sizes = weights.sizes(); + + const int64_t data_byte_size = sizeof(float); + const int64_t num_cols = weight_sizes[1]; + + // Size of the pruned weights tensor. + const int64_t lut_after_prune_size = + num_rows_kept * num_cols * data_byte_size; + + constexpr auto index_byte_size = sizeof(int); + const auto lut_num_row = weight_sizes[0]; + const int64_t compressed_idx_overhead_size = lut_num_row * index_byte_size; + + const int64_t original_size = data_byte_size * weights.numel(); + return (compressed_idx_overhead_size + lut_after_prune_size) < + min_save_ratio * original_size; +} + +// This operator introduces sparsity to a weight matrix by applying +// magnitude based pruning at a row level. The importance level of a row is +// specified using an 'indicator' vector which contains a single value per +// row of the weight matrix. +// +// A row is considered important and not pruned if the indicator value for that +// particular row is greater than the pruning 'threshold' value. +// +// This operator doesn't zero out the pruned rows in-place. Instead, it returns +// a tuple that contains a pruned weights tensor as well as a map that can be +// used to refer the original row in the pruned weights tensor. We refer this +// map as 'compressed indices map' going forward. + +// The compressed indices map is an 1D tensor that contains one entry per +// original row in 'weights'. The array index is the index for the original +// non-pruned weight tensor and the value would be the re-mapped index in the +// pruned weights tensor. If the value for a index is -1, it means the +// corresponding row has been pruned from the original weight tensor. + +// Arguments: +// 'weights' - the weight tensor that needs to be pruned rowwise. +// 'indicator' - the magnitude for every row of the 'weights' matrix. +// 'threshold' - the pruning threshold that will be used for comparison +// against the indicator row value. +// 'compressed_indices_dtype' - dtype for the compressed map indices. +// This should be either int32 or int64. +// 'abs' - whether we should perform abs() on the indicator value or not. +// 'min_non_pruned_rows' - a minimum threshold on the number of rows +// that should be present after pruning. +// 'min_save_ratio' - a parameter to tradeoff between lookup table CPU overhead +// with the reduction in memory bandwidth due to pruned rows. +// Pruning will be skipped for the entire matrix if the physical size of +// pruned weights and indices mapping is greater than +// min_save_ratio * weights size. +// 'compressed indices map' will contain a single element [0] in this case. +// +// Returns: a tuple, +// - The first value is the pruned weight tensor whose dtype is float. +// - The second value is a 1D tensor whose dtype is 'compressed_indices_dtype'. +std::tuple embedding_bag_rowwise_prune( + const Tensor& weights, + const Tensor& indicator, + const double threshold, + at::ScalarType compressed_indices_dtype, + const bool abs, + const int64_t min_non_pruned_rows, + const c10::optional& min_save_ratio) { + TENSOR_ON_CPU(weights); + TENSOR_ON_CPU(indicator); + TENSOR_NDIM_EQUALS(weights, 2); + TORCH_CHECK( + indicator.numel() == weights.sizes()[0], + "Number of elements in 'indicator' should be equivalent to " + "number of rows in 'weights'.") + TORCH_CHECK( + threshold >= 0.0, "Threshold should be greater than or equal to zero."); + TORCH_CHECK( + compressed_indices_dtype == at::ScalarType::Int || + compressed_indices_dtype == at::ScalarType::Long, + "'compressed_indices_dtype' should be Int/Long."); + + const auto indicator_contig = indicator.expect_contiguous(); + const auto indicator_data = indicator_contig->data_ptr(); + auto rowwise_prune_mask = at::empty({indicator.numel()}, at::kBool); + int num_kept = 0; + for (const auto i : c10::irange(indicator.numel())) { + const float val = abs ? std::abs(indicator_data[i]) : indicator_data[i]; + bool should_keep_row = val > threshold; + + // The total number of rows post-pruning should be greater than or equal + // to 'min_non_pruned_rows'. + // Skip pruning the current row to satisfy the above criteria. + if (num_kept < min_non_pruned_rows && + num_kept + (indicator.numel() - i) <= min_non_pruned_rows) { + should_keep_row = true; + } + if (!should_keep_row) { + rowwise_prune_mask[i] = false; + continue; + } + rowwise_prune_mask[i] = true; + num_kept++; + } + + if (min_save_ratio.has_value() && + !should_prune(weights, min_non_pruned_rows, min_save_ratio.value())) { + auto compressed_indices_mapping = at::empty({1}, compressed_indices_dtype); + compressed_indices_mapping[0] = 0; + return std::tuple(weights, compressed_indices_mapping); + } + + return at::native::_rowwise_prune( + weights, rowwise_prune_mask, compressed_indices_dtype); +} + } // namespace fbgemm_gpu TORCH_LIBRARY_FRAGMENT(fbgemm, m) { @@ -1190,15 +1419,19 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( "batched_unary_embeddings(Tensor weight, Tensor table_offsets, Tensor offsets, Tensor indices) -> Tensor"); m.def( - "jagged_2d_to_dense(Tensor embeddings, Tensor offsets, int max_sequence_length) -> Tensor"); + "jagged_2d_to_dense(Tensor values, Tensor offsets, int max_sequence_length) -> Tensor"); m.def( "jagged_1d_to_dense(Tensor values, Tensor offsets, int max_sequence_length, int padding_value) -> Tensor"); m.def( "histogram_binning_calibration(Tensor logit, Tensor bin_num_examples, Tensor bin_num_positives, float positive_weight, float lower_bound, float upper_bound, int bin_ctr_in_use_after, float bin_ctr_weight_value) -> (Tensor, Tensor)"); m.def( "histogram_binning_calibration_by_feature(Tensor logit, Tensor segment_value, Tensor segment_lengths, int num_segments, Tensor bin_num_examples, Tensor bin_num_positives, int num_bins, float positive_weight, float lower_bound, float upper_bound, int bin_ctr_in_use_after, float bin_ctr_weight_value) -> (Tensor, Tensor)"); + m.def( + "generic_histogram_binning_calibration_by_feature(Tensor logit, Tensor segment_value, Tensor segment_lengths, int num_segments, Tensor bin_num_examples, Tensor bin_num_positives, Tensor bin_boundaries, float positive_weight, int bin_ctr_in_use_after, float bin_ctr_weight_value) -> (Tensor, Tensor)"); m.def( "segment_sum_csr(int batch_size, Tensor csr_seg, Tensor values) -> Tensor"); + m.def( + "embedding_bag_rowwise_prune(Tensor weight, Tensor indicator, float threshold, ScalarType compressed_indices_dtype, bool abs=True, int min_num_rows=0, float? min_save_ratio=1.0) -> (Tensor, Tensor)"); } TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { @@ -1231,5 +1464,10 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { m.impl( "histogram_binning_calibration_by_feature", fbgemm_gpu::histogram_binning_calibration_by_feature_cpu); + m.impl( + "generic_histogram_binning_calibration_by_feature", + fbgemm_gpu::generic_histogram_binning_calibration_by_feature_cpu); m.impl("segment_sum_csr", fbgemm_gpu::segment_sum_csr_cpu); + m.impl( + "embedding_bag_rowwise_prune", fbgemm_gpu::embedding_bag_rowwise_prune); } diff --git a/fbgemm_gpu/src/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops_gpu.cpp index 42859d399..6199a939c 100644 --- a/fbgemm_gpu/src/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops_gpu.cpp @@ -62,15 +62,15 @@ class Jagged2DToDenseGPUOp public: static torch::autograd::variable_list forward( torch::autograd::AutogradContext* ctx, - Tensor embeddings, + Tensor values, Tensor offsets, int32_t max_sequence_length) { - int32_t total_L = embeddings.size(0); + int32_t total_L = values.size(0); ctx->save_for_backward({offsets}); ctx->saved_data["total_L"] = total_L; - return {jagged_2d_to_dense_forward_cuda( - embeddings, offsets, max_sequence_length)}; + return { + jagged_2d_to_dense_forward_cuda(values, offsets, max_sequence_length)}; } static torch::autograd::variable_list backward( @@ -82,11 +82,11 @@ class Jagged2DToDenseGPUOp int32_t total_L = ctx->saved_data["total_L"].toInt(); using torch::autograd::Variable; - auto grad_padded_embeddings = grad_outputs[0]; - auto grad_embeddings = jagged_2d_to_dense_backward_cuda( - grad_padded_embeddings, offsets, total_L); + auto grad_padded_values = grad_outputs[0]; + auto grad_values = + jagged_2d_to_dense_backward_cuda(grad_padded_values, offsets, total_L); return { - grad_embeddings, + grad_values, Variable(), // offsets Variable() // max_sequence_length }; @@ -94,11 +94,11 @@ class Jagged2DToDenseGPUOp }; Tensor jagged_2d_to_dense_gpu( - Tensor embeddings, + Tensor values, Tensor offsets, int64_t max_sequence_length) { return Jagged2DToDenseGPUOp::apply( - embeddings, offsets, static_cast(max_sequence_length))[0]; + values, offsets, static_cast(max_sequence_length))[0]; } } // namespace fbgemm_gpu @@ -133,5 +133,8 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { DISPATCH_TO_CUDA( "histogram_binning_calibration_by_feature", fbgemm_gpu::histogram_binning_calibration_by_feature_cuda); + DISPATCH_TO_CUDA( + "generic_histogram_binning_calibration_by_feature", + fbgemm_gpu::generic_histogram_binning_calibration_by_feature_cuda); DISPATCH_TO_CUDA("segment_sum_csr", fbgemm_gpu::segment_sum_csr_cuda); } diff --git a/fbgemm_gpu/src/split_embeddings_cache_cuda.cu b/fbgemm_gpu/src/split_embeddings_cache_cuda.cu index 13fb62652..34a55d598 100644 --- a/fbgemm_gpu/src/split_embeddings_cache_cuda.cu +++ b/fbgemm_gpu/src/split_embeddings_cache_cuda.cu @@ -29,10 +29,10 @@ #include #include "fbgemm_gpu/dispatch_macros.h" +#include "fbgemm_gpu/embedding_common.h" #include "fbgemm_gpu/fbgemm_cuda_utils.cuh" - -#include "../codegen/embedding_common.h" -#include "split_embeddings_utils.cuh" +#include "fbgemm_gpu/sparse_ops_utils.h" +#include "fbgemm_gpu/split_embeddings_utils.cuh" constexpr size_t kCacheMaxThreads = 512; @@ -73,18 +73,19 @@ __host__ __device__ inline int32_t padded_row_size_in_bytes( } } // namespace -// TODO: do we care about 64-bit indices? Currently we just ignore. -__host__ DEVICE_INLINE uint32_t cache_slot(int32_t h_in, int32_t C) { - // MurmorHash3 32-bit mixing function. - uint32_t h = (uint32_t)h_in; - h ^= h >> 16; - h *= 0x85ebca6b; - h ^= h >> 13; - h *= 0xc2b2ae35; - h ^= h >> 16; - // https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/ - return ((uint64_t)h * (uint64_t)C) >> 32; -} +// // TODO: do we care about 64-bit indices? Currently we just ignore. +// __host__ DEVICE_INLINE uint32_t cache_slot(int32_t h_in, int32_t C) { +// // MurmorHash3 32-bit mixing function. +// uint32_t h = (uint32_t)h_in; +// h ^= h >> 16; +// h *= 0x85ebca6b; +// h ^= h >> 13; +// h *= 0xc2b2ae35; +// h ^= h >> 16; +// // +// https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/ +// return ((uint64_t)h * (uint64_t)C) >> 32; +// } __host__ DEVICE_INLINE uint32_t cache_slot(int64_t h_in, int32_t C) { // MurmurHash3 64-bit mixing function. @@ -110,13 +111,13 @@ __global__ __launch_bounds__(kMaxThreads) void lxu_cache_flush_kernel( at::PackedTensorAccessor64 weights, const at::PackedTensorAccessor32 cache_hash_size_cumsum, - const at::PackedTensorAccessor32 + const at::PackedTensorAccessor64 cache_index_table_map, const at::PackedTensorAccessor32 weights_offsets, const at::PackedTensorAccessor32 D_offsets, - at::PackedTensorAccessor32 + const at::PackedTensorAccessor32 lxu_cache_state, at::PackedTensorAccessor64 lxu_cache_weights, @@ -188,6 +189,14 @@ void lxu_cache_flush_cuda( Tensor lxu_cache_state, Tensor lxu_cache_weights, bool stochastic_rounding) { + TENSOR_ON_CUDA_GPU(uvm_weights); + TENSOR_ON_CUDA_GPU(cache_hash_size_cumsum); + TENSOR_ON_CUDA_GPU(cache_index_table_map); + TENSOR_ON_CUDA_GPU(weights_offsets); + TENSOR_ON_CUDA_GPU(D_offsets); + TENSOR_ON_CUDA_GPU(lxu_cache_state); + TENSOR_ON_CUDA_GPU(lxu_cache_weights); + at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(lxu_cache_weights.get_device()); @@ -216,7 +225,7 @@ void lxu_cache_flush_cuda( cache_hash_size_cumsum .packed_accessor32(), cache_index_table_map - .packed_accessor32(), + .packed_accessor64(), weights_offsets .packed_accessor32(), D_offsets @@ -232,12 +241,13 @@ void lxu_cache_flush_cuda( return; } +template __global__ __launch_bounds__(kMaxThreads) void linearize_cache_indices_kernel( const at::PackedTensorAccessor32 cache_hash_size_cumsum, - const at::PackedTensorAccessor32 indices, - const at::PackedTensorAccessor32 offsets, - at::PackedTensorAccessor32 + const at::PackedTensorAccessor32 indices, + const at::PackedTensorAccessor32 offsets, + at::PackedTensorAccessor32 linear_cache_indices) { int32_t T = cache_hash_size_cumsum.size(0) - 1; int64_t total_cache_hash_size = cache_hash_size_cumsum[T]; @@ -248,15 +258,15 @@ __global__ __launch_bounds__(kMaxThreads) void linearize_cache_indices_kernel( bool valid = t < T; int64_t hash_offset = valid ? cache_hash_size_cumsum[t] : -1; - int64_t indices_start = valid ? offsets[t * B + b] : -1; + auto indices_start = valid ? offsets[t * B + b] : -1; int32_t L = valid ? offsets[t * B + b + 1] - indices_start : 0; int32_t lane_id = threadIdx.x % kWarpSize; // hash_offset < 0 for non-caching tables for (int32_t j = 0; j < kWarpSize; ++j) { - int64_t indices_start_warp = SHFL_SYNC_MACRO(indices_start, j); - int32_t L_warp = SHFL_SYNC_MACRO(L, j); - int64_t hash_offset_warp = SHFL_SYNC_MACRO(hash_offset, j); + auto indices_start_warp = shfl_sync(indices_start, j); + int32_t L_warp = shfl_sync(L, j); + int64_t hash_offset_warp = shfl_sync(hash_offset, j); if (hash_offset_warp >= 0) { for (int32_t i = lane_id; i < L_warp; i += kWarpSize) { auto idx = __ldg(&indices[indices_start_warp + i]); @@ -274,6 +284,10 @@ Tensor linearize_cache_indices_cuda( Tensor cache_hash_size_cumsum, Tensor indices, Tensor offsets) { + TENSOR_ON_CUDA_GPU(cache_hash_size_cumsum); + TENSOR_ON_CUDA_GPU(indices); + TENSOR_ON_CUDA_GPU(offsets); + at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(cache_hash_size_cumsum.get_device()); @@ -287,18 +301,21 @@ Tensor linearize_cache_indices_cuda( if (B == 0) { return linear_cache_indices; } - linearize_cache_indices_kernel<<< - div_round_up(B * T, kMaxThreads), - kMaxThreads, - 0, - at::cuda::getCurrentCUDAStream()>>>( - cache_hash_size_cumsum - .packed_accessor32(), - indices.packed_accessor32(), - offsets.packed_accessor32(), - linear_cache_indices - .packed_accessor32()); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + AT_DISPATCH_INDEX_TYPES( + indices.scalar_type(), "linearize_cache_indices_kernel", [&]() { + linearize_cache_indices_kernel<<< + div_round_up(B * T, kMaxThreads), + kMaxThreads, + 0, + at::cuda::getCurrentCUDAStream()>>>( + cache_hash_size_cumsum + .packed_accessor32(), + indices.packed_accessor32(), + offsets.packed_accessor32(), + linear_cache_indices + .packed_accessor32()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); return linear_cache_indices; } @@ -306,6 +323,8 @@ std::tuple> get_unique_indices_cuda( Tensor linear_indices, int64_t max_indices, bool compute_count) { + TENSOR_ON_CUDA_GPU(linear_indices); + at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(linear_indices.get_device()); @@ -320,88 +339,93 @@ std::tuple> get_unique_indices_cuda( unique_indices_count = at::empty( {linear_indices.numel()}, linear_indices.options().dtype(at::kInt)); } - - // sort indices - size_t temp_storage_bytes_0 = 0; - AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortKeys( - nullptr, - temp_storage_bytes_0, - linear_indices.data_ptr(), - sorted_indices.data_ptr(), - N, - 0, - int(log2(float(max_indices + 1)) + 1), - at::cuda::getCurrentCUDAStream(), - false)); - auto temp_storage_0 = at::empty( - {static_cast(temp_storage_bytes_0)}, - linear_indices.options().dtype(at::kByte)); - AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortKeys( - temp_storage_0.data_ptr(), - temp_storage_bytes_0, - linear_indices.data_ptr(), - sorted_indices.data_ptr(), - N, - 0, - int(log2(float(max_indices + 1)) + 1), - at::cuda::getCurrentCUDAStream(), - false)); - // get unique indices - if (compute_count) { - size_t temp_storage_bytes_1 = 0; - AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRunLengthEncode::Encode( - nullptr, - temp_storage_bytes_1, - sorted_indices.data_ptr(), - unique_indices.data_ptr(), - unique_indices_count->data_ptr(), - unique_indices_length.data_ptr(), - N, - at::cuda::getCurrentCUDAStream(), - false)); - auto temp_storage_1 = at::empty( - {static_cast(temp_storage_bytes_1)}, - linear_indices.options().dtype(at::kByte)); - AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRunLengthEncode::Encode( - temp_storage_1.data_ptr(), - temp_storage_bytes_1, - sorted_indices.data_ptr(), - unique_indices.data_ptr(), - unique_indices_count->data_ptr(), - unique_indices_length.data_ptr(), - N, - at::cuda::getCurrentCUDAStream(), - false)); - } else { - size_t temp_storage_bytes_1 = 0; - AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceSelect::Unique( - nullptr, - temp_storage_bytes_1, - sorted_indices.data_ptr(), - unique_indices.data_ptr(), - unique_indices_length.data_ptr(), - N, - at::cuda::getCurrentCUDAStream(), - false)); - auto temp_storage_1 = at::empty( - {static_cast(temp_storage_bytes_1)}, - linear_indices.options().dtype(at::kByte)); - AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceSelect::Unique( - temp_storage_1.data_ptr(), - temp_storage_bytes_1, - sorted_indices.data_ptr(), - unique_indices.data_ptr(), - unique_indices_length.data_ptr(), - N, - at::cuda::getCurrentCUDAStream(), - false)); - } + AT_DISPATCH_INDEX_TYPES( + linear_indices.scalar_type(), "get_unique_indices_cuda", [&]() { + // sort indices + size_t temp_storage_bytes_0 = 0; + AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortKeys( + nullptr, + temp_storage_bytes_0, + linear_indices.data_ptr(), + sorted_indices.data_ptr(), + N, + 0, + int(log2(float(max_indices + 1)) + 1), + at::cuda::getCurrentCUDAStream(), + false)); + auto temp_storage_0 = at::empty( + {static_cast(temp_storage_bytes_0)}, + linear_indices.options().dtype(at::kByte)); + AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortKeys( + temp_storage_0.data_ptr(), + temp_storage_bytes_0, + linear_indices.data_ptr(), + sorted_indices.data_ptr(), + N, + 0, + int(log2(float(max_indices + 1)) + 1), + at::cuda::getCurrentCUDAStream(), + false)); + // get unique indices + if (compute_count) { + size_t temp_storage_bytes_1 = 0; + AT_CUDA_CHECK( + FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRunLengthEncode::Encode( + nullptr, + temp_storage_bytes_1, + sorted_indices.data_ptr(), + unique_indices.data_ptr(), + unique_indices_count->data_ptr(), + unique_indices_length.data_ptr(), + N, + at::cuda::getCurrentCUDAStream(), + false)); + auto temp_storage_1 = at::empty( + {static_cast(temp_storage_bytes_1)}, + linear_indices.options().dtype(at::kByte)); + AT_CUDA_CHECK( + FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRunLengthEncode::Encode( + temp_storage_1.data_ptr(), + temp_storage_bytes_1, + sorted_indices.data_ptr(), + unique_indices.data_ptr(), + unique_indices_count->data_ptr(), + unique_indices_length.data_ptr(), + N, + at::cuda::getCurrentCUDAStream(), + false)); + } else { + size_t temp_storage_bytes_1 = 0; + AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceSelect::Unique( + nullptr, + temp_storage_bytes_1, + sorted_indices.data_ptr(), + unique_indices.data_ptr(), + unique_indices_length.data_ptr(), + N, + at::cuda::getCurrentCUDAStream(), + false)); + auto temp_storage_1 = at::empty( + {static_cast(temp_storage_bytes_1)}, + linear_indices.options().dtype(at::kByte)); + AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceSelect::Unique( + temp_storage_1.data_ptr(), + temp_storage_bytes_1, + sorted_indices.data_ptr(), + unique_indices.data_ptr(), + unique_indices_length.data_ptr(), + N, + at::cuda::getCurrentCUDAStream(), + false)); + } + }); return std::make_tuple( unique_indices, unique_indices_length, unique_indices_count); } +template __global__ __launch_bounds__(kMaxThreads) void lru_cache_find_uncached_kernel( - const at::PackedTensorAccessor32 + const at::PackedTensorAccessor32 unique_indices, const int32_t* __restrict__ N_unique, int64_t max_indices, @@ -442,6 +466,10 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_find_uncached_kernel( } #ifdef __HIP_PLATFORM_HCC__ + // FIXME: __any_sync with mask isn't supported by HIP yet. + // See https://fburl.com/fvy7j0lq for the similar context. + // assert false here with https://fburl.com/pfm7enw2 + assert(false); if (!__any(found)) { #else if (!__any_sync(0xFFFFFFFF, found)) { @@ -459,6 +487,11 @@ std::pair lru_cache_find_uncached_cuda( Tensor lxu_cache_state, int64_t time_stamp, Tensor lru_state) { + TENSOR_ON_CUDA_GPU(unique_indices); + TENSOR_ON_CUDA_GPU(unique_indices_length); + TENSOR_ON_CUDA_GPU(lxu_cache_state); + TENSOR_ON_CUDA_GPU(lru_state); + at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(unique_indices.get_device()); @@ -468,49 +501,54 @@ std::pair lru_cache_find_uncached_cuda( auto sorted_cache_sets = empty_like(cache_sets); auto cache_set_sorted_unique_indices = empty_like(unique_indices); - // Find uncached indices - lru_cache_find_uncached_kernel<<< - div_round_up(N, kMaxThreads / kWarpSize), - dim3(kWarpSize, kMaxThreads / kWarpSize), - 0, - at::cuda::getCurrentCUDAStream()>>>( - unique_indices.packed_accessor32(), - unique_indices_length.data_ptr(), - max_indices, - lxu_cache_state.packed_accessor32(), - cache_sets.packed_accessor32(), - time_stamp, - lru_state.packed_accessor32()); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - // Sort the cache sets and ids - size_t temp_storage_bytes = 0; - AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( - nullptr, - temp_storage_bytes, - cache_sets.data_ptr(), - sorted_cache_sets.data_ptr(), - unique_indices.data_ptr(), - cache_set_sorted_unique_indices.data_ptr(), - N, - 0, - int(log2(float(lxu_cache_state.size(0) + 1)) + 1), - at::cuda::getCurrentCUDAStream(), - false)); - auto temp_storage = at::empty( - {static_cast(temp_storage_bytes)}, - unique_indices.options().dtype(at::kByte)); - AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( - temp_storage.data_ptr(), - temp_storage_bytes, - cache_sets.data_ptr(), - sorted_cache_sets.data_ptr(), - unique_indices.data_ptr(), - cache_set_sorted_unique_indices.data_ptr(), - N, - 0, - int(log2(float(lxu_cache_state.size(0) + 1)) + 1), - at::cuda::getCurrentCUDAStream(), - false)); + AT_DISPATCH_INDEX_TYPES( + unique_indices.scalar_type(), "lru_cache_find_uncached_cuda", [&]() { + // Find uncached indices + lru_cache_find_uncached_kernel<<< + div_round_up(N, kMaxThreads / kWarpSize), + dim3(kWarpSize, kMaxThreads / kWarpSize), + 0, + at::cuda::getCurrentCUDAStream()>>>( + unique_indices + .packed_accessor32(), + unique_indices_length.data_ptr(), + max_indices, + lxu_cache_state + .packed_accessor32(), + cache_sets.packed_accessor32(), + time_stamp, + lru_state.packed_accessor32()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + // Sort the cache sets and ids + size_t temp_storage_bytes = 0; + AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( + nullptr, + temp_storage_bytes, + cache_sets.data_ptr(), + sorted_cache_sets.data_ptr(), + unique_indices.data_ptr(), + cache_set_sorted_unique_indices.data_ptr(), + N, + 0, + int(log2(float(lxu_cache_state.size(0) + 1)) + 1), + at::cuda::getCurrentCUDAStream(), + false)); + auto temp_storage = at::empty( + {static_cast(temp_storage_bytes)}, + unique_indices.options().dtype(at::kByte)); + AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( + temp_storage.data_ptr(), + temp_storage_bytes, + cache_sets.data_ptr(), + sorted_cache_sets.data_ptr(), + unique_indices.data_ptr(), + cache_set_sorted_unique_indices.data_ptr(), + N, + 0, + int(log2(float(lxu_cache_state.size(0) + 1)) + 1), + at::cuda::getCurrentCUDAStream(), + false)); + }); return {sorted_cache_sets, cache_set_sorted_unique_indices}; } @@ -519,7 +557,7 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_kernel( at::PackedTensorAccessor64 weights, const at::PackedTensorAccessor32 cache_hash_size_cumsum, - const at::PackedTensorAccessor32 + const at::PackedTensorAccessor64 cache_index_table_map, const at::PackedTensorAccessor32 weights_offsets, @@ -575,8 +613,8 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_kernel( int64_t sorted_lru_cost = costs[0]; for (int32_t l = 0; l < min(SL, kWarpSize); ++l) { - int32_t insert_slot = SHFL_SYNC_MACRO(sorted_slot, l); - int64_t insert_current_lru_cost = SHFL_SYNC_MACRO(sorted_lru_cost, l); + int32_t insert_slot = shfl_sync(sorted_slot, l); + int64_t insert_current_lru_cost = shfl_sync(sorted_lru_cost, l); if (insert_current_lru_cost == time_stamp) { return; } @@ -592,7 +630,7 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_kernel( // lxu_cache_state int64_t current_idx = threadIdx.x == 0 ? lxu_cache_state[cache_set][insert_slot] : 0; - current_idx = SHFL_SYNC_MACRO(current_idx, 0); + current_idx = shfl_sync(current_idx, 0); // not empty if (current_idx != static_cast(kCacheStateInvalid)) { @@ -696,6 +734,18 @@ void lru_cache_insert_cuda( int64_t time_stamp, Tensor lru_state, bool stochastic_rounding) { + TENSOR_ON_CUDA_GPU(weights); + TENSOR_ON_CUDA_GPU(cache_hash_size_cumsum); + TENSOR_ON_CUDA_GPU(cache_index_table_map); + TENSOR_ON_CUDA_GPU(weights_offsets); + TENSOR_ON_CUDA_GPU(D_offsets); + TENSOR_ON_CUDA_GPU(sorted_cache_sets); + TENSOR_ON_CUDA_GPU(cache_set_sorted_unique_indices); + TENSOR_ON_CUDA_GPU(unique_indices_length); + TENSOR_ON_CUDA_GPU(lxu_cache_state); + TENSOR_ON_CUDA_GPU(lxu_cache_weights); + TENSOR_ON_CUDA_GPU(lru_state); + at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(weights.get_device()); @@ -723,7 +773,7 @@ void lru_cache_insert_cuda( cache_hash_size_cumsum .packed_accessor32(), cache_index_table_map - .packed_accessor32(), + .packed_accessor64(), weights_offsets .packed_accessor32(), D_offsets @@ -759,6 +809,16 @@ void lru_cache_populate_cuda( int64_t time_stamp, Tensor lru_state, bool stochastic_rounding) { + TENSOR_ON_CUDA_GPU(weights); + TENSOR_ON_CUDA_GPU(cache_hash_size_cumsum); + TENSOR_ON_CUDA_GPU(cache_index_table_map); + TENSOR_ON_CUDA_GPU(weights_offsets); + TENSOR_ON_CUDA_GPU(D_offsets); + TENSOR_ON_CUDA_GPU(linear_cache_indices); + TENSOR_ON_CUDA_GPU(lxu_cache_state); + TENSOR_ON_CUDA_GPU(lxu_cache_weights); + TENSOR_ON_CUDA_GPU(lru_state); + at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(weights.get_device()); @@ -805,11 +865,12 @@ void lru_cache_populate_cuda( stochastic_rounding); } +template __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_byte_kernel( at::PackedTensorAccessor64 weights, const at::PackedTensorAccessor32 cache_hash_size_cumsum, - const at::PackedTensorAccessor32 + const at::PackedTensorAccessor64 cache_index_table_map, const at::PackedTensorAccessor32 weights_offsets, @@ -819,7 +880,7 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_byte_kernel( D_offsets, const at::PackedTensorAccessor32 sorted_cache_sets, - const at::PackedTensorAccessor32 + const at::PackedTensorAccessor32 cache_set_sorted_indices, const int32_t* __restrict__ N_unique, at::PackedTensorAccessor32 @@ -865,12 +926,12 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_byte_kernel( int64_t sorted_lru_cost = costs[0]; for (int32_t l = 0; l < min(SL, kWarpSize); ++l) { - int32_t insert_slot = SHFL_SYNC_MACRO(sorted_slot, l); - int64_t insert_current_lru_cost = SHFL_SYNC_MACRO(sorted_lru_cost, l); + int32_t insert_slot = shfl_sync(sorted_slot, l); + int64_t insert_current_lru_cost = shfl_sync(sorted_lru_cost, l); if (insert_current_lru_cost == time_stamp) { return; } - int64_t insert_idx = cache_set_sorted_indices[n + l]; + index_t insert_idx = cache_set_sorted_indices[n + l]; int32_t t_insert = cache_index_table_map[insert_idx]; SparseType weight_ty_insert = static_cast(weights_tys[t_insert]); @@ -887,7 +948,7 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_byte_kernel( // lxu_cache_state int64_t current_idx = threadIdx.x == 0 ? lxu_cache_state[cache_set][insert_slot] : 0; - current_idx = SHFL_SYNC_MACRO(current_idx, 0); + current_idx = shfl_sync(current_idx, 0); // not empty if (current_idx != static_cast(kCacheStateInvalid)) { @@ -941,33 +1002,55 @@ void lru_cache_insert_byte_cuda( Tensor lxu_cache_weights, int64_t time_stamp, Tensor lru_state) { + TENSOR_ON_CUDA_GPU(weights); + TENSOR_ON_CUDA_GPU(cache_hash_size_cumsum); + TENSOR_ON_CUDA_GPU(cache_index_table_map); + TENSOR_ON_CUDA_GPU(weights_offsets); + TENSOR_ON_CUDA_GPU(weights_tys); + TENSOR_ON_CUDA_GPU(D_offsets); + TENSOR_ON_CUDA_GPU(sorted_cache_sets); + TENSOR_ON_CUDA_GPU(cache_set_sorted_unique_indices); + TENSOR_ON_CUDA_GPU(unique_indices_length); + TENSOR_ON_CUDA_GPU(lxu_cache_state); + TENSOR_ON_CUDA_GPU(lxu_cache_weights); + TENSOR_ON_CUDA_GPU(lru_state); + at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(weights.get_device()); int32_t N = cache_set_sorted_unique_indices.numel(); - lru_cache_insert_byte_kernel<<< - div_round_up(N, kMaxThreads / kWarpSize), - dim3(kWarpSize, kMaxThreads / kWarpSize), - 0, - at::cuda::getCurrentCUDAStream()>>>( - weights.packed_accessor64(), - cache_hash_size_cumsum - .packed_accessor32(), - cache_index_table_map - .packed_accessor32(), - weights_offsets.packed_accessor32(), - weights_tys.packed_accessor32(), - D_offsets.packed_accessor32(), - sorted_cache_sets.packed_accessor32(), - cache_set_sorted_unique_indices - .packed_accessor32(), - unique_indices_length.data_ptr(), - lxu_cache_state.packed_accessor32(), - lxu_cache_weights.packed_accessor64(), - time_stamp, - lru_state.packed_accessor32()); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + AT_DISPATCH_INDEX_TYPES( + cache_set_sorted_unique_indices.scalar_type(), + "lru_cache_insert_byte_cuda", + [&]() { + lru_cache_insert_byte_kernel<<< + div_round_up(N, kMaxThreads / kWarpSize), + dim3(kWarpSize, kMaxThreads / kWarpSize), + 0, + at::cuda::getCurrentCUDAStream()>>>( + weights.packed_accessor64(), + cache_hash_size_cumsum + .packed_accessor32(), + cache_index_table_map + .packed_accessor64(), + weights_offsets + .packed_accessor32(), + weights_tys.packed_accessor32(), + D_offsets.packed_accessor32(), + sorted_cache_sets + .packed_accessor32(), + cache_set_sorted_unique_indices + .packed_accessor32(), + unique_indices_length.data_ptr(), + lxu_cache_state + .packed_accessor32(), + lxu_cache_weights + .packed_accessor64(), + time_stamp, + lru_state.packed_accessor32()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); } void lru_cache_populate_byte_cuda( @@ -983,6 +1066,17 @@ void lru_cache_populate_byte_cuda( Tensor lxu_cache_weights, int64_t time_stamp, Tensor lru_state) { + TENSOR_ON_CUDA_GPU(weights); + TENSOR_ON_CUDA_GPU(cache_hash_size_cumsum); + TENSOR_ON_CUDA_GPU(cache_index_table_map); + TENSOR_ON_CUDA_GPU(weights_offsets); + TENSOR_ON_CUDA_GPU(weights_tys); + TENSOR_ON_CUDA_GPU(D_offsets); + TENSOR_ON_CUDA_GPU(linear_cache_indices); + TENSOR_ON_CUDA_GPU(lxu_cache_state); + TENSOR_ON_CUDA_GPU(lxu_cache_weights); + TENSOR_ON_CUDA_GPU(lru_state); + at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(weights.get_device()); @@ -1029,8 +1123,9 @@ void lru_cache_populate_byte_cuda( lru_state); } +template __global__ __launch_bounds__(kMaxThreads) void lfu_update_counts_kernel( - const at::PackedTensorAccessor32 + const at::PackedTensorAccessor32 unique_indices, const int32_t* __restrict__ N_unique, const at::PackedTensorAccessor32 @@ -1040,7 +1135,7 @@ __global__ __launch_bounds__(kMaxThreads) void lfu_update_counts_kernel( if (n >= *N_unique) { return; } - int64_t idx = unique_indices[n]; + auto idx = unique_indices[n]; lfu_state[idx] += unique_indices_count[n]; } @@ -1049,20 +1144,29 @@ void lfu_update_counts_cuda( Tensor unique_indices_length, Tensor unique_indices_count, Tensor lfu_state) { + TENSOR_ON_CUDA_GPU(unique_indices); + TENSOR_ON_CUDA_GPU(unique_indices_length); + TENSOR_ON_CUDA_GPU(unique_indices_count); + TENSOR_ON_CUDA_GPU(lfu_state); + at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(unique_indices.get_device()); int32_t N = unique_indices.size(0); - lfu_update_counts_kernel<<< - div_round_up(N, kMaxThreads), - kMaxThreads, - 0, - at::cuda::getCurrentCUDAStream()>>>( - unique_indices.packed_accessor32(), - unique_indices_length.data_ptr(), - unique_indices_count - .packed_accessor32(), - lfu_state.packed_accessor64()); + AT_DISPATCH_INDEX_TYPES( + unique_indices.scalar_type(), "lfu_update_counts_cuda", [&]() { + lfu_update_counts_kernel<<< + div_round_up(N, kMaxThreads), + kMaxThreads, + 0, + at::cuda::getCurrentCUDAStream()>>>( + unique_indices + .packed_accessor32(), + unique_indices_length.data_ptr(), + unique_indices_count + .packed_accessor32(), + lfu_state.packed_accessor64()); + }); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -1070,8 +1174,9 @@ constexpr int32_t kCacheSetBits = 24; constexpr int32_t kLFUCounterBits = 40; static_assert(kCacheSetBits + kLFUCounterBits == 8 * sizeof(int64_t), ""); +template __global__ __launch_bounds__(kMaxThreads) void lfu_cache_find_uncached_kernel( - const at::PackedTensorAccessor32 + const at::PackedTensorAccessor32 unique_indices, const int32_t* __restrict__ N_unique, int64_t max_indices, @@ -1115,6 +1220,10 @@ __global__ __launch_bounds__(kMaxThreads) void lfu_cache_find_uncached_kernel( } #ifdef __HIP_PLATFORM_HCC__ + // FIXME: __any_sync with mask isn't supported by HIP yet. + // See https://fburl.com/fvy7j0lq for the similar context. + // assert false here with https://fburl.com/pfm7enw2 + assert(false); if (!__any(found)) { #else if (!__any_sync(0xFFFFFFFF, found)) { @@ -1134,6 +1243,11 @@ std::pair lfu_cache_find_uncached_cuda( int64_t max_indices, Tensor lxu_cache_state, Tensor lfu_state) { + TENSOR_ON_CUDA_GPU(unique_indices); + TENSOR_ON_CUDA_GPU(unique_indices_length); + TENSOR_ON_CUDA_GPU(lxu_cache_state); + TENSOR_ON_CUDA_GPU(lfu_state); + at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(unique_indices.get_device()); @@ -1143,48 +1257,53 @@ std::pair lfu_cache_find_uncached_cuda( auto sorted_cache_sets = empty_like(cache_sets); auto cache_set_sorted_unique_indices = empty_like(unique_indices); - // Find uncached indices - lfu_cache_find_uncached_kernel<<< - div_round_up(N, kMaxThreads / kWarpSize), - dim3(kWarpSize, kMaxThreads / kWarpSize), - 0, - at::cuda::getCurrentCUDAStream()>>>( - unique_indices.packed_accessor32(), - unique_indices_length.data_ptr(), - max_indices, - lxu_cache_state.packed_accessor32(), - (uint64_t*)cache_sets.data_ptr(), - lfu_state.packed_accessor64()); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - // Sort the cache sets and ids - size_t temp_storage_bytes = 0; - AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( - nullptr, - temp_storage_bytes, - (uint64_t*)cache_sets.data_ptr(), - (uint64_t*)sorted_cache_sets.data_ptr(), - unique_indices.data_ptr(), - cache_set_sorted_unique_indices.data_ptr(), - N, - 0, - int(log2(float(lxu_cache_state.size(0) + 1)) + 1) + kLFUCounterBits, - at::cuda::getCurrentCUDAStream(), - false)); - auto temp_storage = at::empty( - {static_cast(temp_storage_bytes)}, - unique_indices.options().dtype(at::kByte)); - AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( - temp_storage.data_ptr(), - temp_storage_bytes, - (uint64_t*)cache_sets.data_ptr(), - (uint64_t*)sorted_cache_sets.data_ptr(), - unique_indices.data_ptr(), - cache_set_sorted_unique_indices.data_ptr(), - N, - 0, - int(log2(float(lxu_cache_state.size(0) + 1)) + 1) + kLFUCounterBits, - at::cuda::getCurrentCUDAStream(), - false)); + AT_DISPATCH_INDEX_TYPES( + unique_indices.scalar_type(), "lfu_cache_find_uncached_cuda", [&]() { + // Find uncached indices + lfu_cache_find_uncached_kernel<<< + div_round_up(N, kMaxThreads / kWarpSize), + dim3(kWarpSize, kMaxThreads / kWarpSize), + 0, + at::cuda::getCurrentCUDAStream()>>>( + unique_indices + .packed_accessor32(), + unique_indices_length.data_ptr(), + max_indices, + lxu_cache_state + .packed_accessor32(), + (uint64_t*)cache_sets.data_ptr(), + lfu_state.packed_accessor64()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + // Sort the cache sets and ids + size_t temp_storage_bytes = 0; + AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( + nullptr, + temp_storage_bytes, + (uint64_t*)cache_sets.data_ptr(), + (uint64_t*)sorted_cache_sets.data_ptr(), + unique_indices.data_ptr(), + cache_set_sorted_unique_indices.data_ptr(), + N, + 0, + int(log2(float(lxu_cache_state.size(0) + 1)) + 1) + kLFUCounterBits, + at::cuda::getCurrentCUDAStream(), + false)); + auto temp_storage = at::empty( + {static_cast(temp_storage_bytes)}, + unique_indices.options().dtype(at::kByte)); + AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( + temp_storage.data_ptr(), + temp_storage_bytes, + (uint64_t*)cache_sets.data_ptr(), + (uint64_t*)sorted_cache_sets.data_ptr(), + unique_indices.data_ptr(), + cache_set_sorted_unique_indices.data_ptr(), + N, + 0, + int(log2(float(lxu_cache_state.size(0) + 1)) + 1) + kLFUCounterBits, + at::cuda::getCurrentCUDAStream(), + false)); + }); return {sorted_cache_sets, cache_set_sorted_unique_indices}; } @@ -1193,7 +1312,7 @@ __global__ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_kernel( at::PackedTensorAccessor64 weights, const at::PackedTensorAccessor32 cache_hash_size_cumsum, - const at::PackedTensorAccessor32 + const at::PackedTensorAccessor64 cache_index_table_map, const at::PackedTensorAccessor32 weights_offsets, @@ -1255,8 +1374,8 @@ __global__ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_kernel( int64_t sorted_lfu_cost = costs[0]; for (int32_t l = 0; l < min(SL, kWarpSize); ++l) { - int32_t insert_slot = SHFL_SYNC_MACRO(sorted_slot, l); - int64_t insert_current_lfu_cost = SHFL_SYNC_MACRO(sorted_lfu_cost, l); + int32_t insert_slot = shfl_sync(sorted_slot, l); + int64_t insert_current_lfu_cost = shfl_sync(sorted_lfu_cost, l); int64_t insert_idx = cache_set_sorted_indices[n + l]; int64_t insert_lfu_cost = lfu_state[insert_idx]; @@ -1280,7 +1399,7 @@ __global__ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_kernel( // lxu_cache_state int64_t current_idx = threadIdx.x == 0 ? lxu_cache_state[cache_set][insert_slot] : 0; - current_idx = SHFL_SYNC_MACRO(current_idx, 0); + current_idx = shfl_sync(current_idx, 0); int32_t t_current = cache_index_table_map[current_idx]; int64_t idx_current = current_idx - cache_hash_size_cumsum[t_current]; int64_t weights_offset_current = weights_offsets[t_current]; @@ -1379,6 +1498,18 @@ void lfu_cache_insert_cuda( Tensor lxu_cache_weights, Tensor lfu_state, bool stochastic_rounding) { + TENSOR_ON_CUDA_GPU(weights); + TENSOR_ON_CUDA_GPU(cache_hash_size_cumsum); + TENSOR_ON_CUDA_GPU(cache_index_table_map); + TENSOR_ON_CUDA_GPU(weights_offsets); + TENSOR_ON_CUDA_GPU(D_offsets); + TENSOR_ON_CUDA_GPU(sorted_cache_sets); + TENSOR_ON_CUDA_GPU(cache_set_sorted_unique_indices); + TENSOR_ON_CUDA_GPU(unique_indices_length); + TENSOR_ON_CUDA_GPU(lxu_cache_state); + TENSOR_ON_CUDA_GPU(lxu_cache_weights); + TENSOR_ON_CUDA_GPU(lfu_state); + at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(weights.get_device()); @@ -1406,7 +1537,7 @@ void lfu_cache_insert_cuda( cache_hash_size_cumsum .packed_accessor32(), cache_index_table_map - .packed_accessor32(), + .packed_accessor64(), weights_offsets .packed_accessor32(), D_offsets @@ -1439,6 +1570,16 @@ void lfu_cache_populate_cuda( Tensor lxu_cache_weights, Tensor lfu_state, bool stochastic_rounding) { + TENSOR_ON_CUDA_GPU(weights); + TENSOR_ON_CUDA_GPU(cache_hash_size_cumsum); + TENSOR_ON_CUDA_GPU(cache_index_table_map); + TENSOR_ON_CUDA_GPU(weights_offsets); + TENSOR_ON_CUDA_GPU(D_offsets); + TENSOR_ON_CUDA_GPU(linear_cache_indices); + TENSOR_ON_CUDA_GPU(lxu_cache_state); + TENSOR_ON_CUDA_GPU(lxu_cache_weights); + TENSOR_ON_CUDA_GPU(lfu_state); + at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(weights.get_device()); @@ -1502,12 +1643,13 @@ void lfu_cache_populate_cuda( // uint8_t only). Basically no "high-precision cache" support for now. // - The insert/evict of embedding row from the cache are done in a byte-by-byte // manner. +template __global__ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_byte_kernel( at::PackedTensorAccessor64 weights, const at::PackedTensorAccessor32 cache_hash_size_cumsum, - const at::PackedTensorAccessor32 + const at::PackedTensorAccessor64 cache_index_table_map, const at::PackedTensorAccessor32 weights_offsets, @@ -1516,7 +1658,7 @@ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_byte_kernel( const at::PackedTensorAccessor32 D_offsets, const uint64_t* __restrict__ sorted_cache_sets, - const at::PackedTensorAccessor32 + const at::PackedTensorAccessor32 cache_set_sorted_indices, const int32_t* __restrict__ N_unique, at::PackedTensorAccessor32 @@ -1569,9 +1711,9 @@ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_byte_kernel( int64_t sorted_lfu_cost = costs[0]; for (int32_t l = 0; l < min(SL, kWarpSize); ++l) { - int32_t insert_slot = SHFL_SYNC_MACRO(sorted_slot, l); - int64_t insert_current_lfu_cost = SHFL_SYNC_MACRO(sorted_lfu_cost, l); - int64_t insert_idx = cache_set_sorted_indices[n + l]; + int32_t insert_slot = shfl_sync(sorted_slot, l); + int64_t insert_current_lfu_cost = shfl_sync(sorted_lfu_cost, l); + index_t insert_idx = cache_set_sorted_indices[n + l]; int64_t insert_lfu_cost = lfu_state[insert_idx]; if (insert_current_lfu_cost > insert_lfu_cost) { @@ -1599,7 +1741,7 @@ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_byte_kernel( // lxu_cache_state int64_t current_idx = threadIdx.x == 0 ? lxu_cache_state[cache_set][insert_slot] : 0; - current_idx = SHFL_SYNC_MACRO(current_idx, 0); + current_idx = shfl_sync(current_idx, 0); int32_t t_current = cache_index_table_map[current_idx]; SparseType weight_ty_current = static_cast(weights_tys[t_current]); @@ -1647,31 +1789,52 @@ void lfu_cache_insert_byte_cuda( Tensor lxu_cache_state, Tensor lxu_cache_weights, Tensor lfu_state) { + TENSOR_ON_CUDA_GPU(weights); + TENSOR_ON_CUDA_GPU(cache_hash_size_cumsum); + TENSOR_ON_CUDA_GPU(cache_index_table_map); + TENSOR_ON_CUDA_GPU(weights_offsets); + TENSOR_ON_CUDA_GPU(weights_tys) + TENSOR_ON_CUDA_GPU(D_offsets); + TENSOR_ON_CUDA_GPU(sorted_cache_sets); + TENSOR_ON_CUDA_GPU(cache_set_sorted_unique_indices); + TENSOR_ON_CUDA_GPU(unique_indices_length); + TENSOR_ON_CUDA_GPU(lxu_cache_state); + TENSOR_ON_CUDA_GPU(lxu_cache_weights); + TENSOR_ON_CUDA_GPU(lfu_state); + at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(weights.get_device()); int32_t N = cache_set_sorted_unique_indices.numel(); - lfu_cache_insert_byte_kernel<<< - div_round_up(N, kCacheMaxThreads / kWarpSize), - dim3(kWarpSize, kCacheMaxThreads / kWarpSize), - 0, - at::cuda::getCurrentCUDAStream()>>>( - weights.packed_accessor64(), - cache_hash_size_cumsum - .packed_accessor32(), - cache_index_table_map - .packed_accessor32(), - weights_offsets.packed_accessor32(), - weights_tys.packed_accessor32(), - D_offsets.packed_accessor32(), - (uint64_t*)sorted_cache_sets.data_ptr(), - cache_set_sorted_unique_indices - .packed_accessor32(), - unique_indices_length.data_ptr(), - lxu_cache_state.packed_accessor32(), - lxu_cache_weights.packed_accessor64(), - lfu_state.packed_accessor64()); + AT_DISPATCH_INDEX_TYPES( + cache_set_sorted_unique_indices.scalar_type(), + "lfu_cache_insert_byte_cuda", + [&]() { + lfu_cache_insert_byte_kernel<<< + div_round_up(N, kCacheMaxThreads / kWarpSize), + dim3(kWarpSize, kCacheMaxThreads / kWarpSize), + 0, + at::cuda::getCurrentCUDAStream()>>>( + weights.packed_accessor64(), + cache_hash_size_cumsum + .packed_accessor32(), + cache_index_table_map + .packed_accessor64(), + weights_offsets + .packed_accessor32(), + weights_tys.packed_accessor32(), + D_offsets.packed_accessor32(), + (uint64_t*)sorted_cache_sets.data_ptr(), + cache_set_sorted_unique_indices + .packed_accessor32(), + unique_indices_length.data_ptr(), + lxu_cache_state + .packed_accessor32(), + lxu_cache_weights + .packed_accessor64(), + lfu_state.packed_accessor64()); + }); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -1688,6 +1851,17 @@ void lfu_cache_populate_byte_cuda( Tensor lxu_cache_state, Tensor lxu_cache_weights, Tensor lfu_state) { + TENSOR_ON_CUDA_GPU(weights); + TENSOR_ON_CUDA_GPU(cache_hash_size_cumsum); + TENSOR_ON_CUDA_GPU(cache_index_table_map); + TENSOR_ON_CUDA_GPU(weights_offsets); + TENSOR_ON_CUDA_GPU(weights_tys) + TENSOR_ON_CUDA_GPU(D_offsets); + TENSOR_ON_CUDA_GPU(linear_cache_indices); + TENSOR_ON_CUDA_GPU(lxu_cache_state); + TENSOR_ON_CUDA_GPU(lxu_cache_weights); + TENSOR_ON_CUDA_GPU(lfu_state); + at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(weights.get_device()); @@ -1736,8 +1910,9 @@ void lfu_cache_populate_byte_cuda( lfu_state); } +template __global__ __launch_bounds__(kMaxThreads) void lxu_cache_lookup_kernel( - const at::PackedTensorAccessor32 + const at::PackedTensorAccessor32 linear_cache_indices, const at::PackedTensorAccessor32 lxu_cache_state, @@ -1757,6 +1932,10 @@ __global__ __launch_bounds__(kMaxThreads) void lxu_cache_lookup_kernel( lxu_cache_locations[n] = cache_set * kWarpSize + slot; } #ifdef __HIP_PLATFORM_HCC__ + // FIXME: __any_sync with mask isn't supported by HIP yet. + // See https://fburl.com/fvy7j0lq for the similar context. + // assert false here with https://fburl.com/pfm7enw2 + assert(false); if (!__any(found)) { #else if (!__any_sync(0xFFFFFFFF, found)) { @@ -1770,6 +1949,9 @@ __global__ __launch_bounds__(kMaxThreads) void lxu_cache_lookup_kernel( Tensor lxu_cache_lookup_cuda( Tensor linear_cache_indices, Tensor lxu_cache_state) { + TENSOR_ON_CUDA_GPU(linear_cache_indices); + TENSOR_ON_CUDA_GPU(lxu_cache_state); + at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(linear_cache_indices.get_device()); @@ -1784,17 +1966,21 @@ Tensor lxu_cache_lookup_cuda( const dim3 threads(kWarpSize, kMaxThreads / kWarpSize); const dim3 blocks(div_round_up(N, kMaxThreads / kWarpSize)); - lxu_cache_lookup_kernel<<< - blocks, - threads, - 0, - at::cuda::getCurrentCUDAStream()>>>( - linear_cache_indices - .packed_accessor32(), - lxu_cache_state.packed_accessor32(), - lxu_cache_locations - .packed_accessor32()); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + AT_DISPATCH_INDEX_TYPES( + linear_cache_indices.scalar_type(), "lxu_cache_lookup_cuda", [&]() { + lxu_cache_lookup_kernel<<< + blocks, + threads, + 0, + at::cuda::getCurrentCUDAStream()>>>( + linear_cache_indices + .packed_accessor32(), + lxu_cache_state + .packed_accessor32(), + lxu_cache_locations + .packed_accessor32()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); return lxu_cache_locations; } diff --git a/fbgemm_gpu/src/split_embeddings_utils.cu b/fbgemm_gpu/src/split_embeddings_utils.cu new file mode 100644 index 000000000..46f14b645 --- /dev/null +++ b/fbgemm_gpu/src/split_embeddings_utils.cu @@ -0,0 +1,229 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include "fbgemm_gpu/split_embeddings_utils.cuh" + +#include +#include +#include "fbgemm_gpu/embedding_backward_template_helpers.cuh" + +using Tensor = at::Tensor; + +using namespace fbgemm_gpu; + +template +__global__ void linearize_index_kernel( + const at::PackedTensorAccessor32 + hash_size_cumsum, + const at::PackedTensorAccessor32 indices, + const at::PackedTensorAccessor32 offsets, + at::PackedTensorAccessor32 infos, + at::PackedTensorAccessor32 + linear_indices) { + int32_t T = hash_size_cumsum.size(0) - 1; + int32_t B = (offsets.size(0) - 1) / T; + int32_t b_t = blockIdx.x * blockDim.x + threadIdx.x; + int32_t b = b_t % B; + int32_t t = b_t / B; + bool valid = t < T; + + index_t hash_offset = valid ? hash_size_cumsum[t] : -1; + index_t indices_start = valid ? offsets[t * B + b] : -1; + int32_t L = valid ? offsets[t * B + b + 1] - indices_start : 0; + int32_t lane_id = threadIdx.x % fbgemm_gpu::kWarpSize; + + for (int32_t j = 0; j < fbgemm_gpu::kWarpSize; ++j) { + index_t indices_start_warp = fbgemm_gpu::shfl_sync(indices_start, j); + int32_t b_t_warp = fbgemm_gpu::shfl_sync(b_t, j); + int32_t L_warp = fbgemm_gpu::shfl_sync(L, j); + index_t hash_offset_warp = fbgemm_gpu::shfl_sync(hash_offset, j); + for (int32_t i = lane_id; i < L_warp; i += fbgemm_gpu::kWarpSize) { + index_t idx = __ldg(&indices[indices_start_warp + i]); + infos[indices_start_warp + i] = b_t_warp; + linear_indices[indices_start_warp + i] = hash_offset_warp + idx; + } + } +} + +template +__global__ void nobag_linearize_index_kernel( + const at::PackedTensorAccessor32 + hash_size_cumsum, + const at::PackedTensorAccessor32 indices, + const at::PackedTensorAccessor32 offsets, + at::PackedTensorAccessor32 infos, + at::PackedTensorAccessor32 + linear_indices) { + int32_t T = hash_size_cumsum.size(0) - 1; + int32_t B = (offsets.size(0) - 1) / T; + int32_t b_t = blockIdx.x * blockDim.x + threadIdx.x; + int32_t b = b_t % B; + int32_t t = b_t / B; + bool valid = t < T; + + index_t hash_offset = valid ? hash_size_cumsum[t] : -1; + index_t indices_start = valid ? offsets[t * B + b] : -1; + int32_t L = valid ? offsets[t * B + b + 1] - indices_start : 0; + int32_t lane_id = threadIdx.x % fbgemm_gpu::kWarpSize; + + for (int32_t j = 0; j < fbgemm_gpu::kWarpSize; ++j) { + index_t indices_start_warp = fbgemm_gpu::shfl_sync(indices_start, j); + int32_t t_warp = fbgemm_gpu::shfl_sync(t, j); + int32_t L_warp = fbgemm_gpu::shfl_sync(L, j); + index_t hash_offset_warp = fbgemm_gpu::shfl_sync(hash_offset, j); + for (int32_t i = lane_id; i < L_warp; i += fbgemm_gpu::kWarpSize) { + index_t idx = __ldg(&indices[indices_start_warp + i]); + int64_t l_t = (indices_start_warp + i) * T + t_warp; + infos[indices_start_warp + i] = l_t; + linear_indices[indices_start_warp + i] = hash_offset_warp + idx; + } + } +} + +std::tuple< + Tensor /*linear_indices*/, + Tensor /*linear_indices_sorted*/, + Tensor /*infos_sorted*/, + Tensor /*sorted_linear_indices_run*/, + Tensor /*sorted_linear_indices_run_lengths*/, + Tensor /*sorted_linear_indices_num_runs*/, + Tensor /*sorted_linear_indices_cumulative_run_lengths*/> +transpose_embedding_input( + Tensor hash_size_cumsum, + int64_t total_hash_size_bits, + Tensor indices, + Tensor offsets, + bool nobag) { + int32_t T = hash_size_cumsum.size(0) - 1; + int32_t B = (offsets.size(0) - 1) / T; + + auto infos = at::empty_like( + indices, indices.options().dtype(nobag ? at::kLong : at::kInt)); + auto infos_sorted = at::empty_like(infos); + auto linear_indices = at::empty_like(indices); + auto linear_indices_sorted = at::empty_like(indices); + + Tensor sorted_linear_indices_run; + Tensor sorted_linear_indices_run_lengths; + Tensor sorted_linear_indices_num_runs; + + using at::RestrictPtrTraits; + + AT_DISPATCH_INDEX_TYPES( + infos.scalar_type(), "transpose_embedding_input1", ([&] { + using info_t = index_t; + AT_DISPATCH_INDEX_TYPES( + indices.scalar_type(), "transpose_embedding_input2", ([&] { + if (!nobag) { + linearize_index_kernel<<< + div_round_up(B * T, kMaxThreads), + kMaxThreads, + 0, + at::cuda::getCurrentCUDAStream()>>>( + hash_size_cumsum + .packed_accessor32(), + indices.packed_accessor32(), + offsets.packed_accessor32(), + infos.packed_accessor32(), + linear_indices + .packed_accessor32()); + } else { + nobag_linearize_index_kernel<<< + div_round_up(B * T, kMaxThreads), + kMaxThreads, + 0, + at::cuda::getCurrentCUDAStream()>>>( + hash_size_cumsum + .packed_accessor32(), + indices.packed_accessor32(), + offsets.packed_accessor32(), + infos.packed_accessor32(), + linear_indices + .packed_accessor32()); + } + C10_CUDA_KERNEL_LAUNCH_CHECK(); + { + size_t temp_storage_bytes = 0; + AT_CUDA_CHECK( + FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( + nullptr, + temp_storage_bytes, + linear_indices.data_ptr(), + linear_indices_sorted.data_ptr(), + infos.data_ptr(), + infos_sorted.data_ptr(), + linear_indices.numel(), + 0, + total_hash_size_bits, + at::cuda::getCurrentCUDAStream(), + false)); + auto temp_storage = at::empty( + {static_cast(temp_storage_bytes)}, + indices.options().dtype(at::kByte)); + AT_CUDA_CHECK( + FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( + temp_storage.data_ptr(), + temp_storage_bytes, + linear_indices.data_ptr(), + linear_indices_sorted.data_ptr(), + infos.data_ptr(), + infos_sorted.data_ptr(), + linear_indices.numel(), + 0, + total_hash_size_bits, + at::cuda::getCurrentCUDAStream(), + false)); + } + + sorted_linear_indices_run = at::empty_like(indices); + sorted_linear_indices_run_lengths = + at::zeros_like(indices, indices.options().dtype(at::kInt)); + sorted_linear_indices_num_runs = + at::zeros({1}, indices.options().dtype(at::kInt)); + + { + size_t temp_storage_bytes = 0; + AT_CUDA_CHECK( + FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRunLengthEncode::Encode( + nullptr, + temp_storage_bytes, + linear_indices_sorted.data_ptr(), + sorted_linear_indices_run.data_ptr(), + sorted_linear_indices_run_lengths.data_ptr(), + sorted_linear_indices_num_runs.data_ptr(), + linear_indices_sorted.numel(), + at::cuda::getCurrentCUDAStream())); + // Allocate temporary storage + auto temp_storage = at::empty( + {static_cast(temp_storage_bytes)}, + indices.options().dtype(at::kByte)); + // Run encoding + AT_CUDA_CHECK( + FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRunLengthEncode::Encode( + temp_storage.data_ptr(), + temp_storage_bytes, + linear_indices_sorted.data_ptr(), + sorted_linear_indices_run.data_ptr(), + sorted_linear_indices_run_lengths.data_ptr(), + sorted_linear_indices_num_runs.data_ptr(), + linear_indices_sorted.numel(), + at::cuda::getCurrentCUDAStream())); + } + })); + })); + + auto sorted_linear_indices_cumulative_run_lengths = + asynchronous_complete_cumsum(sorted_linear_indices_run_lengths); + + return { + linear_indices, + linear_indices_sorted, + infos_sorted, + sorted_linear_indices_run, + sorted_linear_indices_run_lengths, + sorted_linear_indices_num_runs, + sorted_linear_indices_cumulative_run_lengths}; +} diff --git a/fbgemm_gpu/src/split_table_batched_embeddings.cpp b/fbgemm_gpu/src/split_table_batched_embeddings.cpp index 5ebc3ec5c..3f2fc8c52 100644 --- a/fbgemm_gpu/src/split_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/split_table_batched_embeddings.cpp @@ -8,6 +8,8 @@ #include #include +#include "fbgemm_gpu/sparse_ops_utils.h" + using Tensor = at::Tensor; // Map index to cache_set. h_in: linear_indices; C: #cache_sets. @@ -107,48 +109,59 @@ namespace { TORCH_LIBRARY_FRAGMENT(fb, m) { m.def( "linearize_cache_indices(Tensor cache_hash_size_cumsum, Tensor indices, Tensor offsets) -> Tensor"); - m.impl( - "linearize_cache_indices", - torch::dispatch( - c10::DispatchKey::CUDA, TORCH_FN(linearize_cache_indices_cuda))); + DISPATCH_TO_CUDA("linearize_cache_indices", linearize_cache_indices_cuda); m.def( "lru_cache_populate(Tensor weights, Tensor hash_size_cumsum, int total_cache_hash_size, Tensor cache_index_table_map, Tensor weights_offsets, Tensor D_offsets, Tensor linear_cache_indices, Tensor(a!) lxu_cache_state, Tensor(b!) lxu_cache_weights, int time_stamp, Tensor(c!) lru_state, bool stochastic_rounding) -> ()"); - m.impl( - "lru_cache_populate", - torch::dispatch( - c10::DispatchKey::CUDA, TORCH_FN(lru_cache_populate_cuda))); + DISPATCH_TO_CUDA("lru_cache_populate", lru_cache_populate_cuda); m.def( "lru_cache_populate_byte(Tensor weights, Tensor hash_size_cumsum, int total_cache_hash_size, Tensor cache_index_table_map, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, Tensor linear_cache_indices, Tensor(a!) lxu_cache_state, Tensor(b!) lxu_cache_weights, int time_stamp, Tensor(c!) lru_state) -> ()"); - m.impl( - "lru_cache_populate_byte", - torch::dispatch( - c10::DispatchKey::CUDA, TORCH_FN(lru_cache_populate_byte_cuda))); + DISPATCH_TO_CUDA("lru_cache_populate_byte", lru_cache_populate_byte_cuda); m.def( "lfu_cache_populate(Tensor weights, Tensor cache_hash_size_cumsum, int total_cache_hash_size, Tensor cache_index_table_map, Tensor weights_offsets, Tensor D_offsets, Tensor linear_cache_indices, Tensor(a!) lxu_cache_state, Tensor(b!) lxu_cache_weights, Tensor(c!) lfu_state, bool stochastic_rounding) -> ()"); - m.impl( - "lfu_cache_populate", - torch::dispatch( - c10::DispatchKey::CUDA, TORCH_FN(lfu_cache_populate_cuda))); + DISPATCH_TO_CUDA("lfu_cache_populate", lfu_cache_populate_cuda); m.def( "lfu_cache_populate_byte(Tensor weights, Tensor cache_hash_size_cumsum, int total_cache_hash_size, Tensor cache_index_table_map, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, Tensor linear_cache_indices, Tensor(a!) lxu_cache_state, Tensor(b!) lxu_cache_weights, Tensor(c!) lfu_state) -> ()"); + DISPATCH_TO_CUDA("lfu_cache_populate_byte", lfu_cache_populate_byte_cuda); + m.def( + "lxu_cache_lookup(Tensor linear_cache_indices, Tensor lxu_cache_state) -> Tensor"); + DISPATCH_TO_CUDA("lxu_cache_lookup", lxu_cache_lookup_cuda); + m.def( + "lxu_cache_flush(Tensor(a!) uvm_weights, Tensor cache_hash_size_cumsum, Tensor cache_index_table_map, Tensor weights_offsets, Tensor D_offsets, int total_D, Tensor(b!) lxu_cache_state, Tensor(c!) lxu_cache_weights, bool stochastic_rounding) -> ()"); + DISPATCH_TO_CUDA("lxu_cache_flush", lxu_cache_flush_cuda); + m.def("lxu_cache_slot(int h_in, int C) -> int"); m.impl( - "lfu_cache_populate_byte", + "lxu_cache_slot", torch::dispatch( - c10::DispatchKey::CUDA, TORCH_FN(lfu_cache_populate_byte_cuda))); + c10::DispatchKey::CatchAll, TORCH_FN(host_lxu_cache_slot))); +} + +TORCH_LIBRARY_FRAGMENT(fbgemm, m) { + m.def( + "linearize_cache_indices(Tensor cache_hash_size_cumsum, Tensor indices, Tensor offsets) -> Tensor"); + DISPATCH_TO_CUDA("linearize_cache_indices", linearize_cache_indices_cuda); + m.def( + "lru_cache_populate(Tensor weights, Tensor hash_size_cumsum, int total_cache_hash_size, Tensor cache_index_table_map, Tensor weights_offsets, Tensor D_offsets, Tensor linear_cache_indices, Tensor(a!) lxu_cache_state, Tensor(b!) lxu_cache_weights, int time_stamp, Tensor(c!) lru_state, bool stochastic_rounding) -> ()"); + DISPATCH_TO_CUDA("lru_cache_populate", lru_cache_populate_cuda); + m.def( + "lru_cache_populate_byte(Tensor weights, Tensor hash_size_cumsum, int total_cache_hash_size, Tensor cache_index_table_map, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, Tensor linear_cache_indices, Tensor(a!) lxu_cache_state, Tensor(b!) lxu_cache_weights, int time_stamp, Tensor(c!) lru_state) -> ()"); + DISPATCH_TO_CUDA("lru_cache_populate_byte", lru_cache_populate_byte_cuda); + m.def( + "lfu_cache_populate(Tensor weights, Tensor cache_hash_size_cumsum, int total_cache_hash_size, Tensor cache_index_table_map, Tensor weights_offsets, Tensor D_offsets, Tensor linear_cache_indices, Tensor(a!) lxu_cache_state, Tensor(b!) lxu_cache_weights, Tensor(c!) lfu_state, bool stochastic_rounding) -> ()"); + DISPATCH_TO_CUDA("lfu_cache_populate", lfu_cache_populate_cuda); + m.def( + "lfu_cache_populate_byte(Tensor weights, Tensor cache_hash_size_cumsum, int total_cache_hash_size, Tensor cache_index_table_map, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, Tensor linear_cache_indices, Tensor(a!) lxu_cache_state, Tensor(b!) lxu_cache_weights, Tensor(c!) lfu_state) -> ()"); + DISPATCH_TO_CUDA("lfu_cache_populate_byte", lfu_cache_populate_byte_cuda); m.def( "lxu_cache_lookup(Tensor linear_cache_indices, Tensor lxu_cache_state) -> Tensor"); - m.impl( - "lxu_cache_lookup", - torch::dispatch(c10::DispatchKey::CUDA, TORCH_FN(lxu_cache_lookup_cuda))); + DISPATCH_TO_CUDA("lxu_cache_lookup", lxu_cache_lookup_cuda); m.def( "lxu_cache_flush(Tensor(a!) uvm_weights, Tensor cache_hash_size_cumsum, Tensor cache_index_table_map, Tensor weights_offsets, Tensor D_offsets, int total_D, Tensor(b!) lxu_cache_state, Tensor(c!) lxu_cache_weights, bool stochastic_rounding) -> ()"); - m.impl( - "lxu_cache_flush", - torch::dispatch(c10::DispatchKey::CUDA, TORCH_FN(lxu_cache_flush_cuda))); + DISPATCH_TO_CUDA("lxu_cache_flush", lxu_cache_flush_cuda); m.def("lxu_cache_slot(int h_in, int C) -> int"); m.impl( "lxu_cache_slot", torch::dispatch( c10::DispatchKey::CatchAll, TORCH_FN(host_lxu_cache_slot))); } + } // namespace diff --git a/fbgemm_gpu/test/merge_pooled_embeddings_test.py b/fbgemm_gpu/test/merge_pooled_embeddings_test.py index f83b46a2b..e09075e88 100644 --- a/fbgemm_gpu/test/merge_pooled_embeddings_test.py +++ b/fbgemm_gpu/test/merge_pooled_embeddings_test.py @@ -89,6 +89,32 @@ def ref(pooled_ad_embeddings, batch_indices): torch.testing.assert_allclose(output_ref, output.cpu()) torch.testing.assert_allclose(output_ref, output_cpu) + @given( + num_inputs=st.integers(min_value=1, max_value=10), + num_gpus=st.integers(min_value=1, max_value=torch.cuda.device_count()), + non_default_stream=st.booleans(), + r=st.randoms(use_true_random=False), + ) + # Can instantiate 8 contexts which takes a long time. + @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None) + def test_all_to_one_device( + self, + num_inputs, + num_gpus, + non_default_stream, + r, + ) -> None: + dst_device = torch.device(f"cuda:{r.randint(0, num_gpus - 1)}") + with torch.cuda.device(dst_device): + inputs = [torch.randn(10, 20) for _ in range(num_inputs)] + cuda_inputs = [ + input.to(f"cuda:{i % num_gpus}") for i, input in enumerate(inputs) + ] + cuda_outputs = torch.ops.fbgemm.all_to_one_device(cuda_inputs, dst_device) + for i, o in zip(inputs, cuda_outputs): + self.assertEqual(o.device, dst_device) + torch.testing.assert_allclose(o.cpu(), i) + if __name__ == "__main__": unittest.main() diff --git a/fbgemm_gpu/test/permute_pooled_embedding_modules_test.py b/fbgemm_gpu/test/permute_pooled_embedding_modules_test.py index ee49fea10..3c39d4fea 100644 --- a/fbgemm_gpu/test/permute_pooled_embedding_modules_test.py +++ b/fbgemm_gpu/test/permute_pooled_embedding_modules_test.py @@ -91,11 +91,9 @@ def test_permutation_autograd(self) -> None: output.sum().backward() # check grads for fc1 when permuted, equals to fc2 weights times input_sum + permute_res = net.permute_pooled_embeddings(net.fc1.weight.grad.view(1, 10)) self.assertTrue( - net.permute_pooled_embeddings(net.fc1.weight.grad.view(1, 10)) - .isclose(input_sum * net.fc2.weight) - .all() - .item() + permute_res.isclose(input_sum * net.fc2.weight, rtol=1e-03).all().item() ) def test_compatibility(self) -> None: diff --git a/fbgemm_gpu/test/quantize_ops_test.py b/fbgemm_gpu/test/quantize_ops_test.py index ae312fc2c..959494eff 100644 --- a/fbgemm_gpu/test/quantize_ops_test.py +++ b/fbgemm_gpu/test/quantize_ops_test.py @@ -25,6 +25,8 @@ except Exception: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") + torch.ops.load_library("//caffe2/torch/fb/sparsenn:sparsenn_operators_gpu") + torch.ops.load_library("//caffe2/torch/fb/sparsenn:sparsenn_operators") from fbgemm_gpu.test.test_utils import ( fused_rowwise_8bit_quantize_reference, fused_rowwise_8bit_dequantize_reference, @@ -243,5 +245,38 @@ def test_quantize_and_dequantize_op_cuda_large_nrows(self) -> None: torch.testing.assert_allclose(dequantized_data_gpu.cpu(), reference) +class TestDenseMLPQuantizationConversion(unittest.TestCase): + # pyre-ignore [56]: Invalid decoration, was not able to infer the type of argument + @given( + nrows=st.integers(min_value=0, max_value=100), + ncols=st.integers(min_value=0, max_value=100), + ) + @settings(deadline=10000, suppress_health_check=[HealthCheck.filter_too_much]) + def test_quantize_op(self, nrows: int, ncols: int) -> None: + ebits = 8 + mbits = 7 + bias = 127 + max_pos = (1 << ((1 << ebits) - 2 - bias)) * (2 - 2 ** (-mbits)) + min_pos = 2 ** (1 - bias - mbits) + bounding_box_size = 16 + print("MSFP parameters", bounding_box_size, ebits, mbits, bias) + input_data = torch.rand(nrows, ncols).float() + quantized_data = torch.ops.fb.FloatToMSFPQuantized( + input_data.cuda(), + bounding_box_size, + ebits, + mbits, + bias, + min_pos, + max_pos, + ) + dequantized_data = torch.ops.fb.MSFPQuantizedToFloat( + quantized_data.cuda(), ebits, mbits, bias + ) + torch.testing.assert_allclose( + dequantized_data.cpu(), input_data, rtol=1, atol=0 + ) + + if __name__ == "__main__": unittest.main() diff --git a/fbgemm_gpu/test/sparse_ops_test.py b/fbgemm_gpu/test/sparse_ops_test.py index 49000de0b..7bdcbd3e7 100644 --- a/fbgemm_gpu/test/sparse_ops_test.py +++ b/fbgemm_gpu/test/sparse_ops_test.py @@ -817,45 +817,45 @@ def test_jagged_2d_to_dense( lengths = torch.from_numpy(lengths_) offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) - ref_embeddings = torch.rand(total_lengths, D) - ref_output_embeddings = var_list_to_coo( + ref_values = torch.rand(total_lengths, D) + ref_output_values = var_list_to_coo( lengths, - ref_embeddings, + ref_values, max_sequence_length, D, ).to_dense() # test cpu forward if is_half: - embeddings = ref_embeddings.clone().half().detach().requires_grad_(True) + values = ref_values.clone().half().detach().requires_grad_(True) else: - embeddings = ref_embeddings.clone().detach().requires_grad_(True) - output_embeddings = torch.ops.fbgemm.jagged_2d_to_dense( - embeddings=embeddings, + values = ref_values.clone().detach().requires_grad_(True) + output_values = torch.ops.fbgemm.jagged_2d_to_dense( + values=values, offsets=offsets, max_sequence_length=max_sequence_length, ) - torch.testing.assert_allclose(ref_output_embeddings, output_embeddings) + torch.testing.assert_allclose(ref_output_values, output_values) if torch.cuda.is_available(): # test gpu forward - ref_embeddings = ref_embeddings.cuda() + ref_values = ref_values.cuda() if is_half: - embeddings = ref_embeddings.clone().half().detach().requires_grad_(True) + values = ref_values.clone().half().detach().requires_grad_(True) else: - embeddings = ref_embeddings.clone().detach().requires_grad_(True) + values = ref_values.clone().detach().requires_grad_(True) offsets = offsets.cuda() - ref_output_embeddings = ref_output_embeddings.cuda() - output_embeddings = torch.ops.fbgemm.jagged_2d_to_dense( - embeddings=embeddings, + ref_output_values = ref_output_values.cuda() + output_values = torch.ops.fbgemm.jagged_2d_to_dense( + values=values, offsets=offsets, max_sequence_length=max_sequence_length, ) - torch.testing.assert_allclose(ref_output_embeddings, output_embeddings) + torch.testing.assert_allclose(ref_output_values, output_values) # test gpu backward - output_embeddings.backward(ref_output_embeddings) - torch.testing.assert_allclose(ref_embeddings, embeddings.grad) + output_values.backward(ref_output_values) + torch.testing.assert_allclose(ref_values, values.grad) def test_jagged_2d_to_dense_truncation(self) -> None: # Test the case where max_sequence_length < max(lengths[i]) @@ -866,42 +866,42 @@ def test_jagged_2d_to_dense_truncation(self) -> None: embedding_dim = 16 max_sequence_length = 2 - ref_embeddings = torch.rand(total_lengths, embedding_dim) - ref_output_embeddings = var_list_to_coo( + ref_values = torch.rand(total_lengths, embedding_dim) + ref_output_values = var_list_to_coo( lengths, - ref_embeddings, + ref_values, 3, embedding_dim, ).to_dense()[:, :max_sequence_length, :] # test cpu forward - embeddings = ref_embeddings.clone().detach().requires_grad_(True) - output_embeddings = torch.ops.fbgemm.jagged_2d_to_dense( - embeddings=embeddings, + values = ref_values.clone().detach().requires_grad_(True) + output_values = torch.ops.fbgemm.jagged_2d_to_dense( + values=values, offsets=offsets, max_sequence_length=max_sequence_length, ) - torch.testing.assert_allclose(ref_output_embeddings, output_embeddings) + torch.testing.assert_allclose(ref_output_values, output_values) if torch.cuda.is_available(): # test gpu forward - ref_embeddings = ref_embeddings.cuda() - embeddings = ref_embeddings.clone().detach().requires_grad_(True) + ref_values = ref_values.cuda() + values = ref_values.clone().detach().requires_grad_(True) offsets = offsets.cuda() - ref_output_embeddings = ref_output_embeddings.cuda() - output_embeddings = torch.ops.fbgemm.jagged_2d_to_dense( - embeddings=embeddings, + ref_output_values = ref_output_values.cuda() + output_values = torch.ops.fbgemm.jagged_2d_to_dense( + values=values, offsets=offsets, max_sequence_length=max_sequence_length, ) - torch.testing.assert_allclose(ref_output_embeddings, output_embeddings) + torch.testing.assert_allclose(ref_output_values, output_values) # test gpu backward - expected_grad = ref_embeddings + expected_grad = ref_values expected_grad[4, :] = 0 # due to truncation expected_grad = expected_grad.cuda() - output_embeddings.backward(ref_output_embeddings) - torch.testing.assert_allclose(expected_grad, embeddings.grad) + output_values.backward(ref_output_values) + torch.testing.assert_allclose(expected_grad, values.grad) @settings( verbosity=Verbosity.verbose, @@ -1188,6 +1188,99 @@ def test_histogram_binning_calibration_by_feature( ) ) + # pyre-ignore [56]: Invalid decoration, was not able to infer the type of argument + @given(data_type=st.sampled_from([torch.half, torch.float32])) + @settings(verbosity=Verbosity.verbose, deadline=None) + def test_generic_histogram_binning_calibration_by_feature( + self, data_type: torch.dtype + ) -> None: + num_bins = 5000 + num_segments = 42 + + logit = torch.tensor([-0.0018, 0.0085, 0.0090, 0.0003, 0.0029]).type(data_type) + + segment_value = torch.tensor([40, 31, 32, 13, 31]) + lengths = torch.tensor([[1], [1], [1], [1], [1]]) + + num_interval = num_bins * (num_segments + 1) + bin_num_examples = torch.empty([num_interval], dtype=torch.float64).fill_(0.0) + bin_num_positives = torch.empty([num_interval], dtype=torch.float64).fill_(0.0) + + lower_bound = 0.0 + upper_bound = 1.0 + w = (upper_bound - lower_bound) / num_bins + bin_boundaries = torch.arange( + lower_bound + w, upper_bound - w / 2, w, dtype=torch.float64 + ) + + ( + calibrated_prediction, + bin_ids, + ) = torch.ops.fbgemm.generic_histogram_binning_calibration_by_feature( + logit=logit, + segment_value=segment_value, + segment_lengths=lengths, + num_segments=num_segments, + bin_num_examples=bin_num_examples, + bin_num_positives=bin_num_positives, + bin_boundaries=bin_boundaries, + positive_weight=0.4, + bin_ctr_in_use_after=10000, + bin_ctr_weight_value=0.9995, + ) + + expected_calibrated_prediction = torch.tensor( + [0.2853, 0.2875, 0.2876, 0.2858, 0.2863] + ).type(data_type) + expected_bin_ids = torch.tensor( + [206426, 161437, 166437, 71428, 161431], dtype=torch.long + ) + + torch.testing.assert_allclose( + calibrated_prediction, + expected_calibrated_prediction, + rtol=1e-03, + atol=1e-03, + ) + + self.assertTrue( + torch.equal( + bin_ids.long(), + expected_bin_ids, + ) + ) + + if torch.cuda.is_available(): + ( + calibrated_prediction_gpu, + bin_ids_gpu, + ) = torch.ops.fbgemm.generic_histogram_binning_calibration_by_feature( + logit=logit.cuda(), + segment_value=segment_value.cuda(), + segment_lengths=lengths.cuda(), + num_segments=num_segments, + bin_num_examples=bin_num_examples.cuda(), + bin_num_positives=bin_num_positives.cuda(), + bin_boundaries=bin_boundaries.cuda(), + positive_weight=0.4, + bin_ctr_in_use_after=10000, + bin_ctr_weight_value=0.9995, + ) + + torch.testing.assert_allclose( + calibrated_prediction_gpu, + expected_calibrated_prediction.cuda(), + rtol=1e-03, + atol=1e-03, + ) + + self.assertTrue( + torch.equal( + bin_ids_gpu.long(), + expected_bin_ids.cuda(), + ) + ) + @settings(verbosity=Verbosity.verbose, deadline=None) def test_segment_sum_csr(self) -> None: segment_sum_cpu = torch.ops.fbgemm.segment_sum_csr( diff --git a/fbgemm_gpu/test/split_embedding_inference_converter_test.py b/fbgemm_gpu/test/split_embedding_inference_converter_test.py index 8c56b34dc..6a0f7fd56 100644 --- a/fbgemm_gpu/test/split_embedding_inference_converter_test.py +++ b/fbgemm_gpu/test/split_embedding_inference_converter_test.py @@ -207,7 +207,6 @@ def test_quantize_workflow( rtol=1.0e-1, ) - @unittest.skipIf(open_source, "Not yet in OSS") @given( use_cpu=st.booleans() if gpu_available else st.just(True), use_array_for_index_remapping=st.booleans(), @@ -284,7 +283,6 @@ def test_l2_norm_pruning_workflow( rtol=1.0e-1, ) - @unittest.skipIf(open_source, "Not yet in OSS") @given( T=st.integers(min_value=1, max_value=10), D=st.integers(min_value=2, max_value=128), diff --git a/fbgemm_gpu/test/split_table_batched_embeddings_test.py b/fbgemm_gpu/test/split_table_batched_embeddings_test.py index e71eda48b..6fe24af76 100644 --- a/fbgemm_gpu/test/split_table_batched_embeddings_test.py +++ b/fbgemm_gpu/test/split_table_batched_embeddings_test.py @@ -1043,7 +1043,6 @@ def test_backward_dense( rtol=5.0e-3 if weights_precision == SparseType.FP16 else 1.0e-4, ) - # pyre-fixme[29]: `Union[Tensor, torch.nn.Module]` is not a function. cc = split_table_batched_embeddings_ops.DenseTableBatchedEmbeddingBagsCodegen( [(E, D) for (E, D) in zip(Es, Ds)], # NOTE: only SUM pooling can work with per_sample_weights! @@ -1562,7 +1561,6 @@ def execute_backward_adagrad_( # noqa C901 ) if use_cpu: # NOTE: GPU version of SplitTableBatchedEmbeddingBagsCodegen doesn't support double. - # pyre-fixme[29]: `Union[Tensor, torch.nn.Module]` is not a function. cc = cc.double() per_sample_weights = to_device(xw.contiguous().view(-1), use_cpu) @@ -2705,12 +2703,29 @@ def execute_nbit_forward_( pooling_mode == split_table_batched_embeddings_ops.PoolingMode.SUM or not weighted ) + # No bag ops only work on GPUs, no mixed, no weighted + assume( + not use_cpu + or pooling_mode != split_table_batched_embeddings_ops.PoolingMode.NONE + ) + assume( + not mixed + or pooling_mode != split_table_batched_embeddings_ops.PoolingMode.NONE + ) + assume( + not weighted + or pooling_mode != split_table_batched_embeddings_ops.PoolingMode.NONE + ) mode = "sum" + do_pooling = True if pooling_mode == split_table_batched_embeddings_ops.PoolingMode.SUM: mode = "sum" elif pooling_mode == split_table_batched_embeddings_ops.PoolingMode.MEAN: mode = "mean" + else: + mode = "sum" + do_pooling = False E = int(10 ** log_E) if not mixed_weights_ty: @@ -2725,7 +2740,7 @@ def execute_nbit_forward_( if not mixed: Ds = [D] * T - Es = [int(1e4)] * T + Es = [E] * T else: Ds = [ round_up( @@ -2739,10 +2754,16 @@ def execute_nbit_forward_( np.random.randint(low=int(0.5 * E), high=int(2.0 * E)) for _ in range(T) ] - bs = [ - to_device(torch.nn.EmbeddingBag(E, D, mode=mode, sparse=True), use_cpu) - for (E, D) in zip(Es, Ds) - ] + if do_pooling: + bs = [ + to_device(torch.nn.EmbeddingBag(E, D, mode=mode, sparse=True), use_cpu) + for (E, D) in zip(Es, Ds) + ] + else: + bs = [ + to_device(torch.nn.Embedding(E, D, sparse=True), use_cpu) + for (E, D) in zip(Es, Ds) + ] if use_cpu: managed = [split_table_batched_embeddings_ops.EmbeddingLocation.HOST] * T @@ -2907,19 +2928,31 @@ def comp(i: int) -> np.ndarray: else cc(indices.int(), offsets.int(), xw.contiguous().view(-1).cpu()) ) - if B == 0: + if do_pooling and B == 0: self.assertEqual(fc2.size(), (0, cc.total_D)) return fs = ( - [b_indices(b, x, use_cpu=use_cpu) for (b, x) in zip(bs, xs)] + [ + b_indices(b, x, use_cpu=use_cpu, do_pooling=do_pooling) + for (b, x) in zip(bs, xs) + ] if not weighted else [ - b_indices(b, x, per_sample_weights=xw.view(-1), use_cpu=use_cpu) + b_indices( + b, + x, + per_sample_weights=xw.view(-1), + use_cpu=use_cpu, + do_pooling=do_pooling, + ) for (b, x, xw) in zip(bs, xs, xws) ] ) - f = torch.cat([f.view(B, -1) for f in fs], dim=1) + if do_pooling: + f = torch.cat([f.view(B, -1) for f in fs], dim=1) + else: + f = torch.cat(fs, dim=0).view(-1, D) torch.testing.assert_allclose( fc2.float().cpu(), f.float().cpu(), @@ -2940,6 +2973,7 @@ def comp(i: int) -> np.ndarray: [ split_table_batched_embeddings_ops.PoolingMode.SUM, split_table_batched_embeddings_ops.PoolingMode.MEAN, + split_table_batched_embeddings_ops.PoolingMode.NONE, ] ), weights_ty=st.sampled_from( @@ -3017,6 +3051,7 @@ def test_nbit_forward_int( [ split_table_batched_embeddings_ops.PoolingMode.SUM, split_table_batched_embeddings_ops.PoolingMode.MEAN, + split_table_batched_embeddings_ops.PoolingMode.NONE, ] ), weights_ty=st.sampled_from( @@ -3440,14 +3475,14 @@ def test_bounds_check( warning.cuda(), ) indices_copy = indices.clone() - torch.ops.fb.bounds_check_indices( + torch.ops.fbgemm.bounds_check_indices( rows_per_table, indices, offsets, bounds_check_mode, warning ) # we don't modify when we are in-bounds. torch.testing.assert_allclose(indices_copy, indices) indices[:] = torch.iinfo(dtype).max if bounds_check_mode != BoundsCheckMode.FATAL: - torch.ops.fb.bounds_check_indices( + torch.ops.fbgemm.bounds_check_indices( rows_per_table, indices, offsets, bounds_check_mode, warning ) torch.testing.assert_allclose(indices, torch.zeros_like(indices)) @@ -3456,7 +3491,7 @@ def test_bounds_check( else: if use_cpu and indices.numel(): with self.assertRaises(RuntimeError): - torch.ops.fb.bounds_check_indices( + torch.ops.fbgemm.bounds_check_indices( rows_per_table, indices, offsets, bounds_check_mode, warning ) # It would be nice to test the CUDA implementation of BoundsCheckMode==FATAL, @@ -3470,7 +3505,7 @@ def test_bounds_check( if offsets.numel() > 1: offsets[-1] += 100 if bounds_check_mode != BoundsCheckMode.FATAL: - torch.ops.fb.bounds_check_indices( + torch.ops.fbgemm.bounds_check_indices( rows_per_table, indices, offsets, bounds_check_mode, warning ) if offsets.numel() > 0: @@ -3480,11 +3515,11 @@ def test_bounds_check( if bounds_check_mode == BoundsCheckMode.WARNING: # -1 because when we have 2 elements in offsets, we have only 1 # warning for the pair. - self.assertEqual(warning.item(), min(2, offsets.numel() - 1)) + self.assertGreaterEqual(warning.item(), min(2, offsets.numel() - 1)) else: if use_cpu and indices.numel(): with self.assertRaises(RuntimeError): - torch.ops.fb.bounds_check_indices( + torch.ops.fbgemm.bounds_check_indices( rows_per_table, indices, offsets, bounds_check_mode, warning ) diff --git a/fbgemm_gpu/test/tensor_assert_test.cpp b/fbgemm_gpu/test/tensor_assert_test.cpp new file mode 100644 index 000000000..6df0dc2d0 --- /dev/null +++ b/fbgemm_gpu/test/tensor_assert_test.cpp @@ -0,0 +1,32 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * Copyright (c) Intel Corporation. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include "fbgemm_gpu/sparse_ops_utils.h" + +using namespace ::testing; + +TEST(tensor_assert_test, gpu_asserts) { + at::Tensor on_cpu_empty; + + ASSERT_EQ(on_cpu_empty.numel(), 0); + EXPECT_NO_THROW(TENSOR_ON_CPU(on_cpu_empty)); + ASSERT_TRUE(torch_tensor_empty_or_on_cuda_gpu_check(on_cpu_empty)); + EXPECT_NO_THROW(TENSOR_EMPTY_OR_ON_CUDA_GPU(on_cpu_empty)); + EXPECT_ANY_THROW(TENSOR_ON_CUDA_GPU(on_cpu_empty)); + + auto on_cpu_non_empty = at::randint(10, 32); + const auto on_cuda_non_empty = on_cpu_non_empty.to(at::device(at::kCUDA)); + + ASSERT_NE(on_cpu_non_empty.numel(), 0); + EXPECT_NO_THROW(TENSOR_ON_CPU(on_cpu_non_empty)); + EXPECT_ANY_THROW(TENSOR_ON_CPU(on_cuda_non_empty)); + EXPECT_NO_THROW(TENSOR_ON_CUDA_GPU(on_cuda_non_empty)); + EXPECT_NO_THROW(TENSOR_EMPTY_OR_ON_CUDA_GPU(on_cuda_non_empty)); +} diff --git a/fbgemm_gpu/test/uvm_test.py b/fbgemm_gpu/test/uvm_test.py index 6aca447cc..33fc7df9f 100644 --- a/fbgemm_gpu/test/uvm_test.py +++ b/fbgemm_gpu/test/uvm_test.py @@ -39,13 +39,13 @@ class UvmTest(unittest.TestCase): @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) def test_is_uvm_tensor(self, sizes: List[int], vanilla: bool) -> None: op = ( - torch.ops.fb.new_managed_tensor + torch.ops.fbgemm.new_managed_tensor if not vanilla - else torch.ops.fb.new_vanilla_managed_tensor + else torch.ops.fbgemm.new_vanilla_managed_tensor ) uvm_t = op(torch.empty(0, device="cuda:0", dtype=torch.float), sizes) - assert torch.ops.fb.is_uvm_tensor(uvm_t) - assert torch.ops.fb.uvm_storage(uvm_t) + assert torch.ops.fbgemm.is_uvm_tensor(uvm_t) + assert torch.ops.fbgemm.uvm_storage(uvm_t) @unittest.skipIf(*gpu_unavailable) @given( @@ -55,19 +55,19 @@ def test_is_uvm_tensor(self, sizes: List[int], vanilla: bool) -> None: @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) def test_uvm_to_cpu(self, sizes: List[int], vanilla: bool) -> None: op = ( - torch.ops.fb.new_managed_tensor + torch.ops.fbgemm.new_managed_tensor if not vanilla - else torch.ops.fb.new_vanilla_managed_tensor + else torch.ops.fbgemm.new_vanilla_managed_tensor ) uvm_t = op(torch.empty(0, device="cuda:0", dtype=torch.float), sizes) - cpu_t = torch.ops.fb.uvm_to_cpu(uvm_t) - assert not torch.ops.fb.is_uvm_tensor(cpu_t) - assert torch.ops.fb.uvm_storage(cpu_t) + cpu_t = torch.ops.fbgemm.uvm_to_cpu(uvm_t) + assert not torch.ops.fbgemm.is_uvm_tensor(cpu_t) + assert torch.ops.fbgemm.uvm_storage(cpu_t) uvm_t.copy_(cpu_t) - assert torch.ops.fb.is_uvm_tensor(uvm_t) - assert torch.ops.fb.uvm_storage(uvm_t) + assert torch.ops.fbgemm.is_uvm_tensor(uvm_t) + assert torch.ops.fbgemm.uvm_storage(uvm_t) # Test use of cpu tensor after freeing the uvm tensor del uvm_t @@ -87,13 +87,13 @@ def test_enum(self) -> None: @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) def test_cudaMemAdvise(self, sizes: List[int], vanilla: bool) -> None: op = ( - torch.ops.fb.new_managed_tensor + torch.ops.fbgemm.new_managed_tensor if not vanilla - else torch.ops.fb.new_vanilla_managed_tensor + else torch.ops.fbgemm.new_vanilla_managed_tensor ) uvm_t = op(torch.empty(0, device="cuda:0", dtype=torch.float), sizes) - assert torch.ops.fb.is_uvm_tensor(uvm_t) - assert torch.ops.fb.uvm_storage(uvm_t) + assert torch.ops.fbgemm.is_uvm_tensor(uvm_t) + assert torch.ops.fbgemm.uvm_storage(uvm_t) # pyre-ignore[16] cudaMemAdvise(uvm_t, cudaMemoryAdvise.cudaMemAdviseSetAccessedBy) @@ -108,13 +108,13 @@ def test_cudaMemAdvise(self, sizes: List[int], vanilla: bool) -> None: @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) def test_cudaMemPrefetchAsync(self, sizes: List[int], vanilla: bool) -> None: op = ( - torch.ops.fb.new_managed_tensor + torch.ops.fbgemm.new_managed_tensor if not vanilla - else torch.ops.fb.new_vanilla_managed_tensor + else torch.ops.fbgemm.new_vanilla_managed_tensor ) uvm_t = op(torch.empty(0, device="cuda:0", dtype=torch.float), sizes) - assert torch.ops.fb.is_uvm_tensor(uvm_t) - assert torch.ops.fb.uvm_storage(uvm_t) + assert torch.ops.fbgemm.is_uvm_tensor(uvm_t) + assert torch.ops.fbgemm.uvm_storage(uvm_t) cudaMemPrefetchAsync(uvm_t) @@ -130,20 +130,20 @@ def test_cudaMemPrefetchAsync(self, sizes: List[int], vanilla: bool) -> None: @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) def test_uvm_to_device(self, sizes: List[int], vanilla: bool) -> None: op = ( - torch.ops.fb.new_managed_tensor + torch.ops.fbgemm.new_managed_tensor if not vanilla - else torch.ops.fb.new_vanilla_managed_tensor + else torch.ops.fbgemm.new_vanilla_managed_tensor ) uvm_t = op(torch.empty(0, device="cuda:0", dtype=torch.float), sizes) - assert torch.ops.fb.is_uvm_tensor(uvm_t) - assert torch.ops.fb.uvm_storage(uvm_t) + assert torch.ops.fbgemm.is_uvm_tensor(uvm_t) + assert torch.ops.fbgemm.uvm_storage(uvm_t) # Reference uvm tensor from second cuda device device_prototype = torch.empty(0, device="cuda:1") - second_t = torch.ops.fb.uvm_to_device(uvm_t, device_prototype) + second_t = torch.ops.fbgemm.uvm_to_device(uvm_t, device_prototype) - assert torch.ops.fb.is_uvm_tensor(second_t) - assert torch.ops.fb.uvm_storage(second_t) + assert torch.ops.fbgemm.is_uvm_tensor(second_t) + assert torch.ops.fbgemm.uvm_storage(second_t) assert second_t.device == device_prototype.device @unittest.skipIf(*gpu_unavailable) @@ -156,19 +156,19 @@ def test_uvm_to_device(self, sizes: List[int], vanilla: bool) -> None: @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) def test_uvm_slice(self, sizes: List[int], vanilla: bool) -> None: op = ( - torch.ops.fb.new_managed_tensor + torch.ops.fbgemm.new_managed_tensor if not vanilla - else torch.ops.fb.new_vanilla_managed_tensor + else torch.ops.fbgemm.new_vanilla_managed_tensor ) uvm_t = op(torch.empty(0, device="cuda:0", dtype=torch.float), sizes) - assert torch.ops.fb.is_uvm_tensor(uvm_t) - assert torch.ops.fb.uvm_storage(uvm_t) + assert torch.ops.fbgemm.is_uvm_tensor(uvm_t) + assert torch.ops.fbgemm.uvm_storage(uvm_t) # Reference uvm tensor from second cuda device second_t = uvm_t[0] - assert torch.ops.fb.is_uvm_tensor(second_t) - assert torch.ops.fb.uvm_storage(second_t) + assert torch.ops.fbgemm.is_uvm_tensor(second_t) + assert torch.ops.fbgemm.uvm_storage(second_t) @unittest.skipIf(*gpu_unavailable) @given( @@ -179,6 +179,28 @@ def test_uvm_slice(self, sizes: List[int], vanilla: bool) -> None: ) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) def test_uvm_memadviceDontFork(self, sizes: List[int], vanilla: bool) -> None: + op = ( + torch.ops.fbgemm.new_managed_tensor + if not vanilla + else torch.ops.fbgemm.new_vanilla_managed_tensor + ) + uvm_t = op(torch.empty(0, device="cuda:0", dtype=torch.float), sizes) + assert torch.ops.fbgemm.is_uvm_tensor(uvm_t) + assert torch.ops.fbgemm.uvm_storage(uvm_t) + + cpu_t = torch.ops.fbgemm.uvm_to_cpu(uvm_t) + + torch.ops.fbgemm.uvm_mem_advice_dont_fork(cpu_t) + + @unittest.skipIf(*gpu_unavailable) + @given( + sizes=st.lists( + st.integers(min_value=1, max_value=(512)), min_size=1, max_size=3 + ), + vanilla=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) + def test_uvm_to_cpu_clone(self, sizes: List[int], vanilla: bool) -> None: op = ( torch.ops.fb.new_managed_tensor if not vanilla @@ -188,9 +210,10 @@ def test_uvm_memadviceDontFork(self, sizes: List[int], vanilla: bool) -> None: assert torch.ops.fb.is_uvm_tensor(uvm_t) assert torch.ops.fb.uvm_storage(uvm_t) - cpu_t = torch.ops.fb.uvm_to_cpu(uvm_t) + cpu_clone = torch.ops.fb.uvm_to_cpu_clone(uvm_t) - torch.ops.fb.uvm_mem_advice_dont_fork(cpu_t) + assert not torch.ops.fb.is_uvm_tensor(cpu_clone) + assert not torch.ops.fb.uvm_storage(cpu_clone) if __name__ == "__main__": diff --git a/src/DirectConv.h b/src/DirectConv.h new file mode 100644 index 000000000..54816bd21 --- /dev/null +++ b/src/DirectConv.h @@ -0,0 +1,159 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "./CodeCache.h" +#include "fbgemm/ConvUtils.h" +#include "fbgemm/Fbgemm.h" +#include "fbgemm/Utils.h" +/*#define FBGEMM_LOG_CODE 1*/ + +namespace fbgemm { + +namespace x86 = asmjit::x86; + +/** + * @brief Generate instructions for initializing the C registers to 0. + */ +void initCRegs(x86::Emitter* a, int rowRegs, int colRegs); + +static asmjit::JitRuntime& runtime() { + static asmjit::JitRuntime rt; //< JIT Runtime for asmjit, + // depents on other static + // variables. Required to prevent + // initialization order fiasco + return rt; +} + +template +class DirectConvCodeGenBase { + public: + using jit_micro_kernel_fp = void (*)( + const TA* bufferA, + const TB* bufferB, + const TB* b_pf, + TC* bufferC, + int kc, + int ldc); + + static std::mutex rtMutex_; ///< Control access to runtime; + + // The hash depends on accumulate, mc, nc, ncb, kcb, nr, mr + static CodeCache< + std::tuple, + jit_micro_kernel_fp> + codeCache_; ///< JIT Code Cache for reuse. + + /** + * @brief Generate instructions for storing the C registers back to the + * memory. + */ + template + void storeCRegs( + x86::Emitter* a, + int rowRegs, + int colRegs, + x86::Gp C_Offset, + x86::Gp ldcReg, + bool accum); + + /** + * @brief Generate filename to dump generated code + * (debug-only) + */ + template + static std::string getCodeLoggingFile( + bool accum, + int mc, + int nc, + int NCB, + int KCB, + int MR, + int NR) { + std::ostringstream oss; + oss << "directconv_"; + if (std::is_same::value) { + oss << "acc16_"; + } else if (std::is_same::value) { + oss << "acc32_"; + } else { + oss << "unknown_"; + } + oss << "accum-" + std::to_string(accum) << "_MC-" + std::to_string(mc) + << "_NC-" + std::to_string(nc) << "_NCB-" + std::to_string(NCB) + << "_KCB-" + std::to_string(KCB) << "_MR-" + std::to_string(MR) + << "_NR-" + std::to_string(NR); + if (instSet == inst_set_t::avx512_vnni) { + oss << "_avx512vnni"; + } else if (instSet == inst_set_t::avx512) { + oss << "_avx512"; + } else if (instSet == inst_set_t::avx512_ymm) { + oss << "_avx512_ymm"; + } else if (instSet == inst_set_t::avx2) { + oss << "_avx2"; + } + oss << ".txt"; + return oss.str(); + } + + /** + * @brief Get or Create the instructions for macro-kernel. + * + * If the problem size (mc, nc) and accumulation flag (accum) can be found in + * the code cache (a hash map), then get the macro-kernel instructions + * directly from it. Otherwise, create the instructions for macro-kernel, and + * store that into the code cache. + */ + template + jit_micro_kernel_fp + getOrCreateDirectConv(bool accum, int32_t mc, int32_t nc, int32_t kc); + + /** + * @brief Generate instructions for computing block in the rank-k update. + */ + template + void genComputeBlock( + x86::Emitter* a, + x86::Gp buffer_A, + x86::Gp buffer_B, + x86::Gp B_pf, + int rowRegs, + int colRegs, + int lda); + /** + * @brief Generate instructions for computing block in the rank-k update. + */ + template + void genComputeBlockDirectConv( + x86::Emitter* a, + x86::Gp buffer_A, + x86::Gp buffer_B, + x86::Gp B_pf, + int rowRegs, + int colRegs, + int strideXich); +}; + +template +std::mutex DirectConvCodeGenBase::rtMutex_; + +template +CodeCache< + std::tuple, + typename DirectConvCodeGenBase::jit_micro_kernel_fp> + DirectConvCodeGenBase::codeCache_; + +}; // namespace fbgemm diff --git a/src/GenerateKernel.h b/src/GenerateKernel.h index 47ba5a2d6..336a39f7c 100644 --- a/src/GenerateKernel.h +++ b/src/GenerateKernel.h @@ -14,7 +14,7 @@ #include #include "./CodeCache.h" #include "fbgemm/Fbgemm.h" -/*#define FBGEMM_LOG_CODE 1*/ +//#define FBGEMM_LOG_CODE 1 namespace fbgemm { diff --git a/src/GenerateKernelDirectConvU8S8S32ACC32.cc b/src/GenerateKernelDirectConvU8S8S32ACC32.cc new file mode 100644 index 000000000..1a2d5a112 --- /dev/null +++ b/src/GenerateKernelDirectConvU8S8S32ACC32.cc @@ -0,0 +1,475 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include "./CodeGenHelpers.h" +#include "./DirectConv.h" + +namespace fbgemm { + +namespace x86 = asmjit::x86; +/** + * Generate AVX512 instructions for computing block in the rank-k update of + * 32-bit Accumulation kernel. + * + * this compute block implements the following register blocking + // register blocking: + // leverage vpmaddubsw instructions + for (int _icb = icb; _icb < icb + row_interleave; _icb ++ ) { + for (int _oc = oc; _oc < oc + mRegBLockSize; _oc ++) { + for (int _ow = ow; _ow < std::min(ow + 12, OUT_DIM[1]); _ow ++) { + out[_oc + _ow * OC] += + input[_ich + (_ow + s * stride[1]) * IC + r * IC * IN_DIM[1]] + * + weights[(((((_oc/8) * (IC/4) + icb/4) * K[0] + r) * K[1] + s) + *8 + (_oc % 8)) * 4 + (_icb % 4)]; + } + } + } + * + */ + +/** + * Generate AVX512 instructions for storing the C registers back to the memory + * in 32-bit Accumulation kernel. + */ +template <> +template +void DirectConvCodeGenBase::storeCRegs( + x86::Emitter* a, + int rowRegs, + int colRegs, + x86::Gp C_Offset, + x86::Gp ldcReg, + bool accum) { + using VecT = typename simd_info::vec_reg_t; + static constexpr int vectorLen = simd_info::WIDTH_BYTES; + + for (int i = 0; i < rowRegs; ++i) { + if (i != 0) { + a->add(C_Offset, ldcReg); + } else { + a->xor_(C_Offset.r32(), C_Offset.r32()); + } + for (int j = 0; j < colRegs; ++j) { + if (accum) { + a->vpaddd( + VecT(i * colRegs + j), + VecT(i * colRegs + j), + x86::dword_ptr( + a->zcx(), C_Offset, 0, j * vectorLen * sizeof(int8_t))); + } + a->vmovups( + x86::dword_ptr(a->zcx(), C_Offset, 0, j * vectorLen * sizeof(int8_t)), + VecT(i * colRegs + j)); + } + } +} + +template <> +template +void DirectConvCodeGenBase:: + genComputeBlockDirectConv( + x86::Emitter* a, + x86::Gp buffer_A, + x86::Gp buffer_B, + x86::Gp B_pf, + int rowRegs, + int colRegs, + int strideXich) { + static constexpr int vectorLen = simd_info::WIDTH_BYTES; + using VecRegT = typename simd_info::vec_reg_t; + constexpr int numRegs = simd_info::NUM_VEC_REGS; + + // used for matrix A + VecRegT AReg(numRegs - 1); + + // used for matrix B + VecRegT BReg(numRegs - 2); + + // Contains 16-bit 1s + VecRegT oneReg(numRegs - 3); + + // temporary register + VecRegT res1(numRegs - 4); + + for (int j = 0; j < colRegs; ++j) { + // load B + emitLoadDWord( + a, BReg, x86::dword_ptr(buffer_B, j * vectorLen * sizeof(int8_t))); + // load A, broadcast and fmas + for (int i = 0; i < rowRegs; ++i) { + a->vpbroadcastd( + AReg, x86::dword_ptr(buffer_A, (i * strideXich) * sizeof(uint8_t))); + a->vpmaddubsw(res1, AReg, BReg); + a->vpmaddwd(res1, oneReg, res1); + a->vpaddd(VecRegT(i * colRegs + j), res1, VecRegT(i * colRegs + j)); + } + // a->prefetcht0(x86::dword_ptr(B_pf, j * vectorLen * sizeof(int8_t))); + } +} + +/** + * Get or Create the AVX256 instructions for 32-bit Accumulation macro-kernel. + * + * This function implements a direct convolution kernel that is specialized + * for kernel size (2, 6) and input_height (IN_DIM[0]) = 2. + * + * More specifically the implementation has the following requirements: + * * Weights has layout {OC/8, KH, KW, IC/4, 8, 4} + * * kernel size (2, 6), IN_DIM[0] = 2, therefore: OUT_DIM[0] = 1 + * * Features are in channel last format + * + * mRegBlockSize = 12: the number of avx2 registers for output + * nRegBlockSize = 8: the # of output elements in one avx2 register + * row_interleave = 4: the horizontal reduction size for vpmaddubsw instruction + * O1: output_width: OUT_DIM[1] + * i1Xich: input_width multiply input_channel: IN_DIM[1] x IC + * strideXich: stride multiply input_channel: stride[1] x input_channel + * + * + * The kernel implements the following algorithm: + +for (int ow = 0; ow < OUT_DIM[1]; ow+=12) { + L1 blocking: following weights are in L1 cache + for (int s = 0; s < K[1]; ++s) { + for (int r = 0; r < K[0]; ++r) { + for (int icb = 0; icb < IC; icb+=row_interleave) { + + // register blocking: + // leverage vpmaddubsw instructions + for (int _icb = icb; _icb < icb + row_interleave; _icb ++ ) { + for (int _oc = oc; _oc < oc + mRegBLockSize; _oc ++) { + for (int _ow = ow; _ow < std::min(ow + 12, OUT_DIM[1]); _ow ++) { + out[_oc + _ow * OC] += + input[_ich + (_ow + s * stride[1]) * IC + r * IC * IN_DIM[1]] + * + weights[(((((_oc/8) * (IC/4) + icb/4) * K[0] + r) * K[1] + s) + *8 + (_oc % 8)) * 4 + (_icb % 4)]; + + // If we get rid of the brackets, and substitute corrresponding +variables + // + // input[_ich + _ow * IC + s * strideXich + r * i1Xich] + // * + // weights[(((((_oc/8) * (IC/4) + icb/4) * K[0] + r) * K[1] + s) + // *8 + (_oc % 8)) * 4 + (_icb % 4)]; + } + } + } + + } + } + } + * + */ +template <> +template +DirectConvCodeGenBase::jit_micro_kernel_fp +DirectConvCodeGenBase::getOrCreateDirectConv( + bool accum, + int32_t O1, + int32_t i1Xich, + int32_t strideXich) { + using VecRegT = typename simd_info::vec_reg_t; + constexpr int numRegs = simd_info::NUM_VEC_REGS; + static constexpr int vectorLen = simd_info::WIDTH_BYTES; + + std::tuple kernelSig; + // int ichSize = 32; + int mRegBlockSize = 12; + int nRegBlockSize = 8; + // int nRegBlockSizeMin; + int row_interleave = 4; + + kernelSig = std::make_tuple( + accum, O1, i1Xich, strideXich, i1Xich, mRegBlockSize, nRegBlockSize); + + return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp { + asmjit::CodeHolder code; + code.init(runtime().environment()); + x86::Assembler assembler(&code); + x86::Emitter* a = assembler.as(); +#if defined(FBGEMM_LOG_CODE) + // generated code logging + FILE* codeLogfile = fopen( + getCodeLoggingFile( + accum, O1, i1Xich, strideXich, i1Xich, mRegBlockSize, nRegBlockSize) + .c_str(), + "w"); + asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); + if (codeLogger) { + code.setLogger(codeLogger); + } +#endif + + const int maxMRegs = mRegBlockSize; + (void)maxMRegs; // Suppress unused variable warning + const int maxNRegs = nRegBlockSize * row_interleave / vectorLen; + assert( + maxMRegs * maxNRegs <= numRegs - 4 && + "MRegs x NRegs is above available registers (MAX_REGS - 4)"); + + int O1RegBlocks = O1 / mRegBlockSize; + int O1RegBlocksRem = O1 % mRegBlockSize; + + // arguments to the function created + x86::Gp buffer_A = a->zdi(); + x86::Gp buffer_B = a->zsi(); + x86::Gp B_pf = a->zdx(); + x86::Gp CBase = a->zcx(); + x86::Gp ichXk1 = a->gpz(8); + x86::Gp ldcReg = a->gpz(9); + + asmjit::FuncDetail func; + func.init( + asmjit::FuncSignatureT< + void, + uint8_t*, + int8_t*, + int8_t*, + int32_t*, + int, + int>(asmjit::CallConv::kIdHost), + a->environment()); + + asmjit::FuncFrame frame; + frame.init(func); + + auto dirtyVecRegs = asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15); + if (numRegs >= 16) { + dirtyVecRegs |= asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) | + asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31); + } + + frame.setDirtyRegs(x86::Reg::kGroupVec, dirtyVecRegs); + frame.setDirtyRegs( + x86::Reg::kGroupGp, + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + + asmjit::FuncArgsAssignment args(&func); + args.assignAll(buffer_A, buffer_B, B_pf, CBase, ichXk1, ldcReg); + + args.updateFuncFrame(frame); + frame.finalize(); + + a->emitProlog(frame); + a->emitArgsAssignment(frame, args); + + asmjit::Label LoopMBlocks = a->newLabel(); + // asmjit::Label LoopOBlocks = a->newLabel(); + // asmjit::Label LoopNBlocks = a->newLabel(); + + x86::Gp buffer_B_saved = a->gpz(10); + x86::Gp C_Offset = a->gpz(11); + // x86::Gp B_pf_saved = a->gpz(12); + x86::Gp iIdx = a->gpz(13); + // x86::Gp jIdx = a->gpz(14); + x86::Gp kIdx = a->gpz(15); + // x86::Gp B_pf = a->gpz(8); + + VecRegT oneReg(numRegs - 3); + + gen16BitVectorOne(a, oneReg); + a->imul(ldcReg, ldcReg, static_cast(sizeof(int32_t))); + // a->xor_(C_Offset.r32(), C_Offset.r32()); + + // a->mov(B_pf_saved, B_pf); + + int colRegs = maxNRegs; + + auto issueLoopOverK = [&](int rowRegs) { + // loopKLabel: corresponds to loop "r" where r = 0 + // loopK0Label: corresponds to loop "r" where r = 1 + asmjit::Label LoopKLabel = a->newLabel(); + asmjit::Label LoopK0Label = a->newLabel(); + + // Init C (result) vector registers + initCRegs(a, rowRegs, colRegs); + + // Loops over K: input channel + // a.k.a this issueLoopOverK code block generates code + // corresponding to the "ich" loop of the psedo-code + a->xor_(kIdx.r32(), kIdx.r32()); + a->bind(LoopKLabel); + + // k is incremented by row_interleave + a->add(kIdx, static_cast(row_interleave)); + + // this ComputeBlock generates code correspondent to + // the above psedu-code since the kernel_height loop (loop "r"). + // And because K[0] == 2 and IN_DIM[2] (requirement #2), + // we can unroll loop "r" here. Thus this following + // genComputeBlockDirectConv generates code for loop "r" = 0 + genComputeBlockDirectConv( + a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, strideXich); + + // update buffer_A address for next k iteration + a->add( + buffer_A, static_cast(row_interleave * sizeof(uint8_t))); + + // update buffer_B address for next k iteration + a->add(buffer_B, static_cast(8 * sizeof(int32_t))); + a->add(B_pf, static_cast(8 * sizeof(int32_t))); + + a->cmp(kIdx, ichXk1); + a->jl(LoopKLabel); + + a->sub(buffer_A, ichXk1); + + a->add(buffer_A, static_cast(i1Xich)); + + a->xor_(kIdx.r32(), kIdx.r32()); + a->bind(LoopK0Label); + + // k is incremented by row_interleave + a->add(kIdx, static_cast(row_interleave)); + + // this ComputeBlock generates code that corresponds + // to the kernel_height loop (loop "r") in the psedu-code above. + // And the following genComputeBlockDirectConv + // generates code for loop "r" where "r" = 1 + genComputeBlockDirectConv( + a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, strideXich); + + // update buffer_A address for next k iteration + a->add( + buffer_A, static_cast(row_interleave * sizeof(uint8_t))); + + // update buffer_B address for next k iteration + a->add(buffer_B, static_cast(8 * sizeof(int32_t))); + a->add(B_pf, static_cast(8 * sizeof(int32_t))); + + a->cmp(kIdx, ichXk1); + a->jl(LoopK0Label); + + a->sub(buffer_A, ichXk1); + + // store C matrix + storeCRegs(a, rowRegs, colRegs, C_Offset, ldcReg, accum); + }; + + if (O1RegBlocks > 0) { + // move 0 to iteration variables + a->xor_(iIdx.r32(), iIdx.r32()); + + // iIdex loop corresponds to kernel_width loop (loop "s") + // in the direct conv loops + a->bind(LoopMBlocks); + a->inc(iIdx); + + // save B_buffer address + a->mov(buffer_B_saved, buffer_B); + + issueLoopOverK(mRegBlockSize); + + int rowRegs = mRegBlockSize; + + // reset A + a->sub(buffer_A, static_cast(i1Xich)); + + // increment A for next block + a->add( + buffer_A, + static_cast(rowRegs * strideXich * sizeof(uint8_t))); + + // B for next block + a->mov(buffer_B, buffer_B_saved); + + // increment C for next B block + // ldcReg already multiplied with 4 (sizeof(int32_t)) + a->imul( + C_Offset, ldcReg, static_cast(rowRegs * sizeof(int8_t))); + a->add(CBase, C_Offset); + + // a->add(CBase, static_cast(12*16*4)); + // storeCRegs(a, 12, 1, C_Offset, ldcReg, accum); + + a->cmp(iIdx, O1RegBlocks); + a->jl(LoopMBlocks); + } + + // generate code for remainder + if (O1RegBlocksRem > 0) { + issueLoopOverK(O1RegBlocksRem); + } + + a->emitEpilog(frame); + + jit_micro_kernel_fp fn; + asmjit::Error err; + { + std::unique_lock lock(rtMutex_); + err = runtime().add(&fn, &code); + } + if (err) { + std::cout << "Error: in fn add" << std::endl; + return nullptr; + } + +#if defined(FBGEMM_LOG_CODE) + fclose(codeLogfile); + delete codeLogger; +#endif + + return fn; + }); +} + +/** + * Instantiate the inst_set_t::avx512 instructions for store kernel. + * + */ +template void DirectConvCodeGenBase:: + storeCRegs( + x86::Emitter* a, + int rowRegs, + int colRegs, + x86::Gp C_Offset, + x86::Gp ldcReg, + bool accum); + +/** + * Instantiate the inst_set_t::avx512_ymm instructions for store kernel. + * + */ +template void DirectConvCodeGenBase:: + storeCRegs( + x86::Emitter* a, + int rowRegs, + int colRegs, + x86::Gp C_Offset, + x86::Gp ldcReg, + bool accum); + +/** + * Instantiate the inst_set_t::avx2 instructions for store kernel. + * + */ +template void DirectConvCodeGenBase:: + storeCRegs( + x86::Emitter* a, + int rowRegs, + int colRegs, + x86::Gp C_Offset, + x86::Gp ldcReg, + bool accum); + +/** + * Instantiate the AVX2 instructions for 32-bit Accumulation macro-kernel. + * + */ +template DirectConvCodeGenBase:: + jit_micro_kernel_fp + DirectConvCodeGenBase:: + getOrCreateDirectConv( + bool accum, + int32_t O1, + int32_t i1Xich, + int32_t strideXich); + +} // namespace fbgemm From 0cfb79250dc7e0e0b41679340f9bbbf6915fa539 Mon Sep 17 00:00:00 2001 From: Li Li Date: Wed, 2 Mar 2022 12:41:02 -0600 Subject: [PATCH 05/76] Enable `split_table_batched_embeddings_test.py` (#10) * * added skipIfRocm and TEST_WITH_ROCM in split_table_batched_embeddings_test. * added __any_sync_fbgemm that replaces __any_sync. * 26 tests ran in split_table_batched_embeddings_test 10 skipped. * *Renamed __any_sync_fbgemm to __any_sync and changed its implementation to a more generic one. *Added 'reason' message of skipIfRocm. * *enabled use_array_for_index_remapping in test_nbit_forward_int and test_nbit_forward_fp. *enabled test_nbit_forward_pruning. * deleted 'assert(false)' tthat are related to __any_sync function. --- ...edding_forward_quantized_split_template.cu | 9 ++-- .../include/fbgemm_gpu/fbgemm_cuda_utils.cuh | 10 +++-- fbgemm_gpu/src/split_embeddings_cache_cuda.cu | 9 ++-- .../split_table_batched_embeddings_test.py | 41 +++++++++++-------- fbgemm_gpu/test/test_utils.py | 16 ++++++++ third_party/hipify_torch | 2 +- 6 files changed, 58 insertions(+), 29 deletions(-) diff --git a/fbgemm_gpu/codegen/embedding_forward_quantized_split_template.cu b/fbgemm_gpu/codegen/embedding_forward_quantized_split_template.cu index 85a08e153..a521cd599 100644 --- a/fbgemm_gpu/codegen/embedding_forward_quantized_split_template.cu +++ b/fbgemm_gpu/codegen/embedding_forward_quantized_split_template.cu @@ -493,7 +493,11 @@ __global__ void int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_{ uint32_t subwarp_id = threadIdx.x / 4; uint32_t subwarp_tid = threadIdx.x % 4; +#ifdef __HIP_PLATFORM_HCC__ + uint64_t subwarp_mask = static_cast(0xF) << (4 * subwarp_id); +#else uint32_t subwarp_mask = static_cast(0xF) << (4 * subwarp_id); +#endif for (int32_t l_start = 0; l_start + subwarp_id < L; l_start += kWarpSize / 4) { int32_t idx = indices[indices_start + l_start + subwarp_id]; uint32_t slot_start = pruned_hash_function(static_cast(idx)) % capacity; @@ -515,14 +519,13 @@ __global__ void int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_{ // FIXME: __any_sync with mask isn't supported by HIP yet. // See https://fburl.com/fvy7j0lq for the similar context. // assert false here with https://fburl.com/pfm7enw2 - assert(false); - if (__any(found)) { + if (__any_sync(subwarp_mask, found)) { #else if (__any_sync(subwarp_mask, found)) { #endif break; #ifdef __HIP_PLATFORM_HCC__ - } else if (__any(empty)) { + } else if (__any_sync(subwarp_mask, empty)) { #else } else if (__any_sync(subwarp_mask, empty)) { #endif diff --git a/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh b/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh index 9de3cd43e..29a673b70 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh @@ -1086,8 +1086,6 @@ dequantize_permuted_int4(uint32_t packedVals, __half2 shift_scale) { half shift_scale_x = __low2half(shift_scale); half shift_scale_y = __high2half(shift_scale); - // TODO: Enable this for HIP -#ifndef __HIP_PLATFORM_HCC__ res.vals[0] = hfma2( res.vals[0], __half2( @@ -1112,7 +1110,6 @@ dequantize_permuted_int4(uint32_t packedVals, __half2 shift_scale) { hmul(shift_scale_x, __float2half(32)), hmul(shift_scale_x, __float2half(32))), __half2(shift_scale_y, shift_scale_y)); -#endif return res; } @@ -1621,4 +1618,11 @@ DEVICE_INLINE float float8_min(float8 val) { #undef min #undef max +#ifdef __HIP_PLATFORM_HCC__ +__device__ int __any_sync(uint64_t mask, int predicate) { + uint64_t predicate_bit_pattern = __ballot(predicate); + return (predicate_bit_pattern & mask) > 0; +} +#endif + } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/split_embeddings_cache_cuda.cu b/fbgemm_gpu/src/split_embeddings_cache_cuda.cu index 34a55d598..f3f92d5a2 100644 --- a/fbgemm_gpu/src/split_embeddings_cache_cuda.cu +++ b/fbgemm_gpu/src/split_embeddings_cache_cuda.cu @@ -469,8 +469,7 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_find_uncached_kernel( // FIXME: __any_sync with mask isn't supported by HIP yet. // See https://fburl.com/fvy7j0lq for the similar context. // assert false here with https://fburl.com/pfm7enw2 - assert(false); - if (!__any(found)) { + if (!__any_sync(0xFFFFFFFFFFFFFFFF, found)) { #else if (!__any_sync(0xFFFFFFFF, found)) { #endif @@ -1223,8 +1222,7 @@ __global__ __launch_bounds__(kMaxThreads) void lfu_cache_find_uncached_kernel( // FIXME: __any_sync with mask isn't supported by HIP yet. // See https://fburl.com/fvy7j0lq for the similar context. // assert false here with https://fburl.com/pfm7enw2 - assert(false); - if (!__any(found)) { + if (!__any_sync(0xFFFFFFFFFFFFFFFF, found)) { #else if (!__any_sync(0xFFFFFFFF, found)) { #endif @@ -1935,8 +1933,7 @@ __global__ __launch_bounds__(kMaxThreads) void lxu_cache_lookup_kernel( // FIXME: __any_sync with mask isn't supported by HIP yet. // See https://fburl.com/fvy7j0lq for the similar context. // assert false here with https://fburl.com/pfm7enw2 - assert(false); - if (!__any(found)) { + if (!__any_sync(0xFFFFFFFFFFFFFFFF, found)) { #else if (!__any_sync(0xFFFFFFFF, found)) { #endif diff --git a/fbgemm_gpu/test/split_table_batched_embeddings_test.py b/fbgemm_gpu/test/split_table_batched_embeddings_test.py index 6fe24af76..5d53012b7 100644 --- a/fbgemm_gpu/test/split_table_batched_embeddings_test.py +++ b/fbgemm_gpu/test/split_table_batched_embeddings_test.py @@ -29,7 +29,7 @@ if open_source: # pyre-ignore[21] - from test_utils import gpu_available, gpu_unavailable + from test_utils import gpu_available, gpu_unavailable, TEST_WITH_ROCM, skipIfRocm else: from fbgemm_gpu.test.test_utils import gpu_available, gpu_unavailable @@ -400,7 +400,7 @@ def execute_forward_( weights_precision=st.just(SparseType.INT8), weighted=st.booleans(), mixed=st.booleans(), - use_cache=st.booleans(), + use_cache=st.booleans() if not TEST_WITH_ROCM else st.just(False), cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), @@ -459,7 +459,7 @@ def test_forward_int8( weights_precision=st.just(SparseType.FP16), weighted=st.booleans(), mixed=st.booleans(), - use_cache=st.booleans(), + use_cache=st.booleans() if not TEST_WITH_ROCM else st.just(False), cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), @@ -518,7 +518,7 @@ def test_forward_fp16( weights_precision=st.just(SparseType.FP32), weighted=st.booleans(), mixed=st.booleans(), - use_cache=st.booleans(), + use_cache=st.booleans() if not TEST_WITH_ROCM else st.just(False), cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), @@ -842,6 +842,7 @@ def test_nbit_forward_fused_pooled_emb_quant( equal_nan=True, ) + @skipIfRocm() @given( T=st.integers(min_value=1, max_value=3), D=st.integers(min_value=2, max_value=256), @@ -1057,6 +1058,7 @@ def test_backward_dense( param.requires_grad = False torch.autograd.gradcheck(cc, (indices, offsets, per_sample_weights)) + @skipIfRocm() @given( T=st.integers(min_value=1, max_value=5), D=st.integers(min_value=2, max_value=256), @@ -1066,7 +1068,7 @@ def test_backward_dense( weights_precision=st.sampled_from([SparseType.FP16, SparseType.FP32]), weighted=st.booleans(), mixed=st.booleans(), - use_cache=st.booleans(), + use_cache=st.booleans() if not TEST_WITH_ROCM else st.just(False), cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), @@ -1623,7 +1625,7 @@ def execute_backward_adagrad_( # noqa C901 weighted=st.booleans(), row_wise=st.booleans(), mixed=st.booleans(), - use_cache=st.booleans(), + use_cache=st.booleans() if not TEST_WITH_ROCM else st.just(False), cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), @@ -1685,7 +1687,7 @@ def test_backward_adagrad_fp16_pmSUM( # noqa C901 weighted=st.booleans(), row_wise=st.booleans(), mixed=st.booleans(), - use_cache=st.booleans(), + use_cache=st.booleans() if not TEST_WITH_ROCM else st.just(False), cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), @@ -1747,7 +1749,7 @@ def test_backward_adagrad_fp16_pmMEAN( # noqa C901 weighted=st.booleans(), row_wise=st.booleans(), mixed=st.booleans(), - use_cache=st.booleans(), + use_cache=st.booleans() if not TEST_WITH_ROCM else st.just(False), cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), @@ -1809,7 +1811,7 @@ def test_backward_adagrad_fp16_pmNONE( # noqa C901 weighted=st.booleans(), row_wise=st.booleans(), mixed=st.booleans(), - use_cache=st.booleans(), + use_cache=st.booleans() if not TEST_WITH_ROCM else st.just(False), cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), @@ -1871,11 +1873,11 @@ def test_backward_adagrad_fp32_pmSUM( # noqa C901 weighted=st.booleans(), row_wise=st.booleans(), mixed=st.booleans(), - use_cache=st.booleans(), + use_cache=st.booleans() if not TEST_WITH_ROCM else st.just(False), cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), - use_cpu=st.booleans() if gpu_available else st.just(True), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), exact=st.booleans(), ) @settings( @@ -1933,7 +1935,7 @@ def test_backward_adagrad_fp32_pmMEAN( # noqa C901 weighted=st.booleans(), row_wise=st.booleans(), mixed=st.booleans(), - use_cache=st.booleans(), + use_cache=st.booleans() if not TEST_WITH_ROCM else st.just(False), cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), @@ -1984,6 +1986,7 @@ def test_backward_adagrad_fp32_pmNONE( # noqa C901 ) @unittest.skipIf(*gpu_unavailable) + @skipIfRocm() @given( T=st.integers(min_value=1, max_value=5), D=st.integers(min_value=2, max_value=256), @@ -2450,6 +2453,7 @@ def execute_backward_optimizers_( # noqa C901 rtol=1.0e-4, ) + @skipIfRocm() @given( T=st.integers(min_value=1, max_value=5), D=st.integers(min_value=2, max_value=256), @@ -2509,6 +2513,7 @@ def test_backward_optimizers_adam( # noqa C901 use_cpu, ) + @skipIfRocm() @given( T=st.integers(min_value=1, max_value=5), D=st.integers(min_value=2, max_value=256), @@ -2568,6 +2573,7 @@ def test_backward_optimizers_adagrad( # noqa C901 use_cpu, ) + @skipIfRocm() @given( T=st.integers(min_value=1, max_value=5), D=st.integers(min_value=2, max_value=256), @@ -2627,6 +2633,7 @@ def test_backward_optimizers_lamb( # noqa C901 use_cpu, ) + @skipIfRocm() @given( T=st.integers(min_value=1, max_value=5), D=st.integers(min_value=2, max_value=256), @@ -2983,11 +2990,11 @@ def comp(i: int) -> np.ndarray: # TODO: implement for SparseType.INT2, ] ), - use_cache=st.booleans(), + use_cache=st.booleans() if not TEST_WITH_ROCM else st.just(False), cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), - use_cpu=st.booleans() if gpu_available else st.just(True), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), use_array_for_index_remapping=st.booleans(), mixed_weights_ty=st.booleans(), output_dtype=st.sampled_from( @@ -3060,11 +3067,11 @@ def test_nbit_forward_int( SparseType.FP32, ] ), - use_cache=st.booleans(), + use_cache=st.booleans() if not TEST_WITH_ROCM else st.just(False), cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), - use_cpu=st.booleans() if gpu_available else st.just(True), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), use_array_for_index_remapping=st.booleans(), mixed_weights_ty=st.booleans(), output_dtype=st.sampled_from( @@ -3116,6 +3123,7 @@ def test_nbit_forward_fp( ) @unittest.skipIf(*gpu_unavailable) + @skipIfRocm() @given( T=st.integers(min_value=1, max_value=5), D=st.integers(min_value=2, max_value=256), @@ -3364,6 +3372,7 @@ def test_cache_update_function(self, L: int, H: int, S: int) -> None: assert unique_cache_miss_count == expect_out assert cache_miss_forward_count <= unique_cache_miss_count + @skipIfRocm() @given(N=st.integers(min_value=1, max_value=8)) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) def test_cache_miss_counter(self, N: int) -> None: diff --git a/fbgemm_gpu/test/test_utils.py b/fbgemm_gpu/test/test_utils.py index 2ce2ddd10..4b437343e 100644 --- a/fbgemm_gpu/test/test_utils.py +++ b/fbgemm_gpu/test/test_utils.py @@ -10,7 +10,11 @@ import hypothesis.strategies as st import numpy as np import torch +import os +from functools import wraps +import unittest +TEST_WITH_ROCM = os.getenv('FBGEMM_TEST_WITH_ROCM', '0') == '1' # Eigen/Python round 0.5 away from 0, Numpy rounds to even round_to_nearest: Callable[[np.ndarray], np.ndarray] = np.vectorize(round) @@ -167,3 +171,15 @@ def cpu_and_maybe_gpu() -> st.SearchStrategy[List[torch.device]]: def cpu_only() -> st.SearchStrategy[List[torch.device]]: return st.sampled_from([torch.device("cpu")]) + + +def skipIfRocm(reason="test doesn't currently work on the ROCm stack"): + def skipIfRocmDecorator(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + if TEST_WITH_ROCM: + raise unittest.SkipTest(reason) + else: + fn(*args, **kwargs) + return wrapper + return skipIfRocmDecorator \ No newline at end of file diff --git a/third_party/hipify_torch b/third_party/hipify_torch index 88bd87904..3816549ca 160000 --- a/third_party/hipify_torch +++ b/third_party/hipify_torch @@ -1 +1 @@ -Subproject commit 88bd87904aaf5d68b908af9fe2ef6b32dbbcf45e +Subproject commit 3816549caf28490acc1302859596e33659b46b67 From f13af4463166bf12966749ffcb91f3b78c956a50 Mon Sep 17 00:00:00 2001 From: liligwu Date: Mon, 7 Mar 2022 21:15:44 +0000 Subject: [PATCH 06/76] *Enable use_cache. *Enable split_embedding_inference_converter_test.py by diabling use_cpu. --- fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py | 2 +- fbgemm_gpu/test/split_embedding_inference_converter_test.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py index 017381100..7fdac306b 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py @@ -20,7 +20,7 @@ from fbgemm_gpu.split_embedding_configs import SparseType from torch import Tensor, nn -ASSOC = 32 +ASSOC = 32 if torch.version.hip is None else 64 # Maximum number of times prefetch() can be called without # a corresponding forward() call MAX_PREFETCH_DEPTH = 100 diff --git a/fbgemm_gpu/test/split_embedding_inference_converter_test.py b/fbgemm_gpu/test/split_embedding_inference_converter_test.py index 6a0f7fd56..da9046f94 100644 --- a/fbgemm_gpu/test/split_embedding_inference_converter_test.py +++ b/fbgemm_gpu/test/split_embedding_inference_converter_test.py @@ -27,7 +27,7 @@ if open_source: # pyre-ignore[21] - from test_utils import gpu_available + from test_utils import gpu_available, TEST_WITH_ROCM else: from fbgemm_gpu.test.test_utils import gpu_available @@ -134,7 +134,7 @@ class QuantizedSplitEmbeddingsTest(unittest.TestCase): SparseType.FP32, ] ), - use_cpu=st.booleans() if gpu_available else st.just(True), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), pruning_ratio=st.sampled_from([None, 0.0]), ) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) From 25e5b718001c63b6524025a8b96ab3adddf79155 Mon Sep 17 00:00:00 2001 From: liligwu Date: Mon, 7 Mar 2022 21:38:04 +0000 Subject: [PATCH 07/76] Skip use_cpu. --- .../test/split_table_batched_embeddings_test.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/fbgemm_gpu/test/split_table_batched_embeddings_test.py b/fbgemm_gpu/test/split_table_batched_embeddings_test.py index 6fe24af76..ca102832f 100644 --- a/fbgemm_gpu/test/split_table_batched_embeddings_test.py +++ b/fbgemm_gpu/test/split_table_batched_embeddings_test.py @@ -29,7 +29,7 @@ if open_source: # pyre-ignore[21] - from test_utils import gpu_available, gpu_unavailable + from test_utils import gpu_available, gpu_unavailable, TEST_WITH_ROCM, skipIfRocm else: from fbgemm_gpu.test.test_utils import gpu_available, gpu_unavailable @@ -842,6 +842,7 @@ def test_nbit_forward_fused_pooled_emb_quant( equal_nan=True, ) + @skipIfRocm() @given( T=st.integers(min_value=1, max_value=3), D=st.integers(min_value=2, max_value=256), @@ -1057,6 +1058,7 @@ def test_backward_dense( param.requires_grad = False torch.autograd.gradcheck(cc, (indices, offsets, per_sample_weights)) + @skipIfRocm() @given( T=st.integers(min_value=1, max_value=5), D=st.integers(min_value=2, max_value=256), @@ -1875,7 +1877,7 @@ def test_backward_adagrad_fp32_pmSUM( # noqa C901 cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), - use_cpu=st.booleans() if gpu_available else st.just(True), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), exact=st.booleans(), ) @settings( @@ -1984,6 +1986,7 @@ def test_backward_adagrad_fp32_pmNONE( # noqa C901 ) @unittest.skipIf(*gpu_unavailable) + @skipIfRocm() @given( T=st.integers(min_value=1, max_value=5), D=st.integers(min_value=2, max_value=256), @@ -2450,6 +2453,7 @@ def execute_backward_optimizers_( # noqa C901 rtol=1.0e-4, ) + @skipIfRocm() @given( T=st.integers(min_value=1, max_value=5), D=st.integers(min_value=2, max_value=256), @@ -2509,6 +2513,7 @@ def test_backward_optimizers_adam( # noqa C901 use_cpu, ) + @skipIfRocm() @given( T=st.integers(min_value=1, max_value=5), D=st.integers(min_value=2, max_value=256), @@ -2568,6 +2573,7 @@ def test_backward_optimizers_adagrad( # noqa C901 use_cpu, ) + @skipIfRocm() @given( T=st.integers(min_value=1, max_value=5), D=st.integers(min_value=2, max_value=256), @@ -2627,6 +2633,7 @@ def test_backward_optimizers_lamb( # noqa C901 use_cpu, ) + @skipIfRocm() @given( T=st.integers(min_value=1, max_value=5), D=st.integers(min_value=2, max_value=256), @@ -2987,7 +2994,7 @@ def comp(i: int) -> np.ndarray: cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), - use_cpu=st.booleans() if gpu_available else st.just(True), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), use_array_for_index_remapping=st.booleans(), mixed_weights_ty=st.booleans(), output_dtype=st.sampled_from( @@ -3064,7 +3071,7 @@ def test_nbit_forward_int( cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), - use_cpu=st.booleans() if gpu_available else st.just(True), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), use_array_for_index_remapping=st.booleans(), mixed_weights_ty=st.booleans(), output_dtype=st.sampled_from( @@ -3116,6 +3123,7 @@ def test_nbit_forward_fp( ) @unittest.skipIf(*gpu_unavailable) + @skipIfRocm() @given( T=st.integers(min_value=1, max_value=5), D=st.integers(min_value=2, max_value=256), @@ -3364,6 +3372,7 @@ def test_cache_update_function(self, L: int, H: int, S: int) -> None: assert unique_cache_miss_count == expect_out assert cache_miss_forward_count <= unique_cache_miss_count + @skipIfRocm() @given(N=st.integers(min_value=1, max_value=8)) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) def test_cache_miss_counter(self, N: int) -> None: From dcbe19f5e47108db79e9841aaf6261f4b6e62a92 Mon Sep 17 00:00:00 2001 From: liligwu Date: Mon, 7 Mar 2022 23:08:28 +0000 Subject: [PATCH 08/76] Enable test_nbit_cache_pipeline and test_cache_miss_counter. --- fbgemm_gpu/test/split_table_batched_embeddings_test.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/fbgemm_gpu/test/split_table_batched_embeddings_test.py b/fbgemm_gpu/test/split_table_batched_embeddings_test.py index ca102832f..a7f3aeab2 100644 --- a/fbgemm_gpu/test/split_table_batched_embeddings_test.py +++ b/fbgemm_gpu/test/split_table_batched_embeddings_test.py @@ -3123,7 +3123,6 @@ def test_nbit_forward_fp( ) @unittest.skipIf(*gpu_unavailable) - @skipIfRocm() @given( T=st.integers(min_value=1, max_value=5), D=st.integers(min_value=2, max_value=256), @@ -3372,7 +3371,6 @@ def test_cache_update_function(self, L: int, H: int, S: int) -> None: assert unique_cache_miss_count == expect_out assert cache_miss_forward_count <= unique_cache_miss_count - @skipIfRocm() @given(N=st.integers(min_value=1, max_value=8)) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) def test_cache_miss_counter(self, N: int) -> None: From fda048eb59b1b3380f3a9bdd680ff354ccd35fb7 Mon Sep 17 00:00:00 2001 From: liligwu Date: Mon, 7 Mar 2022 23:00:12 +0000 Subject: [PATCH 09/76] Enable quantize_ops_test.py --- fbgemm_gpu/test/quantize_ops_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/fbgemm_gpu/test/quantize_ops_test.py b/fbgemm_gpu/test/quantize_ops_test.py index 959494eff..04be08b05 100644 --- a/fbgemm_gpu/test/quantize_ops_test.py +++ b/fbgemm_gpu/test/quantize_ops_test.py @@ -21,6 +21,7 @@ fused_rowwise_nbit_quantize_dequantize_reference, bytes_to_half_floats, gpu_available, + skipIfRocm, ) except Exception: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") @@ -41,6 +42,7 @@ class TestFused8BitRowwiseQuantizationConversion(unittest.TestCase): # pyre-fixme[56]: Pyre was not able to infer the type of argument # `hypothesis.strategies.integers($parameter$min_value = 0, $parameter$max_value = # 100)` to decorator factory `hypothesis.given`. + @skipIfRocm() @given( nrows=st.integers(min_value=0, max_value=100), ncols=st.integers(min_value=0, max_value=100), @@ -110,6 +112,7 @@ def test_quantize_and_dequantize_op_cuda_large_nrows(self) -> None: class TestFusedNBitRowwiseQuantizationConversion(unittest.TestCase): # pyre-ignore [56]: Invalid decoration, was not able to infer the type of argument + @skipIfRocm() @given( nrows=st.integers(min_value=0, max_value=100), ncols=st.integers(min_value=0, max_value=100), @@ -247,6 +250,7 @@ def test_quantize_and_dequantize_op_cuda_large_nrows(self) -> None: class TestDenseMLPQuantizationConversion(unittest.TestCase): # pyre-ignore [56]: Invalid decoration, was not able to infer the type of argument + @skipIfRocm @given( nrows=st.integers(min_value=0, max_value=100), ncols=st.integers(min_value=0, max_value=100), From cf307b65d68c88d168101dbaec4586703becef78 Mon Sep 17 00:00:00 2001 From: liligwu Date: Mon, 7 Mar 2022 23:30:57 +0000 Subject: [PATCH 10/76] Remove @skipIfRocm for test_nbit_cache_pipeline and test_cache_miss_counter. --- fbgemm_gpu/test/split_table_batched_embeddings_test.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/fbgemm_gpu/test/split_table_batched_embeddings_test.py b/fbgemm_gpu/test/split_table_batched_embeddings_test.py index 5d53012b7..855ccb8d5 100644 --- a/fbgemm_gpu/test/split_table_batched_embeddings_test.py +++ b/fbgemm_gpu/test/split_table_batched_embeddings_test.py @@ -3123,7 +3123,6 @@ def test_nbit_forward_fp( ) @unittest.skipIf(*gpu_unavailable) - @skipIfRocm() @given( T=st.integers(min_value=1, max_value=5), D=st.integers(min_value=2, max_value=256), @@ -3372,7 +3371,6 @@ def test_cache_update_function(self, L: int, H: int, S: int) -> None: assert unique_cache_miss_count == expect_out assert cache_miss_forward_count <= unique_cache_miss_count - @skipIfRocm() @given(N=st.integers(min_value=1, max_value=8)) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) def test_cache_miss_counter(self, N: int) -> None: From 2d66ea8ebb1b026a8f8d7e5ba23027558f391d3b Mon Sep 17 00:00:00 2001 From: liligwu Date: Mon, 7 Mar 2022 23:51:36 +0000 Subject: [PATCH 11/76] *Uncondition use_cache in split_table_batched_embeddings_test.py *Remove @skipIfRocm for TestFused8BitRowwiseQuantizationConversion and TestFusedNBitRowwiseQuantizationConversion --- fbgemm_gpu/test/quantize_ops_test.py | 2 -- .../split_table_batched_embeddings_test.py | 24 +++++++++---------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/fbgemm_gpu/test/quantize_ops_test.py b/fbgemm_gpu/test/quantize_ops_test.py index 04be08b05..ef07b3d47 100644 --- a/fbgemm_gpu/test/quantize_ops_test.py +++ b/fbgemm_gpu/test/quantize_ops_test.py @@ -42,7 +42,6 @@ class TestFused8BitRowwiseQuantizationConversion(unittest.TestCase): # pyre-fixme[56]: Pyre was not able to infer the type of argument # `hypothesis.strategies.integers($parameter$min_value = 0, $parameter$max_value = # 100)` to decorator factory `hypothesis.given`. - @skipIfRocm() @given( nrows=st.integers(min_value=0, max_value=100), ncols=st.integers(min_value=0, max_value=100), @@ -112,7 +111,6 @@ def test_quantize_and_dequantize_op_cuda_large_nrows(self) -> None: class TestFusedNBitRowwiseQuantizationConversion(unittest.TestCase): # pyre-ignore [56]: Invalid decoration, was not able to infer the type of argument - @skipIfRocm() @given( nrows=st.integers(min_value=0, max_value=100), ncols=st.integers(min_value=0, max_value=100), diff --git a/fbgemm_gpu/test/split_table_batched_embeddings_test.py b/fbgemm_gpu/test/split_table_batched_embeddings_test.py index 855ccb8d5..a7f3aeab2 100644 --- a/fbgemm_gpu/test/split_table_batched_embeddings_test.py +++ b/fbgemm_gpu/test/split_table_batched_embeddings_test.py @@ -400,7 +400,7 @@ def execute_forward_( weights_precision=st.just(SparseType.INT8), weighted=st.booleans(), mixed=st.booleans(), - use_cache=st.booleans() if not TEST_WITH_ROCM else st.just(False), + use_cache=st.booleans(), cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), @@ -459,7 +459,7 @@ def test_forward_int8( weights_precision=st.just(SparseType.FP16), weighted=st.booleans(), mixed=st.booleans(), - use_cache=st.booleans() if not TEST_WITH_ROCM else st.just(False), + use_cache=st.booleans(), cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), @@ -518,7 +518,7 @@ def test_forward_fp16( weights_precision=st.just(SparseType.FP32), weighted=st.booleans(), mixed=st.booleans(), - use_cache=st.booleans() if not TEST_WITH_ROCM else st.just(False), + use_cache=st.booleans(), cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), @@ -1068,7 +1068,7 @@ def test_backward_dense( weights_precision=st.sampled_from([SparseType.FP16, SparseType.FP32]), weighted=st.booleans(), mixed=st.booleans(), - use_cache=st.booleans() if not TEST_WITH_ROCM else st.just(False), + use_cache=st.booleans(), cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), @@ -1625,7 +1625,7 @@ def execute_backward_adagrad_( # noqa C901 weighted=st.booleans(), row_wise=st.booleans(), mixed=st.booleans(), - use_cache=st.booleans() if not TEST_WITH_ROCM else st.just(False), + use_cache=st.booleans(), cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), @@ -1687,7 +1687,7 @@ def test_backward_adagrad_fp16_pmSUM( # noqa C901 weighted=st.booleans(), row_wise=st.booleans(), mixed=st.booleans(), - use_cache=st.booleans() if not TEST_WITH_ROCM else st.just(False), + use_cache=st.booleans(), cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), @@ -1749,7 +1749,7 @@ def test_backward_adagrad_fp16_pmMEAN( # noqa C901 weighted=st.booleans(), row_wise=st.booleans(), mixed=st.booleans(), - use_cache=st.booleans() if not TEST_WITH_ROCM else st.just(False), + use_cache=st.booleans(), cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), @@ -1811,7 +1811,7 @@ def test_backward_adagrad_fp16_pmNONE( # noqa C901 weighted=st.booleans(), row_wise=st.booleans(), mixed=st.booleans(), - use_cache=st.booleans() if not TEST_WITH_ROCM else st.just(False), + use_cache=st.booleans(), cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), @@ -1873,7 +1873,7 @@ def test_backward_adagrad_fp32_pmSUM( # noqa C901 weighted=st.booleans(), row_wise=st.booleans(), mixed=st.booleans(), - use_cache=st.booleans() if not TEST_WITH_ROCM else st.just(False), + use_cache=st.booleans(), cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), @@ -1935,7 +1935,7 @@ def test_backward_adagrad_fp32_pmMEAN( # noqa C901 weighted=st.booleans(), row_wise=st.booleans(), mixed=st.booleans(), - use_cache=st.booleans() if not TEST_WITH_ROCM else st.just(False), + use_cache=st.booleans(), cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), @@ -2990,7 +2990,7 @@ def comp(i: int) -> np.ndarray: # TODO: implement for SparseType.INT2, ] ), - use_cache=st.booleans() if not TEST_WITH_ROCM else st.just(False), + use_cache=st.booleans(), cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), @@ -3067,7 +3067,7 @@ def test_nbit_forward_int( SparseType.FP32, ] ), - use_cache=st.booleans() if not TEST_WITH_ROCM else st.just(False), + use_cache=st.booleans(), cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), From e642a481311dc49ae6e643bf2bd2642eb64aae16 Mon Sep 17 00:00:00 2001 From: liligwu Date: Tue, 8 Mar 2022 02:36:30 +0000 Subject: [PATCH 12/76] Fix backward tests and test_cache_pipeline in split_table_batched_embeddings_test.py --- fbgemm_gpu/codegen/embedding_backward_split_template.cu | 7 +++++-- fbgemm_gpu/test/split_table_batched_embeddings_test.py | 8 +------- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/fbgemm_gpu/codegen/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_template.cu index e80f4041c..ec8703040 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_template.cu @@ -767,9 +767,14 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ C10_CUDA_KERNEL_LAUNCH_CHECK(); int shared_kb = max_shared_bytes >> 10; // V100: 64 KB; A100: 96 KB. +#ifndef __HIP_PLATFORM_HCC__ // Use 2/3 of the available GPU shared mem; leave rooms for L1$. int used_shared_kb = round_down(shared_kb * 2 / 3, 16); TORCH_CHECK(used_shared_kb > 0); +#else + // MI100 has independent shared mem and L1 + int used_shared_kb = shared_kb; +#endif int used_shared_bytes = used_shared_kb << 10; Tensor linear_indices, linear_indices_sorted; @@ -932,12 +937,10 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ {% else %} if (D <= {{ 128 * kMaxVecsPerThread }}) { {% endif %} -#ifndef __HIP_PLATFORM_HCC__ // Stay under used_shared_kb of shared memory (V100: 64 KB; A100: 96 KB), BT_block_size must be a power of two. while (BT_block_size * sizeof(at::acc_type<{{ "scalar_t" if dense else "cache_t" }}, true>) * 4 * kWarpSize * {{ kMaxVecsPerThread }} >= used_shared_bytes) { BT_block_size /= 2; } -#endif TORCH_CHECK(BT_block_size >= 1); if (std::is_same<{{ "scalar_t" if dense else "emb_t" }}, double>::value) { // Otherwise we see CUDA kernel launch failures despite the above checks. diff --git a/fbgemm_gpu/test/split_table_batched_embeddings_test.py b/fbgemm_gpu/test/split_table_batched_embeddings_test.py index a7f3aeab2..6f85f0d6c 100644 --- a/fbgemm_gpu/test/split_table_batched_embeddings_test.py +++ b/fbgemm_gpu/test/split_table_batched_embeddings_test.py @@ -842,7 +842,6 @@ def test_nbit_forward_fused_pooled_emb_quant( equal_nan=True, ) - @skipIfRocm() @given( T=st.integers(min_value=1, max_value=3), D=st.integers(min_value=2, max_value=256), @@ -1058,7 +1057,6 @@ def test_backward_dense( param.requires_grad = False torch.autograd.gradcheck(cc, (indices, offsets, per_sample_weights)) - @skipIfRocm() @given( T=st.integers(min_value=1, max_value=5), D=st.integers(min_value=2, max_value=256), @@ -1986,7 +1984,7 @@ def test_backward_adagrad_fp32_pmNONE( # noqa C901 ) @unittest.skipIf(*gpu_unavailable) - @skipIfRocm() + # @skipIfRocm() @given( T=st.integers(min_value=1, max_value=5), D=st.integers(min_value=2, max_value=256), @@ -2453,7 +2451,6 @@ def execute_backward_optimizers_( # noqa C901 rtol=1.0e-4, ) - @skipIfRocm() @given( T=st.integers(min_value=1, max_value=5), D=st.integers(min_value=2, max_value=256), @@ -2513,7 +2510,6 @@ def test_backward_optimizers_adam( # noqa C901 use_cpu, ) - @skipIfRocm() @given( T=st.integers(min_value=1, max_value=5), D=st.integers(min_value=2, max_value=256), @@ -2573,7 +2569,6 @@ def test_backward_optimizers_adagrad( # noqa C901 use_cpu, ) - @skipIfRocm() @given( T=st.integers(min_value=1, max_value=5), D=st.integers(min_value=2, max_value=256), @@ -2633,7 +2628,6 @@ def test_backward_optimizers_lamb( # noqa C901 use_cpu, ) - @skipIfRocm() @given( T=st.integers(min_value=1, max_value=5), D=st.integers(min_value=2, max_value=256), From d0d294a95e743c14e717b21d54f7ce80c963d04c Mon Sep 17 00:00:00 2001 From: liligwu Date: Tue, 8 Mar 2022 02:38:46 +0000 Subject: [PATCH 13/76] A minor change of removing a commented line. --- fbgemm_gpu/test/split_table_batched_embeddings_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fbgemm_gpu/test/split_table_batched_embeddings_test.py b/fbgemm_gpu/test/split_table_batched_embeddings_test.py index 6f85f0d6c..bdd5d3ce4 100644 --- a/fbgemm_gpu/test/split_table_batched_embeddings_test.py +++ b/fbgemm_gpu/test/split_table_batched_embeddings_test.py @@ -1984,7 +1984,6 @@ def test_backward_adagrad_fp32_pmNONE( # noqa C901 ) @unittest.skipIf(*gpu_unavailable) - # @skipIfRocm() @given( T=st.integers(min_value=1, max_value=5), D=st.integers(min_value=2, max_value=256), From 146f2df15ecd165c4d64c63797c2a2054f6e1e12 Mon Sep 17 00:00:00 2001 From: liligwu Date: Tue, 8 Mar 2022 15:15:51 +0000 Subject: [PATCH 14/76] Remove skipIfRocm import in split_table_batched_embeddings_test.py. --- fbgemm_gpu/test/split_table_batched_embeddings_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fbgemm_gpu/test/split_table_batched_embeddings_test.py b/fbgemm_gpu/test/split_table_batched_embeddings_test.py index bdd5d3ce4..9690eccc0 100644 --- a/fbgemm_gpu/test/split_table_batched_embeddings_test.py +++ b/fbgemm_gpu/test/split_table_batched_embeddings_test.py @@ -29,7 +29,7 @@ if open_source: # pyre-ignore[21] - from test_utils import gpu_available, gpu_unavailable, TEST_WITH_ROCM, skipIfRocm + from test_utils import gpu_available, gpu_unavailable, TEST_WITH_ROCM else: from fbgemm_gpu.test.test_utils import gpu_available, gpu_unavailable From 0c86f2be3eb86a33cf46bacb4d2174c4301dabd9 Mon Sep 17 00:00:00 2001 From: liligwu Date: Fri, 11 Mar 2022 16:23:33 +0000 Subject: [PATCH 15/76] *Removed post_hipify logic in setup.py. *Removed two headerfiles that have been deleted in upstream. --- .../fbgemm_gpu/hipcub_namespace_postfix.cuh | 21 --- .../fbgemm_gpu/hipcub_namespace_prefix.cuh | 16 -- fbgemm_gpu/setup.py | 139 ++++++------------ 3 files changed, 42 insertions(+), 134 deletions(-) delete mode 100644 fbgemm_gpu/fbgemm_gpu/hipcub_namespace_postfix.cuh delete mode 100644 fbgemm_gpu/fbgemm_gpu/hipcub_namespace_prefix.cuh diff --git a/fbgemm_gpu/fbgemm_gpu/hipcub_namespace_postfix.cuh b/fbgemm_gpu/fbgemm_gpu/hipcub_namespace_postfix.cuh deleted file mode 100644 index 8922edbba..000000000 --- a/fbgemm_gpu/fbgemm_gpu/hipcub_namespace_postfix.cuh +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * All rights reserved. - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#undef FBGEMM_GPU_CUB_NS_PREFIX - -#ifdef FBGEMM_CUB_USE_NAMESPACE - -#undef CUB_NS_PREFIX -#undef CUB_NS_POSTFIX - -#define FBGEMM_GPU_CUB_NS_PREFIX fbgemm_gpu:: - -#else - -#define FBGEMM_GPU_CUB_NS_PREFIX - -#endif diff --git a/fbgemm_gpu/fbgemm_gpu/hipcub_namespace_prefix.cuh b/fbgemm_gpu/fbgemm_gpu/hipcub_namespace_prefix.cuh deleted file mode 100644 index c977653fa..000000000 --- a/fbgemm_gpu/fbgemm_gpu/hipcub_namespace_prefix.cuh +++ /dev/null @@ -1,16 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * All rights reserved. - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#ifdef FBGEMM_CUB_USE_NAMESPACE - -#undef CUB_NS_PREFIX -#undef CUB_NS_POSTFIX - -#define CUB_NS_PREFIX namespace fbgemm_gpu { -#define CUB_NS_POSTFIX } // namespace fbgemm_gpu - -#endif diff --git a/fbgemm_gpu/setup.py b/fbgemm_gpu/setup.py index 7145a16b9..bb2ea6f62 100644 --- a/fbgemm_gpu/setup.py +++ b/fbgemm_gpu/setup.py @@ -148,65 +148,17 @@ def build_extension(self, ext): hipify_python.hipify( project_directory=cur_dir, output_directory=cur_dir, - includes="codegen/*", + header_include_dirs=[ + os.path.join(cur_dir, "include"), + os.path.join(cur_dir, "src"), + cur_dir, + ], + includes=["*"], + extra_files=CUDA_source, show_detailed=True, is_pytorch_extension=True, clean_ctx=clean_ctx) - def replace_pattern(hip_file, pattern_map): - patterns = {} - for regexp in pattern_map: - patterns[regexp] = re.compile(regexp.format(exclude="")) - with tempfile.NamedTemporaryFile(mode="w", delete=False) as tmp_file: - with open(hip_file) as src_file: - for line in src_file: - for regexp in pattern_map: - pattern = pattern_map[regexp] - exclude = pattern[0] - replacement = pattern[1] - in_regexp = regexp.format(exclude="") - if len(pattern_map[regexp]) == 4: - all_ori = pattern[2] - all_new = pattern[3] - else: - all_ori = None - all_new = None - if re.search(in_regexp, line) and \ - (exclude is None or not re.search(regexp.format(exclude=exclude), line)): - ori = line - if all_ori is not None and all_ori in line: - line = line.replace(all_ori, all_new) - else: - line = patterns[regexp].sub(replacement, line) - - tmp_file.write(line) - - shutil.copystat(hip_file, tmp_file.name) - shutil.move(tmp_file.name, hip_file) - - def post_hipify(hip_file): - replace_pattern(hip_file, {"(#include.*\"codegen.*){exclude}[.]cuh": ["_hip", "\\1_hip.cuh"], - "{exclude}cub(::DeviceRunLengthEncode)": ["hip", "hipcub\\1"], - "(#include.*[<\"].*){exclude}cub(.*)[.]cuh": ["hip", "\\1hipcub\\2.hpp"], - "(#include.*[<\"]fbgemm_gpu.*)({exclude}[.]cuh)": ["_hip", "\\1_hip\\2", "cuda", "hip"], - "cudaCpuDeviceId": [None, "hipCpuDeviceId"], - "split_embeddings_utils[.]cuh": [None, "split_embeddings_utils_hip.cuh"]}) - - abs_build_path = os.path.join(cur_dir, build_codegen_path) - for f in cpp_cuda_output_files: - if f.endswith(".cu"): - hip_f = os.path.join(abs_build_path, f.replace("cuda", "hip").replace(".cu", ".hip")) - post_hipify(hip_f) - - for s in ["codegen", "src"]: - for f in os.listdir(s): - if f.endswith(".hip") or f.endswith("hip.cuh"): - hip_f = os.path.join(s, f) - post_hipify(hip_f) - - os.system("hipify-perl src/split_embeddings_utils.cuh > src/split_embeddings_utils_hip.cuh") - post_hipify("src/split_embeddings_utils_hip.cuh") - super().build_extension(ext) if is_rocm_pytorch: @@ -230,26 +182,10 @@ def post_hipify(hip_file): cpu_only_build = True sys.argv.remove("--cpu_only") -setup( - name="fbgemm_gpu", - install_requires=[ - "torch", - "Jinja2", - "click", - "hypothesis", - ], - version="0.0.1", - long_description=long_description, - ext_modules=[ - CUDAExtension( - name="fbgemm_gpu_py", - sources=[ +CUDA_source = [ os.path.join(cur_dir, build_codegen_path, "{}".format(f)) for f in cpp_cuda_output_files + cpp_cpu_output_files - ] - + cpp_asmjit_files - + cpp_fbgemm_files - + [ + ] + cpp_asmjit_files + cpp_fbgemm_files + [ os.path.join(cur_dir, "codegen/embedding_forward_split_cpu.cpp"), os.path.join(cur_dir, "codegen/embedding_forward_quantized_host_cpu.cpp"), os.path.join(cur_dir, "codegen/embedding_forward_quantized_host.cpp"), @@ -278,8 +214,9 @@ def post_hipify(hip_file): os.path.join(cur_dir, "src/jagged_tensor_ops.cu"), os.path.join(cur_dir, "src/histogram_binning_calibration_ops.cu"), os.path.join(cur_dir, "src/split_embeddings_utils.cu"), - ], - include_dirs=[ + ] + +common_included = [ cur_dir, os.path.join(cur_dir, "include"), os.path.join(cur_dir, "../include"), @@ -288,34 +225,42 @@ def post_hipify(hip_file): os.path.join(cur_dir, "../third_party/asmjit/src/core"), os.path.join(cur_dir, "../third_party/asmjit/src/x86"), os.path.join(cur_dir, "../third_party/cpuinfo/include"), - ] + include_dirs, + ] + +CUDA_include = common_included + include_dirs + +Cpp_source = [ + os.path.join(cur_dir, build_codegen_path, "{}".format(f)) + for f in cpp_cpu_output_files + ] + cpp_asmjit_files + cpp_fbgemm_files + [ + os.path.join(cur_dir, "codegen/embedding_forward_split_cpu.cpp"), + os.path.join(cur_dir, "codegen/embedding_forward_quantized_host_cpu.cpp"), + os.path.join(cur_dir, "codegen/embedding_backward_dense_host_cpu.cpp"), + ] + +setup( + name="fbgemm_gpu", + install_requires=[ + "torch", + "Jinja2", + "click", + "hypothesis", + ], + version="0.0.1", + long_description=long_description, + ext_modules=[ + CUDAExtension( + name="fbgemm_gpu_py", + sources=CUDA_source, + include_dirs=CUDA_include, extra_compile_args={"cxx": extra_compile_args + ["-DFBGEMM_GPU_WITH_CUDA"], "nvcc": ["-U__CUDA_NO_HALF_CONVERSIONS__"]}, libraries=libraries, ) if not cpu_only_build else CppExtension( name="fbgemm_gpu_py", - sources=[ - os.path.join(cur_dir, build_codegen_path, "{}".format(f)) - for f in cpp_cpu_output_files - ] - + cpp_asmjit_files - + cpp_fbgemm_files - + [ - os.path.join(cur_dir, "codegen/embedding_forward_split_cpu.cpp"), - os.path.join(cur_dir, "codegen/embedding_forward_quantized_host_cpu.cpp"), - os.path.join(cur_dir, "codegen/embedding_backward_dense_host_cpu.cpp"), - ], - include_dirs=[ - cur_dir, - os.path.join(cur_dir, "include"), - os.path.join(cur_dir, "../include"), - os.path.join(cur_dir, "../src"), - os.path.join(cur_dir, "../third_party/asmjit/src"), - os.path.join(cur_dir, "../third_party/asmjit/src/core"), - os.path.join(cur_dir, "../third_party/asmjit/src/x86"), - os.path.join(cur_dir, "../third_party/cpuinfo/include"), - ], + sources=Cpp_source, + include_dirs=common_included, extra_compile_args={"cxx": extra_compile_args}, ) ], From edd330616d281f642483a9704df196eca035eef3 Mon Sep 17 00:00:00 2001 From: liligwu Date: Mon, 14 Mar 2022 16:30:41 +0000 Subject: [PATCH 16/76] Pointing hipify_torch to the newer commit. --- third_party/hipify_torch | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/hipify_torch b/third_party/hipify_torch index 3816549ca..0f7fc0d5a 160000 --- a/third_party/hipify_torch +++ b/third_party/hipify_torch @@ -1 +1 @@ -Subproject commit 3816549caf28490acc1302859596e33659b46b67 +Subproject commit 0f7fc0d5a45c5a4578275bcedff088ebf33772ed From 309a3a17ae0d97e957d934d571652bde268ae06d Mon Sep 17 00:00:00 2001 From: Li Li Date: Wed, 16 Mar 2022 11:42:10 -0500 Subject: [PATCH 17/76] Fixing #include by defining NEW_GENERATOR_PATH in setup.py. (#19) --- .../embedding_backward_template_helpers.cuh | 4 ++++ fbgemm_gpu/setup.py | 11 ++++++++--- fbgemm_gpu/src/split_embeddings_cache_cuda.cu | 4 ++++ 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/embedding_backward_template_helpers.cuh b/fbgemm_gpu/include/fbgemm_gpu/embedding_backward_template_helpers.cuh index ad0804fe8..e662a6d7a 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/embedding_backward_template_helpers.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/embedding_backward_template_helpers.cuh @@ -15,7 +15,11 @@ #include #include +#if !defined(NEW_GENERATOR_PATH) #include +#else +#include +#endif #include #include #include diff --git a/fbgemm_gpu/setup.py b/fbgemm_gpu/setup.py index bb2ea6f62..6c5c00b06 100644 --- a/fbgemm_gpu/setup.py +++ b/fbgemm_gpu/setup.py @@ -40,6 +40,11 @@ # Get the long description from the relevant file cur_dir = os.path.dirname(os.path.realpath(__file__)) +generator_flag = [] +torch_dir = torch.__path__[0] +if os.path.exists(os.path.join(torch_dir, "include", "ATen", "cuda", "CUDAGeneratorImpl.h")): + generator_flag = ["-DNEW_GENERATOR_PATH"] + with open(os.path.join(cur_dir, "README.md"), encoding="utf-8") as f: long_description = f.read() @@ -253,15 +258,15 @@ def build_extension(self, ext): name="fbgemm_gpu_py", sources=CUDA_source, include_dirs=CUDA_include, - extra_compile_args={"cxx": extra_compile_args + ["-DFBGEMM_GPU_WITH_CUDA"], - "nvcc": ["-U__CUDA_NO_HALF_CONVERSIONS__"]}, + extra_compile_args={"cxx": extra_compile_args + ["-DFBGEMM_GPU_WITH_CUDA"] + generator_flag, + "nvcc": ["-U__CUDA_NO_HALF_CONVERSIONS__"] + generator_flag}, libraries=libraries, ) if not cpu_only_build else CppExtension( name="fbgemm_gpu_py", sources=Cpp_source, include_dirs=common_included, - extra_compile_args={"cxx": extra_compile_args}, + extra_compile_args={"cxx": extra_compile_args + generator_flag}, ) ], cmdclass={"build_ext": FBGEMM_GPU_BuildExtension}, diff --git a/fbgemm_gpu/src/split_embeddings_cache_cuda.cu b/fbgemm_gpu/src/split_embeddings_cache_cuda.cu index f3f92d5a2..82de1444b 100644 --- a/fbgemm_gpu/src/split_embeddings_cache_cuda.cu +++ b/fbgemm_gpu/src/split_embeddings_cache_cuda.cu @@ -15,7 +15,11 @@ #include #include +#if !defined(NEW_GENERATOR_PATH) #include +#else +#include +#endif #include #include #include From 358eaf507dea9a624fe4b94c02cac5299643f6ed Mon Sep 17 00:00:00 2001 From: Li Li Date: Wed, 16 Mar 2022 13:30:47 -0500 Subject: [PATCH 18/76] Disabling all use_cpu in the tests. (#20) --- ...plit_embedding_inference_converter_test.py | 4 +-- .../split_table_batched_embeddings_test.py | 32 +++++++++---------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/fbgemm_gpu/test/split_embedding_inference_converter_test.py b/fbgemm_gpu/test/split_embedding_inference_converter_test.py index da9046f94..1318d2f09 100644 --- a/fbgemm_gpu/test/split_embedding_inference_converter_test.py +++ b/fbgemm_gpu/test/split_embedding_inference_converter_test.py @@ -208,7 +208,7 @@ def test_quantize_workflow( ) @given( - use_cpu=st.booleans() if gpu_available else st.just(True), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), use_array_for_index_remapping=st.booleans(), ) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) @@ -288,7 +288,7 @@ def test_l2_norm_pruning_workflow( D=st.integers(min_value=2, max_value=128), log_E=st.integers(min_value=3, max_value=5), pruning_ratio=st.floats(min_value=0.0, max_value=1.0, exclude_max=True), - use_cpu=st.booleans() if gpu_available else st.just(True), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), use_array_for_index_remapping=st.booleans(), ) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) diff --git a/fbgemm_gpu/test/split_table_batched_embeddings_test.py b/fbgemm_gpu/test/split_table_batched_embeddings_test.py index 9690eccc0..0e3ac52c0 100644 --- a/fbgemm_gpu/test/split_table_batched_embeddings_test.py +++ b/fbgemm_gpu/test/split_table_batched_embeddings_test.py @@ -411,7 +411,7 @@ def execute_forward_( split_table_batched_embeddings_ops.PoolingMode.NONE, ] ), - use_cpu=st.booleans() if gpu_available else st.just(True), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), ) @settings( verbosity=Verbosity.verbose, @@ -470,7 +470,7 @@ def test_forward_int8( split_table_batched_embeddings_ops.PoolingMode.NONE, ] ), - use_cpu=st.booleans() if gpu_available else st.just(True), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), ) @settings( verbosity=Verbosity.verbose, @@ -529,7 +529,7 @@ def test_forward_fp16( split_table_batched_embeddings_ops.PoolingMode.NONE, ] ), - use_cpu=st.booleans() if gpu_available else st.just(True), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), ) @settings( verbosity=Verbosity.verbose, @@ -859,7 +859,7 @@ def test_nbit_forward_fused_pooled_emb_quant( split_table_batched_embeddings_ops.PoolingMode.NONE, ] ), - use_cpu=st.booleans() if gpu_available else st.just(True), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), ) @settings( verbosity=Verbosity.verbose, @@ -1078,7 +1078,7 @@ def test_backward_dense( split_table_batched_embeddings_ops.PoolingMode.NONE, ] ), - use_cpu=st.booleans() if gpu_available else st.just(True), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), exact=st.booleans(), ) @settings( @@ -1627,7 +1627,7 @@ def execute_backward_adagrad_( # noqa C901 cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), - use_cpu=st.booleans() if gpu_available else st.just(True), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), exact=st.booleans(), ) @settings( @@ -1689,7 +1689,7 @@ def test_backward_adagrad_fp16_pmSUM( # noqa C901 cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), - use_cpu=st.booleans() if gpu_available else st.just(True), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), exact=st.booleans(), ) @settings( @@ -1751,7 +1751,7 @@ def test_backward_adagrad_fp16_pmMEAN( # noqa C901 cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), - use_cpu=st.booleans() if gpu_available else st.just(True), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), exact=st.booleans(), ) @settings( @@ -1813,7 +1813,7 @@ def test_backward_adagrad_fp16_pmNONE( # noqa C901 cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), - use_cpu=st.booleans() if gpu_available else st.just(True), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), exact=st.booleans(), ) @settings( @@ -1937,7 +1937,7 @@ def test_backward_adagrad_fp32_pmMEAN( # noqa C901 cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), - use_cpu=st.booleans() if gpu_available else st.just(True), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), exact=st.booleans(), ) @settings( @@ -2472,7 +2472,7 @@ def execute_backward_optimizers_( # noqa C901 split_table_batched_embeddings_ops.PoolingMode.NONE, ] ), - use_cpu=st.booleans() if gpu_available else st.just(True), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), ) @settings( verbosity=Verbosity.verbose, @@ -2532,7 +2532,7 @@ def test_backward_optimizers_adam( # noqa C901 split_table_batched_embeddings_ops.PoolingMode.NONE, ] ), - use_cpu=st.booleans() if gpu_available else st.just(True), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), ) @settings( verbosity=Verbosity.verbose, @@ -2590,7 +2590,7 @@ def test_backward_optimizers_adagrad( # noqa C901 split_table_batched_embeddings_ops.PoolingMode.NONE, ] ), - use_cpu=st.booleans() if gpu_available else st.just(True), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), ) @settings( verbosity=Verbosity.verbose, @@ -2644,7 +2644,7 @@ def test_backward_optimizers_lamb( # noqa C901 split_table_batched_embeddings_ops.PoolingMode.NONE, ] ), - use_cpu=st.booleans() if gpu_available else st.just(True), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), ) @settings( verbosity=Verbosity.verbose, @@ -3223,7 +3223,7 @@ def test_nbit_cache_pipeline( T=st.integers(min_value=1, max_value=10), B=st.integers(min_value=1, max_value=64), L=st.integers(min_value=0, max_value=64), - use_cpu=st.booleans() if gpu_available else st.just(True), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), use_cpu_hashtable=st.booleans(), ) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) @@ -3434,7 +3434,7 @@ def test_cache_miss_counter(self, N: int) -> None: BoundsCheckMode.IGNORE, ] ), - use_cpu=st.booleans() if gpu_available else st.just(True), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), dtype=st.sampled_from( [ torch.int64, From 3a915a844222bf5313d9ab81897afbe2cfc60da5 Mon Sep 17 00:00:00 2001 From: Pruthvi Madugundu Date: Thu, 17 Mar 2022 03:34:53 +0530 Subject: [PATCH 19/76] Change py3.8 syntax to py3.7 syntax (#18) --- .../split_table_batched_embeddings_benchmark.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py index 9c6e9e0b7..cc296353e 100644 --- a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py @@ -785,8 +785,10 @@ def uvm( offsets = torch.tensor(([0] + np.cumsum(lengths).tolist())).int().cuda() per_sample_weights = None if weighted: - assert (this_rs_uvm_weights := rs_uvm[2]) is not None - assert (this_rs_gpu_weights := rs_gpu[2]) is not None + this_rs_uvm_weights = rs_uvm[2] + assert this_rs_uvm_weights is not None + this_rs_gpu_weights = rs_gpu[2] + assert this_rs_gpu_weights is not None per_sample_weights = torch.cat( [this_rs_uvm_weights, this_rs_gpu_weights] ) @@ -1607,8 +1609,10 @@ def nbit_uvm( offsets = torch.tensor(([0] + np.cumsum(lengths).tolist())).int().cuda() per_sample_weights = None if weighted: - assert (this_rs_uvm_weights := rs_uvm[2]) is not None - assert (this_rs_gpu_weights := rs_gpu[2]) is not None + this_rs_uvm_weights = rs_uvm[2] + assert this_rs_uvm_weights is not None + this_rs_gpu_weights = rs_gpu[2] + assert this_rs_gpu_weights is not None per_sample_weights = torch.cat( [this_rs_uvm_weights, this_rs_gpu_weights] ) From 40928bab81f7e1f403605ce238fd0d14aec3f60c Mon Sep 17 00:00:00 2001 From: Li Li Date: Thu, 31 Mar 2022 15:10:50 -0500 Subject: [PATCH 20/76] Match upstream setup (#21) * An attempt of matching upstream setup.py. * Move hipify() to CMakeList.txt. * Removing hipify from the python script. * Matching upstream setup.py * #Removing the unnecessary funcitons and statements in Hip.cmake. #Reforming some of the compilation option lists in CMakeList.txt. * Updating hipify_torch (CMake API) * #Adding automatically detection for CUDA and ROCm. #Removing the debug code in embedding_backward_code_generator.py. #Adding 'gfx90a' in FBGEMM_ROCM_ARCH. #Minor changes on message and indentation. --- fbgemm_gpu/CMakeLists.txt | 197 ++++++++++---- fbgemm_gpu/cmake/Hip.cmake | 163 ++++++++++++ .../embedding_backward_code_generator.py | 2 +- fbgemm_gpu/setup.py | 244 +----------------- third_party/hipify_torch | 2 +- 5 files changed, 315 insertions(+), 293 deletions(-) create mode 100644 fbgemm_gpu/cmake/Hip.cmake diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt index c62f8585b..181008fc9 100644 --- a/fbgemm_gpu/CMakeLists.txt +++ b/fbgemm_gpu/CMakeLists.txt @@ -3,7 +3,7 @@ cmake_minimum_required(VERSION 3.11.0 FATAL_ERROR) project( fbgemm_gpu VERSION 0.0.1 - LANGUAGES CXX C CUDA) + LANGUAGES CXX C) set(message_line "-------------------------------------------------------------") @@ -13,23 +13,35 @@ if(SKBUILD) message("The project is built using scikit-build") endif() -set(default_cuda_architectures 60 61 70 75 80) -set(cuda_architectures_doc - "CUDA architectures to build for. Default is ${default_cuda_architectures}") -set(cuda_architectures - "${default_cuda_architectures}" - CACHE STRING "${cuda_architectures_doc}") +if(EXISTS "/usr/bin/nvidia-smi") + message("NVIDIA GPU detected.") + option(USE_CUDA "Use CUDA" ON) + option(USE_ROCM "Use ROCm" OFF) +elseif(EXISTS "/opt/rocm/bin/rocm-smi") + message("AMD GPU detected.") + option(USE_CUDA "Use CUDA" OFF) + option(USE_ROCM "Use ROCm" ON) +else() + message("Unable to detect GPU vendor") + message(FATAL_ERROR "") +endif() -message("${message_line}") -message("fbgemm_gpu:") -message("Building for cuda_architectures = \"${cuda_architectures}\"") -message("${message_line}") +if((USE_CUDA EQUAL ON AND USE_ROCM EQUAL ON) OR (USE_CUDA EQUAL OFF AND USE_ROCM EQUAL OFF)) + message(FATAL_ERROR "Please choose either CUDA or ROCm.") +endif() -find_package(Torch REQUIRED) -find_package(PythonExtensions REQUIRED) +if(USE_CUDA) + set(default_cuda_architectures 60 61 70 75 80) + set(cuda_architectures_doc + "CUDA architectures to build for. Default is ${default_cuda_architectures}") + set(cuda_architectures + "${default_cuda_architectures}" + CACHE STRING "${cuda_architectures_doc}") -set(FBGEMM ${CMAKE_CURRENT_SOURCE_DIR}/..) -set(THIRDPARTY ${FBGEMM}/third_party) + message("${message_line}") + message("fbgemm_gpu:") + message("Building for cuda_architectures = \"${cuda_architectures}\"") + message("${message_line}") # # Toch Cuda Extensions are normally compiled with the flags below. However we @@ -37,13 +49,47 @@ set(THIRDPARTY ${FBGEMM}/third_party) # constructor exists to convert from "int" to "__half" errors in # gen_embedding_forward_quantized_split_[un]weighted_codegen_cuda.cu # + set(TORCH_CUDA_OPTIONS + --expt-relaxed-constexpr + -D__CUDA_NO_HALF_OPERATORS__ + # -D__CUDA_NO_HALF_CONVERSIONS__ + -D__CUDA_NO_BFLOAT16_CONVERSIONS__ + -D__CUDA_NO_HALF2_OPERATORS__) +endif() + +find_package(Torch REQUIRED) +find_package(PythonExtensions REQUIRED) + +set(FBGEMM ${CMAKE_CURRENT_SOURCE_DIR}/..) +set(THIRDPARTY ${FBGEMM}/third_party) + +if(USE_ROCM) + if(NOT DEFINED ENV{PYTORCH_ROCM_ARCH}) + SET(FBGEMM_ROCM_ARCH gfx900;gfx906;gfx908;gfx90a) + else() + SET(FBGEMM_ROCM_ARCH $ENV{PYTORCH_ROCM_ARCH}) + endif() + + list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake" "${THIRDPARTY}/hipify_torch/cmake") + include(Hip) + if(NOT FBGEMM_HAVE_HIP) + message(FATAL_ERROR "Not able to find HIP installation.") + endif() + include(Hipify) + list (APPEND CMAKE_PREFIX_PATH /opt/rocm/hip /opt/rocm) + set(CMAKE_MODULE_PATH ${HIP_PATH}/cmake ${CMAKE_MODULE_PATH}) + + find_package(rocBLAS REQUIRED) + find_package(hipFFT REQUIRED) + find_package(hipRAND REQUIRED) + find_package(rocRAND REQUIRED) + find_package(hipSPARSE REQUIRED) + find_package(OpenMP REQUIRED) + + message("${message_line}") + message(STATUS "hip found ${ROCM_FOUND}") +endif() -set(TORCH_CUDA_OPTIONS - --expt-relaxed-constexpr - -D__CUDA_NO_HALF_OPERATORS__ - # -D__CUDA_NO_HALF_CONVERSIONS__ - -D__CUDA_NO_BFLOAT16_CONVERSIONS__ - -D__CUDA_NO_HALF2_OPERATORS__) # # GENERATED CUDA, CPP and Python code @@ -131,17 +177,37 @@ set(codegen_dependencies ${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/sparse_ops_utils.h ) -add_custom_command( - OUTPUT ${gen_cpu_source_files} ${gen_gpu_source_files} ${gen_python_files} - COMMAND - "${PYTHON_EXECUTABLE}" - "${CMAKE_CODEGEN_DIR}/embedding_backward_code_generator.py" - "--opensource" - DEPENDS "${codegen_dependencies}") +if(USE_CUDA) + add_custom_command( + OUTPUT ${gen_cpu_source_files} ${gen_gpu_source_files} ${gen_python_files} + COMMAND + "${PYTHON_EXECUTABLE}" + "${CMAKE_CODEGEN_DIR}/embedding_backward_code_generator.py" + "--opensource" + DEPENDS "${codegen_dependencies}") + + set_source_files_properties( + ${gen_cpu_source_files} PROPERTIES COMPILE_OPTIONS + "-mavx2;-mf16c;-mfma;-fopenmp") +elseif(USE_ROCM) + execute_process( + COMMAND + "${PYTHON_EXECUTABLE}" + "${CMAKE_CODEGEN_DIR}/embedding_backward_code_generator.py" + "--opensource") + + set(header_include_dir + ${CMAKE_CURRENT_SOURCE_DIR}/include + ${CMAKE_CURRENT_SOURCE_DIR}/src + ${CMAKE_CURRENT_SOURCE_DIR} + ) + hipify(CUDA_SOURCE_DIR ${PROJECT_SOURCE_DIR} HEADER_INCLUDE_DIR ${header_include_dir}) + + set_source_files_properties( + ${gen_cpu_source_files} PROPERTIES COMPILE_OPTIONS + "-mavx2;-mf16c;-mfma") +endif() -set_source_files_properties( - ${gen_cpu_source_files} PROPERTIES COMPILE_OPTIONS - "-mavx2;-mf16c;-mfma;-fopenmp") set_source_files_properties( ${gen_cpu_source_files} PROPERTIES @@ -180,15 +246,15 @@ set(cpp_fbgemm_files_avx2 "../src/EmbeddingSpMDMAvx2.cc" set_source_files_properties(${cpp_fbgemm_files_avx2} PROPERTIES COMPILE_OPTIONS "-mavx2;-mf16c;-mfma") +set(cpp_fbgemm_files ${cpp_fbgemm_files_normal} ${cpp_fbgemm_files_avx2}) set(cpp_fbgemm_files_avx512 "../src/EmbeddingSpMDMAvx512.cc") - -set_source_files_properties( - ${cpp_fbgemm_files_avx512} - PROPERTIES COMPILE_OPTIONS - "-mavx2;-mf16c;-mfma;-mavx512f;-mavx512bw;-mavx512dq;-mavx512vl") - -set(cpp_fbgemm_files ${cpp_fbgemm_files_normal} ${cpp_fbgemm_files_avx2} - ${cpp_fbgemm_files_avx512}) +if(USE_CUDA) + set_source_files_properties( + ${cpp_fbgemm_files_avx512} + PROPERTIES COMPILE_OPTIONS + "-mavx2;-mf16c;-mfma;-mavx512f;-mavx512bw;-mavx512dq;-mavx512vl") + list(APPEND cpp_fbgemm_files ${cpp_fbgemm_files_avx512}) +endif() set(cpp_fbgemm_files_include_directories ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/include @@ -223,10 +289,12 @@ set(fbgemm_gpu_sources_cpu src/sparse_ops_gpu.cpp src/split_table_batched_embeddings.cpp) -set_source_files_properties( - ${fbgemm_gpu_sources_cpu} PROPERTIES COMPILE_OPTIONS - "-mavx;-mf16c;-mfma;-mavx2;-fopenmp") - +set(fbgemm_gpu_sources_cpu_option "-mavx;-mf16c;-mfma;-mavx2") +if(USE_CUDA) + set_source_files_properties( + ${fbgemm_gpu_sources_cpu} PROPERTIES COMPILE_OPTIONS + "${fbgemm_gpu_sources_cpu_option};-fopenmp") +endif() set(fbgemm_gpu_sources_gpu codegen/embedding_bounds_check.cu src/cumem_utils.cu src/histogram_binning_calibration_ops.cu src/jagged_tensor_ops.cu @@ -248,22 +316,45 @@ set_source_files_properties( set(fbgemm_gpu_sources ${fbgemm_gpu_sources_gpu} ${fbgemm_gpu_sources_cpu}) -# -# MODULE -# - -add_library(fbgemm_gpu_py MODULE ${fbgemm_gpu_sources} ${gen_source_files} - ${cpp_asmjit_files} ${cpp_fbgemm_files}) +if(USE_ROCM) + set(abspath_gen_source_files) + foreach(filename_gen_source_file ${gen_source_files}) + list(APPEND abspath_gen_source_files "${CMAKE_BINARY_DIR}/${filename_gen_source_file}") + endforeach() +endif() -target_compile_definitions(fbgemm_gpu_py PRIVATE FBGEMM_CUB_USE_NAMESPACE) +if(USE_CUDA) + add_library(fbgemm_gpu_py MODULE ${fbgemm_gpu_sources} ${gen_source_files} + ${cpp_asmjit_files} ${cpp_fbgemm_files}) + set_property(TARGET fbgemm_gpu_py PROPERTY CUDA_ARCHITECTURES + "${cuda_architectures}") + target_compile_definitions(fbgemm_gpu_py PRIVATE FBGEMM_CUB_USE_NAMESPACE) + set_property(TARGET fbgemm_gpu_py PROPERTY CXX_STANDARD 17) +elseif(USE_ROCM) + get_hipified_list("${fbgemm_gpu_sources}" fbgemm_gpu_sources) + get_hipified_list("${abspath_gen_source_files}" abspath_gen_source_files) + get_hipified_list("${cpp_fbgemm_files}" cpp_fbgemm_files) + + set(FBGEMM_ALL_HIP_FILES ${fbgemm_gpu_sources} ${abspath_gen_source_files} ${cpp_fbgemm_files}) + set_source_files_properties(${FBGEMM_ALL_HIP_FILES} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) + hip_include_directories("${cpp_fbgemm_files_include_directories}") + + hip_add_library(fbgemm_gpu_py SHARED ${cpp_asmjit_files} ${FBGEMM_ALL_HIP_FILES} ${FBGEMM_HIP_HCC_LIBRARIES} + HIPCC_OPTIONS ${HIP_CXX_FLAGS}) + target_include_directories(fbgemm_gpu_py PUBLIC ${FBGEMM_HIP_INCLUDE} ${ROCRAND_INCLUDE}) + set_property(TARGET fbgemm_gpu_py PROPERTY HIP_ARCHITECTURES ${FBGEMM_ROCM_ARCH}) + + # For ROCm5.1 + list (GET TORCH_INCLUDE_DIRS 0 TORCH_PATH) + if(EXISTS "${TORCH_PATH}/ATen/cuda/CUDAGeneratorImpl.h") + target_compile_definitions(fbgemm_gpu_py PRIVATE NEW_GENERATOR_PATH) + endif() +endif() -set_property(TARGET fbgemm_gpu_py PROPERTY CUDA_ARCHITECTURES - "${cuda_architectures}") set_target_properties(fbgemm_gpu_py PROPERTIES PREFIX "") target_link_libraries(fbgemm_gpu_py ${TORCH_LIBRARIES}) target_include_directories(fbgemm_gpu_py PRIVATE ${TORCH_INCLUDE_DIRS}) -set_property(TARGET fbgemm_gpu_py PROPERTY CXX_STANDARD 17) install(TARGETS fbgemm_gpu_py DESTINATION fbgemm_gpu) diff --git a/fbgemm_gpu/cmake/Hip.cmake b/fbgemm_gpu/cmake/Hip.cmake new file mode 100644 index 000000000..d6a6d7550 --- /dev/null +++ b/fbgemm_gpu/cmake/Hip.cmake @@ -0,0 +1,163 @@ +set(FBGEMM_HAVE_HIP FALSE) + +IF(NOT DEFINED ENV{ROCM_PATH}) + SET(ROCM_PATH /opt/rocm) +ELSE() + SET(ROCM_PATH $ENV{ROCM_PATH}) +ENDIF() + +# HIP_PATH +IF(NOT DEFINED ENV{HIP_PATH}) + SET(HIP_PATH ${ROCM_PATH}/hip) +ELSE() + SET(HIP_PATH $ENV{HIP_PATH}) +ENDIF() + +IF(NOT EXISTS ${HIP_PATH}) + return() +ENDIF() + +# HCC_PATH +IF(NOT DEFINED ENV{HCC_PATH}) + SET(HCC_PATH ${ROCM_PATH}/hcc) +ELSE() + SET(HCC_PATH $ENV{HCC_PATH}) +ENDIF() + +# HSA_PATH +IF(NOT DEFINED ENV{HSA_PATH}) + SET(HSA_PATH ${ROCM_PATH}/hsa) +ELSE() + SET(HSA_PATH $ENV{HSA_PATH}) +ENDIF() + +# ROCBLAS_PATH +IF(NOT DEFINED ENV{ROCBLAS_PATH}) + SET(ROCBLAS_PATH ${ROCM_PATH}/rocblas) +ELSE() + SET(ROCBLAS_PATH $ENV{ROCBLAS_PATH}) +ENDIF() + +# ROCSPARSE_PATH +IF(NOT DEFINED ENV{ROCSPARSE_PATH}) + SET(ROCSPARSE_PATH ${ROCM_PATH}/rocsparse) +ELSE() + SET(ROCSPARSE_PATH $ENV{ROCSPARSE_PATH}) +ENDIF() + +# ROCFFT_PATH +IF(NOT DEFINED ENV{ROCFFT_PATH}) + SET(ROCFFT_PATH ${ROCM_PATH}/rocfft) +ELSE() + SET(ROCFFT_PATH $ENV{ROCFFT_PATH}) +ENDIF() + +# HIPSPARSE_PATH +IF(NOT DEFINED ENV{HIPSPARSE_PATH}) + SET(HIPSPARSE_PATH ${ROCM_PATH}/hipsparse) +ELSE() + SET(HIPSPARSE_PATH $ENV{HIPSPARSE_PATH}) +ENDIF() + +# THRUST_PATH +IF(DEFINED ENV{THRUST_PATH}) + SET(THRUST_PATH $ENV{THRUST_PATH}) +ELSE() + SET(THRUST_PATH ${ROCM_PATH}/include) +ENDIF() + +# HIPRAND_PATH +IF(NOT DEFINED ENV{HIPRAND_PATH}) + SET(HIPRAND_PATH ${ROCM_PATH}/hiprand) +ELSE() + SET(HIPRAND_PATH $ENV{HIPRAND_PATH}) +ENDIF() + +# ROCRAND_PATH +IF(NOT DEFINED ENV{ROCRAND_PATH}) + SET(ROCRAND_PATH ${ROCM_PATH}/rocrand) +ELSE() + SET(ROCRAND_PATH $ENV{ROCRAND_PATH}) +ENDIF() + +# MIOPEN_PATH +IF(NOT DEFINED ENV{MIOPEN_PATH}) + SET(MIOPEN_PATH ${ROCM_PATH}/miopen) +ELSE() + SET(MIOPEN_PATH $ENV{MIOPEN_PATH}) +ENDIF() + +IF(NOT DEFINED ENV{FBGEMM_ROCM_ARCH}) + SET(FBGEMM_ROCM_ARCH gfx900;gfx906;gfx908;gfx90a) +ELSE() + SET(FBGEMM_ROCM_ARCH $ENV{FBGEMM_ROCM_ARCH}) +ENDIF() + +# Add HIP to the CMAKE Module Path +set(CMAKE_MODULE_PATH ${HIP_PATH}/cmake ${CMAKE_MODULE_PATH}) + +# Disable Asserts In Code (Can't use asserts on HIP stack.) +ADD_DEFINITIONS(-DNDEBUG) + +# Find the HIP Package +find_package(HIP) + +IF(HIP_FOUND) + set(FBGEMM_HAVE_HIP TRUE) + + if(HIP_COMPILER STREQUAL clang) + set(hip_library_name amdhip64) + else() + set(hip_library_name hip_hcc) + endif() + message("HIP library name: ${hip_library_name}") + + set(CMAKE_HCC_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG}) + set(CMAKE_HCC_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE}) + FIND_LIBRARY(FBGEMM_HIP_HCC_LIBRARIES ${hip_library_name} HINTS ${HIP_PATH}/lib) + + list(APPEND HIP_CXX_FLAGS -D__HIP_NO_HALF_OPERATORS__=1) + # list(APPEND HIP_CXX_FLAGS -D__HIP_NO_HALF_CONVERSIONS__=1) + list(APPEND HIP_CXX_FLAGS -D__HIP_NO_BFLOAT16_CONVERSIONS__=1) + list(APPEND HIP_CXX_FLAGS -D__HIP_NO_HALF2_OPERATORS__=1) + list(APPEND HIP_CXX_FLAGS -mavx2) + list(APPEND HIP_CXX_FLAGS -mf16c) + list(APPEND HIP_CXX_FLAGS -mfma) + list(APPEND HIP_CXX_FLAGS -std=c++17) + + set(HIP_HCC_FLAGS ${HIP_CXX_FLAGS}) + # Ask hcc to generate device code during compilation so we can use + # host linker to link. + list(APPEND HIP_HCC_FLAGS -fno-gpu-rdc) + list(APPEND HIP_HCC_FLAGS -Wno-defaulted-function-deleted) + foreach(fbgemm_rocm_arch ${FBGEMM_ROCM_ARCH}) + list(APPEND HIP_HCC_FLAGS --amdgpu-target=${fbgemm_rocm_arch}) + endforeach() + + set(hip_DIR ${HIP_PATH}/lib/cmake/hip) + set(hsa-runtime64_DIR ${ROCM_PATH}/lib/cmake/hsa-runtime64) + set(AMDDeviceLibs_DIR ${ROCM_PATH}/lib/cmake/AMDDeviceLibs) + set(amd_comgr_DIR ${ROCM_PATH}/lib/cmake/amd_comgr) + set(rocrand_DIR ${ROCRAND_PATH}/lib/cmake/rocrand) + set(hiprand_DIR ${HIPRAND_PATH}/lib/cmake/hiprand) + set(rocblas_DIR ${ROCBLAS_PATH}/lib/cmake/rocblas) + set(miopen_DIR ${MIOPEN_PATH}/lib/cmake/miopen) + set(rocfft_DIR ${ROCFFT_PATH}/lib/cmake/rocfft) + set(hipfft_DIR ${HIPFFT_PATH}/lib/cmake/hipfft) + set(hipsparse_DIR ${HIPSPARSE_PATH}/lib/cmake/hipsparse) + set(rccl_DIR ${RCCL_PATH}/lib/cmake/rccl) + set(rocprim_DIR ${ROCPRIM_PATH}/lib/cmake/rocprim) + set(hipcub_DIR ${HIPCUB_PATH}/lib/cmake/hipcub) + set(rocthrust_DIR ${ROCTHRUST_PATH}/lib/cmake/rocthrust) + set(ROCclr_DIR ${ROCM_PATH}/rocclr/lib/cmake/rocclr) + + find_package(hip REQUIRED) + + set(ROCRAND_INCLUDE ${ROCRAND_PATH}/include) + + set(FBGEMM_HIP_INCLUDE ${ROCM_PATH}/include ${FBGEMM_HIP_INCLUDE}) + set(FBGEMM_HIP_INCLUDE ${hip_INCLUDE_DIRS} $ $ ${FBGEMM_HIP_INCLUDE}) + + hip_include_directories(${FBGEMM_HIP_INCLUDE} ${ROCRAND_INCLUDE}) + +ENDIF() diff --git a/fbgemm_gpu/codegen/embedding_backward_code_generator.py b/fbgemm_gpu/codegen/embedding_backward_code_generator.py index 3acb2386d..23cc84696 100644 --- a/fbgemm_gpu/codegen/embedding_backward_code_generator.py +++ b/fbgemm_gpu/codegen/embedding_backward_code_generator.py @@ -9,7 +9,6 @@ import os from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple - import jinja2 args: argparse.Namespace @@ -965,3 +964,4 @@ def main() -> None: if __name__ == "__main__": main() + # hipify_gen() diff --git a/fbgemm_gpu/setup.py b/fbgemm_gpu/setup.py index 6c5c00b06..f33dc0f91 100644 --- a/fbgemm_gpu/setup.py +++ b/fbgemm_gpu/setup.py @@ -3,271 +3,39 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import glob import os -import shutil -import sysconfig import sys -import re -import tempfile -from codegen.embedding_backward_code_generator import emb_codegen -from setuptools import setup -from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension -import torch -sys.path.append("..") -from third_party.hipify_torch.hipify import hipify_python +from skbuild import setup cpu_only_build = False -cur_dir = os.path.dirname(os.path.realpath(__file__)) + cub_include_path = os.getenv("CUB_DIR", None) if cub_include_path is None: print( "CUDA CUB directory environment variable not set. Using default CUB location." ) -build_codegen_path = "build/codegen" -py_path = "python" - -is_rocm_pytorch = False -maj_ver, min_ver, _ = torch.__version__.split('.') -if int(maj_ver) > 1 or (int(maj_ver) == 1 and int(min_ver) >= 5): - from torch.utils.cpp_extension import ROCM_HOME - is_rocm_pytorch = True \ - if ((torch.version.hip is not None) and (ROCM_HOME is not None)) \ - else False # Get the long description from the relevant file cur_dir = os.path.dirname(os.path.realpath(__file__)) -generator_flag = [] -torch_dir = torch.__path__[0] -if os.path.exists(os.path.join(torch_dir, "include", "ATen", "cuda", "CUDAGeneratorImpl.h")): - generator_flag = ["-DNEW_GENERATOR_PATH"] - with open(os.path.join(cur_dir, "README.md"), encoding="utf-8") as f: long_description = f.read() -extra_compile_args = sysconfig.get_config_var("CFLAGS").split() -extra_compile_args += ["-mavx2", "-mf16c", "-mfma"] -if not is_rocm_pytorch: - extra_compile_args += ["-mavx512f", "-mavx512bw", "-mavx512dq", "-mavx512vl"] - -OPTIMIZERS = [ - "adagrad", - "adam", - "approx_rowwise_adagrad", - "approx_sgd", - "lamb", - "lars_sgd", - "partial_rowwise_adam", - "partial_rowwise_lamb", - "rowwise_adagrad", - "sgd", - "rowwise_weighted_adagrad" -] - -cpp_asmjit_files = glob.glob("../third_party/asmjit/src/asmjit/*/*.cpp") - -cpp_fbgemm_files = [ - "../src/EmbeddingSpMDMAvx2.cc", - "../src/EmbeddingSpMDM.cc", - "../src/EmbeddingSpMDMNBit.cc", - "../src/QuantUtils.cc", - "../src/QuantUtilsAvx2.cc", - "../src/RefImplementations.cc", - "../src/RowWiseSparseAdagradFused.cc", - "../src/SparseAdagrad.cc", - "../src/Utils.cc", -] - -if not is_rocm_pytorch: - cpp_fbgemm_files += ["../src/EmbeddingSpMDMAvx512.cc"] - -cpp_cpu_output_files = ( - [ - "gen_embedding_forward_quantized_unweighted_codegen_cpu.cpp", - "gen_embedding_forward_quantized_weighted_codegen_cpu.cpp", - "gen_embedding_backward_dense_split_cpu.cpp", - ] - + [ - "gen_embedding_backward_split_{}_cpu.cpp".format(optimizer) - for optimizer in OPTIMIZERS - ] - + [ - "gen_embedding_backward_{}_split_cpu.cpp".format(optimizer) - for optimizer in OPTIMIZERS - ] -) - -cpp_cuda_output_files = ( - [ - "gen_embedding_forward_dense_weighted_codegen_cuda.cu", - "gen_embedding_forward_dense_unweighted_codegen_cuda.cu", - "gen_embedding_forward_quantized_split_unweighted_codegen_cuda.cu", - "gen_embedding_forward_quantized_split_weighted_codegen_cuda.cu", - "gen_embedding_forward_split_weighted_codegen_cuda.cu", - "gen_embedding_forward_split_unweighted_codegen_cuda.cu", - "gen_embedding_backward_split_indice_weights_codegen_cuda.cu", - "gen_embedding_backward_dense_indice_weights_codegen_cuda.cu", - "gen_embedding_backward_dense_split_unweighted_cuda.cu", - "gen_embedding_backward_dense_split_weighted_cuda.cu", - ] - + [ - "gen_embedding_backward_{}_split_{}_cuda.cu".format(optimizer, weighted) - for optimizer in OPTIMIZERS - for weighted in [ - "weighted", - "unweighted", - ] - ] - + [ - "gen_embedding_backward_split_{}.cpp".format(optimizer) - for optimizer in OPTIMIZERS - ] -) - -py_output_files = ["lookup_{}.py".format(optimizer) for optimizer in OPTIMIZERS] - - -def generate_jinja_files(): - abs_build_path = os.path.join(cur_dir, build_codegen_path) - if not os.path.exists(abs_build_path): - os.makedirs(abs_build_path) - emb_codegen(install_dir=abs_build_path, is_fbcode=False) - - dst_python_path = os.path.join(cur_dir, py_path) - if not os.path.exists(dst_python_path): - os.makedirs(dst_python_path) - for filename in py_output_files: - shutil.copy2(os.path.join(abs_build_path, filename), dst_python_path) - shutil.copy2(os.path.join(cur_dir, "codegen", "lookup_args.py"), dst_python_path) - - -class FBGEMM_GPU_BuildExtension(BuildExtension.with_options(no_python_abi_suffix=True)): - def build_extension(self, ext): - if not is_rocm_pytorch: - generate_jinja_files() - else: - with hipify_python.GeneratedFileCleaner(keep_intermediates=True) as clean_ctx: - hipify_python.hipify( - project_directory=cur_dir, - output_directory=cur_dir, - header_include_dirs=[ - os.path.join(cur_dir, "include"), - os.path.join(cur_dir, "src"), - cur_dir, - ], - includes=["*"], - extra_files=CUDA_source, - show_detailed=True, - is_pytorch_extension=True, - clean_ctx=clean_ctx) - - super().build_extension(ext) - -if is_rocm_pytorch: - generate_jinja_files() - rocm_include_dirs = ["/opt/rocm/include/hiprand", "/opt/rocm/include/rocrand"] - libraries = [] -else: - rocm_include_dirs = [] - libraries = ["nvidia-ml"] - -include_dirs = [ cur_dir, - os.path.join(cur_dir, "include"), - os.path.join(cur_dir, "include/fbgemm_gpu"), - ] + rocm_include_dirs +import torch -if cub_include_path is not None: - include_dirs += [cub_include_path] +torch_root = os.path.dirname(torch.__file__) # Handle command line args before passing to main setup() method. if "--cpu_only" in sys.argv: cpu_only_build = True sys.argv.remove("--cpu_only") -CUDA_source = [ - os.path.join(cur_dir, build_codegen_path, "{}".format(f)) - for f in cpp_cuda_output_files + cpp_cpu_output_files - ] + cpp_asmjit_files + cpp_fbgemm_files + [ - os.path.join(cur_dir, "codegen/embedding_forward_split_cpu.cpp"), - os.path.join(cur_dir, "codegen/embedding_forward_quantized_host_cpu.cpp"), - os.path.join(cur_dir, "codegen/embedding_forward_quantized_host.cpp"), - os.path.join(cur_dir, "codegen/embedding_backward_dense_host_cpu.cpp"), - os.path.join(cur_dir, "codegen/embedding_backward_dense_host.cpp"), - os.path.join(cur_dir, "codegen/embedding_bounds_check_host.cpp"), - os.path.join(cur_dir, "codegen/embedding_bounds_check_host_cpu.cpp"), - os.path.join(cur_dir, "codegen/embedding_bounds_check.cu"), - os.path.join(cur_dir, "src/split_embeddings_cache_cuda.cu"), - os.path.join(cur_dir, "src/split_table_batched_embeddings.cpp"), - os.path.join(cur_dir, "src/cumem_utils.cu"), - os.path.join(cur_dir, "src/cumem_utils_host.cpp"), - os.path.join(cur_dir, "src/quantize_ops_cpu.cpp"), - os.path.join(cur_dir, "src/quantize_ops_gpu.cpp"), - os.path.join(cur_dir, "src/quantize_ops.cu"), - os.path.join(cur_dir, "src/cpu_utils.cpp"), - os.path.join(cur_dir, "src/sparse_ops_cpu.cpp"), - os.path.join(cur_dir, "src/sparse_ops_gpu.cpp"), - os.path.join(cur_dir, "src/sparse_ops.cu"), - os.path.join(cur_dir, "src/merge_pooled_embeddings_gpu.cpp"), - os.path.join(cur_dir, "src/permute_pooled_embedding_ops.cu"), - os.path.join(cur_dir, "src/permute_pooled_embedding_ops_gpu.cpp"), - os.path.join(cur_dir, "src/layout_transform_ops_cpu.cpp"), - os.path.join(cur_dir, "src/layout_transform_ops_gpu.cpp"), - os.path.join(cur_dir, "src/layout_transform_ops.cu"), - os.path.join(cur_dir, "src/jagged_tensor_ops.cu"), - os.path.join(cur_dir, "src/histogram_binning_calibration_ops.cu"), - os.path.join(cur_dir, "src/split_embeddings_utils.cu"), - ] - -common_included = [ - cur_dir, - os.path.join(cur_dir, "include"), - os.path.join(cur_dir, "../include"), - os.path.join(cur_dir, "../src"), - os.path.join(cur_dir, "../third_party/asmjit/src"), - os.path.join(cur_dir, "../third_party/asmjit/src/core"), - os.path.join(cur_dir, "../third_party/asmjit/src/x86"), - os.path.join(cur_dir, "../third_party/cpuinfo/include"), - ] - -CUDA_include = common_included + include_dirs - -Cpp_source = [ - os.path.join(cur_dir, build_codegen_path, "{}".format(f)) - for f in cpp_cpu_output_files - ] + cpp_asmjit_files + cpp_fbgemm_files + [ - os.path.join(cur_dir, "codegen/embedding_forward_split_cpu.cpp"), - os.path.join(cur_dir, "codegen/embedding_forward_quantized_host_cpu.cpp"), - os.path.join(cur_dir, "codegen/embedding_backward_dense_host_cpu.cpp"), - ] - setup( name="fbgemm_gpu", - install_requires=[ - "torch", - "Jinja2", - "click", - "hypothesis", - ], version="0.0.1", long_description=long_description, - ext_modules=[ - CUDAExtension( - name="fbgemm_gpu_py", - sources=CUDA_source, - include_dirs=CUDA_include, - extra_compile_args={"cxx": extra_compile_args + ["-DFBGEMM_GPU_WITH_CUDA"] + generator_flag, - "nvcc": ["-U__CUDA_NO_HALF_CONVERSIONS__"] + generator_flag}, - libraries=libraries, - ) if not cpu_only_build else - CppExtension( - name="fbgemm_gpu_py", - sources=Cpp_source, - include_dirs=common_included, - extra_compile_args={"cxx": extra_compile_args + generator_flag}, - ) - ], - cmdclass={"build_ext": FBGEMM_GPU_BuildExtension}, + packages=["fbgemm_gpu"], + cmake_args=[f"-DCMAKE_PREFIX_PATH={torch_root}"], ) diff --git a/third_party/hipify_torch b/third_party/hipify_torch index 0f7fc0d5a..59e17e5fc 160000 --- a/third_party/hipify_torch +++ b/third_party/hipify_torch @@ -1 +1 @@ -Subproject commit 0f7fc0d5a45c5a4578275bcedff088ebf33772ed +Subproject commit 59e17e5fcf00d4fb7c0a64cd727ca08e5100d9bd From 69abf78f15e309881b8100b96101b5b7534e462b Mon Sep 17 00:00:00 2001 From: Reza Rahimi Date: Fri, 1 Apr 2022 12:23:28 -0700 Subject: [PATCH 21/76] Enable merge_pooled_embeddings op. in ROCm (#15) * Enable merge_pooled_embeddings op. in ROCm * Enabling the merge pool ops. Co-authored-by: liligwu --- fbgemm_gpu/CMakeLists.txt | 5 +- fbgemm_gpu/cmake/Hip.cmake | 3 +- .../src/merge_pooled_embeddings_gpu.cpp | 96 ++++++++++++++++++- .../test/merge_pooled_embeddings_test.py | 6 +- 4 files changed, 99 insertions(+), 11 deletions(-) diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt index 181008fc9..7d7d45b3d 100644 --- a/fbgemm_gpu/CMakeLists.txt +++ b/fbgemm_gpu/CMakeLists.txt @@ -85,6 +85,7 @@ if(USE_ROCM) find_package(rocRAND REQUIRED) find_package(hipSPARSE REQUIRED) find_package(OpenMP REQUIRED) + find_package(rocPRIM REQUIRED) message("${message_line}") message(STATUS "hip found ${ROCM_FOUND}") @@ -281,7 +282,7 @@ set(fbgemm_gpu_sources_cpu src/input_combine_cpu.cpp src/layout_transform_ops_cpu.cpp src/layout_transform_ops_gpu.cpp - # src/merge_pooled_embeddings_cpu.cpp src/merge_pooled_embeddings_gpu.cpp + src/merge_pooled_embeddings_cpu.cpp src/merge_pooled_embeddings_gpu.cpp src/permute_pooled_embedding_ops_gpu.cpp src/quantize_ops_cpu.cpp src/quantize_ops_gpu.cpp @@ -341,7 +342,7 @@ elseif(USE_ROCM) hip_add_library(fbgemm_gpu_py SHARED ${cpp_asmjit_files} ${FBGEMM_ALL_HIP_FILES} ${FBGEMM_HIP_HCC_LIBRARIES} HIPCC_OPTIONS ${HIP_CXX_FLAGS}) - target_include_directories(fbgemm_gpu_py PUBLIC ${FBGEMM_HIP_INCLUDE} ${ROCRAND_INCLUDE}) + target_include_directories(fbgemm_gpu_py PUBLIC ${FBGEMM_HIP_INCLUDE} ${ROCRAND_INCLUDE} ${ROCM_SMI_INCLUDE}) set_property(TARGET fbgemm_gpu_py PROPERTY HIP_ARCHITECTURES ${FBGEMM_ROCM_ARCH}) # For ROCm5.1 diff --git a/fbgemm_gpu/cmake/Hip.cmake b/fbgemm_gpu/cmake/Hip.cmake index d6a6d7550..9adbb6322 100644 --- a/fbgemm_gpu/cmake/Hip.cmake +++ b/fbgemm_gpu/cmake/Hip.cmake @@ -154,10 +154,11 @@ IF(HIP_FOUND) find_package(hip REQUIRED) set(ROCRAND_INCLUDE ${ROCRAND_PATH}/include) + set(ROCM_SMI_INCLUDE ${ROCM_PATH}/rocm_smi/include) set(FBGEMM_HIP_INCLUDE ${ROCM_PATH}/include ${FBGEMM_HIP_INCLUDE}) set(FBGEMM_HIP_INCLUDE ${hip_INCLUDE_DIRS} $ $ ${FBGEMM_HIP_INCLUDE}) - hip_include_directories(${FBGEMM_HIP_INCLUDE} ${ROCRAND_INCLUDE}) + hip_include_directories(${FBGEMM_HIP_INCLUDE} ${ROCRAND_INCLUDE} ${ROCM_SMI_INCLUDE}) ENDIF() diff --git a/fbgemm_gpu/src/merge_pooled_embeddings_gpu.cpp b/fbgemm_gpu/src/merge_pooled_embeddings_gpu.cpp index 04ec601ea..8562895c5 100644 --- a/fbgemm_gpu/src/merge_pooled_embeddings_gpu.cpp +++ b/fbgemm_gpu/src/merge_pooled_embeddings_gpu.cpp @@ -15,9 +15,93 @@ #include #include -// FIXME: Enable merge_pooled_embeddings for HIP. -// AMD GPUs don't seem to have nvml equivalent library support. -#ifndef __HIP_PLATFORM_HCC__ +#ifdef __HIP_PLATFORM_HCC__ +#include "rocm_smi/rocm_smi.h" +#include "hip/hip_runtime.h" + +#include + +#include "fbgemm_gpu/merge_pooled_embeddings.h" +#include "fbgemm_gpu/sparse_ops_utils.h" + +using Tensor = at::Tensor; + +#define RSMI_CHECK(fn) \ + do { \ + rsmi_status_t ret = (fn); \ + TORCH_CHECK((ret) == RSMI_STATUS_SUCCESS); \ + } while (0) + +#define RSMI_DEVICE_PCI_BUS_ID_BUFFER_SIZE 16 + +using Node = int64_t; +using Links = int64_t; +template +using AdjacencyMatrix = std::function; + +namespace { + +AdjacencyMatrix get_nvlink_matrix() { + auto world_size = at::cuda::getNumGPUs(); + RSMI_CHECK(rsmi_init(0)); + + // Note that ROCm_SMI uses a different numbering method to ROCm runtime, + // so we need to learn the mapping by using the bus ID. + uint32_t device_count; + RSMI_CHECK(rsmi_num_monitor_devices(&device_count)); + + std::unordered_map rocm_device_to_rsmi_device; + + for (const auto i : c10::irange(device_count)) { + uint64_t pci_info; + RSMI_CHECK(rsmi_dev_pci_id_get(i, &pci_info)); + uint64_t domain, bus, device, function; + domain = (pci_info >> 32) & 0xffffffff; + bus = (pci_info >> 8) & 0xff; + device = (pci_info >> 3) & 0x1f; + function = pci_info & 0x7; + // Different form CUDA, we do not get the PCI BUS ID as a char* and we need to reconstruct it. + char pci_bus_id_str[RSMI_DEVICE_PCI_BUS_ID_BUFFER_SIZE]; + sprintf(pci_bus_id_str, "%04X:%02X:%02X.%0X", domain, bus, device, function); + + std::array pci_bus_id; + std::copy( + &pci_bus_id_str[0], + &pci_bus_id_str[RSMI_DEVICE_PCI_BUS_ID_BUFFER_SIZE], + pci_bus_id.data()); + int32_t node = 0; + auto err = hipDeviceGetByPCIBusId(&node, pci_bus_id.data()); + if (err == hipSuccess) { + rocm_device_to_rsmi_device.insert({node, i}); + } else { + // flush the last error - this can occur when e.g. we set + // HIP_VISIBLE_DEVICES to a subset of the available GPUs in the system. + hipGetLastError(); + } + } + + std::vector links(world_size * world_size); + for (const auto i : c10::irange(world_size)) { + auto src_rsmi_device = rocm_device_to_rsmi_device.find(i); + if (src_rsmi_device != rocm_device_to_rsmi_device.end()){ + for (const auto j : c10::irange(world_size)) { + auto dst_rsmi_device = rocm_device_to_rsmi_device.find(j); + if (dst_rsmi_device != rocm_device_to_rsmi_device.end()){ + bool is_active; + RSMI_CHECK(rsmi_is_P2P_accessible(src_rsmi_device->second, dst_rsmi_device->second, &is_active)); + if (is_active) { + links[i * world_size + j] += 1; + } + } + } + } + } + RSMI_CHECK(rsmi_shut_down()); + return [=](Node i, Node j) { return links[i * world_size + j]; }; +} +} // namespace + +#else // CUDA #include #include @@ -106,7 +190,9 @@ AdjacencyMatrix get_nvlink_matrix() { return [=](Node i, Node j) { return links[i * world_size + j]; }; } - +} // namespace +#endif +namespace { // Hilariously unoptimized, but algorithmic correctness matters more here, and // we only do it once. AdjacencyMatrix get_intermediate_node(AdjacencyMatrix links) { @@ -409,4 +495,4 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "all_to_one_device(Tensor[] input_tensors, Device target_device) -> Tensor[]"); DISPATCH_TO_CUDA("all_to_one_device", fbgemm_gpu::all_to_one_device); } -#endif + diff --git a/fbgemm_gpu/test/merge_pooled_embeddings_test.py b/fbgemm_gpu/test/merge_pooled_embeddings_test.py index e09075e88..fe2c318c4 100644 --- a/fbgemm_gpu/test/merge_pooled_embeddings_test.py +++ b/fbgemm_gpu/test/merge_pooled_embeddings_test.py @@ -31,7 +31,7 @@ @unittest.skipIf(*gpu_unavailable) -@unittest.skipIf(open_source, "Not supported in open source yet") +#@unittest.skipIf(open_source, "Not supported in open source yet") class MergePooledEmbeddingsTest(unittest.TestCase): @given( num_ads=st.integers(min_value=1, max_value=10), @@ -39,7 +39,7 @@ class MergePooledEmbeddingsTest(unittest.TestCase): ads_tables=st.integers(min_value=1, max_value=32), num_gpus=st.integers(min_value=1, max_value=torch.cuda.device_count()), non_default_stream=st.booleans(), - r=st.randoms(use_true_random=False), + r=st.randoms(), ) # Can instantiate 8 contexts which takes a long time. @settings(verbosity=Verbosity.verbose, max_examples=40, deadline=None) @@ -93,7 +93,7 @@ def ref(pooled_ad_embeddings, batch_indices): num_inputs=st.integers(min_value=1, max_value=10), num_gpus=st.integers(min_value=1, max_value=torch.cuda.device_count()), non_default_stream=st.booleans(), - r=st.randoms(use_true_random=False), + r=st.randoms(), ) # Can instantiate 8 contexts which takes a long time. @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None) From bfac874a440806bca187b86f0e47f70f11c4648a Mon Sep 17 00:00:00 2001 From: liligwu Date: Thu, 14 Apr 2022 20:02:48 +0000 Subject: [PATCH 22/76] Fixing test_lxu_cache_lookup in AMD devices where warp_siize=64 --- fbgemm_gpu/test/split_table_batched_embeddings_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fbgemm_gpu/test/split_table_batched_embeddings_test.py b/fbgemm_gpu/test/split_table_batched_embeddings_test.py index 2324ed2f1..15dff64ac 100644 --- a/fbgemm_gpu/test/split_table_batched_embeddings_test.py +++ b/fbgemm_gpu/test/split_table_batched_embeddings_test.py @@ -3982,6 +3982,8 @@ def test_lxu_cache_lookup(self) -> None: # Testing all miss. linear_cache_indices_0 = torch.tensor( [32, 33, 34, 35, 36, 100, 1000, 1725] + ).cuda() if ASSOC == 32 else torch.tensor( + [64, 65, 66, 67, 68, 100, 1000, 1725] ).cuda() lxu_locations = torch.ops.fbgemm.lxu_cache_lookup( linear_cache_indices_0, lxu_cache_state_gpu, max_index From 1cf7e8440093310256fbbe188a5c8ebf6bce8157 Mon Sep 17 00:00:00 2001 From: liligwu Date: Fri, 15 Apr 2022 18:44:51 +0000 Subject: [PATCH 23/76] * Enabling the specificationn of hip architecture by using PYTORCH_ROCM_ARCH. # Enabling building on Pytorch 1.11. --- fbgemm_gpu/CMakeLists.txt | 4 ++-- fbgemm_gpu/cmake/Hip.cmake | 7 +------ 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt index 07bf39bfe..115271b16 100644 --- a/fbgemm_gpu/CMakeLists.txt +++ b/fbgemm_gpu/CMakeLists.txt @@ -381,9 +381,9 @@ elseif(USE_ROCM) hip_include_directories("${cpp_fbgemm_files_include_directories}") hip_add_library(fbgemm_gpu_py SHARED ${cpp_asmjit_files} ${FBGEMM_ALL_HIP_FILES} ${FBGEMM_HIP_HCC_LIBRARIES} - HIPCC_OPTIONS ${HIP_CXX_FLAGS}) + HIPCC_OPTIONS ${HIP_HCC_FLAGS}) target_include_directories(fbgemm_gpu_py PUBLIC ${FBGEMM_HIP_INCLUDE} ${ROCRAND_INCLUDE} ${ROCM_SMI_INCLUDE}) - set_property(TARGET fbgemm_gpu_py PROPERTY HIP_ARCHITECTURES ${FBGEMM_ROCM_ARCH}) + # set_property(TARGET fbgemm_gpu_py PROPERTY HIP_ARCHITECTURES ${FBGEMM_ROCM_ARCH}) # For ROCm5.1 list (GET TORCH_INCLUDE_DIRS 0 TORCH_PATH) diff --git a/fbgemm_gpu/cmake/Hip.cmake b/fbgemm_gpu/cmake/Hip.cmake index 9adbb6322..cdc225e9d 100644 --- a/fbgemm_gpu/cmake/Hip.cmake +++ b/fbgemm_gpu/cmake/Hip.cmake @@ -87,17 +87,12 @@ ELSE() SET(MIOPEN_PATH $ENV{MIOPEN_PATH}) ENDIF() -IF(NOT DEFINED ENV{FBGEMM_ROCM_ARCH}) - SET(FBGEMM_ROCM_ARCH gfx900;gfx906;gfx908;gfx90a) -ELSE() - SET(FBGEMM_ROCM_ARCH $ENV{FBGEMM_ROCM_ARCH}) -ENDIF() - # Add HIP to the CMAKE Module Path set(CMAKE_MODULE_PATH ${HIP_PATH}/cmake ${CMAKE_MODULE_PATH}) # Disable Asserts In Code (Can't use asserts on HIP stack.) ADD_DEFINITIONS(-DNDEBUG) +ADD_DEFINITIONS(-DUSE_ROCM) # Find the HIP Package find_package(HIP) From 5b3328788a21f0a25911d14117adc52600c8308d Mon Sep 17 00:00:00 2001 From: liligwu Date: Tue, 19 Apr 2022 20:18:33 +0000 Subject: [PATCH 24/76] *Fixing the unit tests in sparse_ops_test.py. *Fixing the path of Atomic.cuh path in embedding_backward_template_helpers.cuh. --- .../embedding_backward_template_helpers.cuh | 5 +- fbgemm_gpu/src/sparse_ops_cpu.cpp | 116 ++++++------------ 2 files changed, 44 insertions(+), 77 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/embedding_backward_template_helpers.cuh b/fbgemm_gpu/include/fbgemm_gpu/embedding_backward_template_helpers.cuh index e662a6d7a..0635c32e0 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/embedding_backward_template_helpers.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/embedding_backward_template_helpers.cuh @@ -25,8 +25,11 @@ #include #include #include +#if !defined(NEW_ATOMIC_PATH) #include - +#else +#include +#endif #include #include #include diff --git a/fbgemm_gpu/src/sparse_ops_cpu.cpp b/fbgemm_gpu/src/sparse_ops_cpu.cpp index 09916d9d5..d60b3503f 100644 --- a/fbgemm_gpu/src/sparse_ops_cpu.cpp +++ b/fbgemm_gpu/src/sparse_ops_cpu.cpp @@ -1543,62 +1543,6 @@ void _generic_histogram_binning_calibration_by_feature_cpu_kernel( } } -template -void _generic_histogram_binning_calibration_by_feature_cpu_kernel( - const int64_t num_logits, - const int64_t num_bins, - const int64_t num_segments, - const int64_t num_lengths, - const double recalibrate_value, - const int64_t bin_ctr_in_use_after, - const double bin_ctr_weight_value, - const T* const logit_data, - const int64_t* const segment_value_data, - const int64_t* const segment_lengths_data, - const double* const bin_num_examples_data, - const double* const bin_num_positives_data, - const double* const bin_boundaries, - int64_t* const dense_segment_value_data, - T* const calibrated_prediction_data, - int64_t* const bin_ids_data) { - int k = 0; - for (const auto i : c10::irange(num_lengths)) { - if (segment_lengths_data[i] > 0) { - // Add 1 to distinguish between 0 inserted by densification vs. original - // value. - dense_segment_value_data[i] = segment_value_data[k] + 1; - ++k; - } - } - - for (const auto i : c10::irange(num_logits)) { - const T pre_sigmoid = logit_data[i] + recalibrate_value; - const double uncalibrated = 1.0 / (1.0 + std::exp(-pre_sigmoid)); - - const int curr_bin_id = - std::lower_bound( - bin_boundaries, bin_boundaries + num_bins, uncalibrated) - - bin_boundaries; - - const int64_t curr_segment_value = - dense_segment_value_data[i] > num_segments - ? 0 - : std::max(0L, dense_segment_value_data[i] * num_bins); - - bin_ids_data[i] = curr_bin_id + curr_segment_value; - - const auto curr_bin_num_examples = bin_num_examples_data[bin_ids_data[i]]; - if (curr_bin_num_examples > bin_ctr_in_use_after) { - const auto curr_bin_ctr = - bin_num_positives_data[bin_ids_data[i]] / curr_bin_num_examples; - calibrated_prediction_data[i] = curr_bin_ctr * bin_ctr_weight_value + - uncalibrated * (1.0 - bin_ctr_weight_value); - } else { - calibrated_prediction_data[i] = uncalibrated; - } - } -} - std::tuple generic_histogram_binning_calibration_by_feature_cpu( const Tensor& logit, const Tensor& segment_value, @@ -1623,31 +1567,51 @@ std::tuple generic_histogram_binning_calibration_by_feature_cpu( // dense_segment_value is used as a temporary storage. Tensor dense_segment_value = - at::zeros({logit.numel()}, segment_value.options()); + at::empty({logit.numel()}, segment_value.options()); + AT_DISPATCH_INDEX_TYPES( + segment_value.scalar_type(), "to_dense_representation_cpu_wrapper", [&] { + using segment_value_t = index_t; + AT_DISPATCH_INDEX_TYPES( + segment_lengths.scalar_type(), "to_dense_representation_cpu", [&] { + using segment_length_t = index_t; + _to_dense_representation( + segment_lengths.numel(), + segment_value.data_ptr(), + segment_lengths.data_ptr(), + dense_segment_value.data_ptr()); + }); + }); + Tensor calibrated_prediction = at::empty_like(logit); Tensor bin_ids = at::empty({logit.numel()}, logit.options().dtype(at::kLong)); const double recalibrate_value = std::log(positive_weight); AT_DISPATCH_FLOATING_TYPES_AND_HALF( logit.type(), - "generic_histogram_binning_calibration_by_feature_cpu", - [&]() { - _generic_histogram_binning_calibration_by_feature_cpu_kernel( - logit.numel(), - bin_boundaries.numel() + 1, - num_segments, - segment_lengths.numel(), - recalibrate_value, - bin_ctr_in_use_after, - bin_ctr_weight_value, - logit.data_ptr(), - segment_value.data_ptr(), - segment_lengths.data_ptr(), - bin_num_examples.data_ptr(), - bin_num_positives.data_ptr(), - bin_boundaries.data_ptr(), - dense_segment_value.data_ptr(), - calibrated_prediction.data_ptr(), - bin_ids.data_ptr()); + "generic_histogram_binning_calibration_by_feature_cpu_wrapper", + [&] { + using logit_t = scalar_t; + AT_DISPATCH_INDEX_TYPES( + segment_value.scalar_type(), + "generic_histogram_binning_calibration_by_feature_cpu", + [&] { + using segment_value_t = index_t; + _generic_histogram_binning_calibration_by_feature_cpu_kernel< + logit_t, + segment_value_t>( + logit.numel(), + bin_boundaries.numel() + 1, + num_segments, + recalibrate_value, + bin_ctr_in_use_after, + bin_ctr_weight_value, + logit.data_ptr(), + dense_segment_value.data_ptr(), + bin_num_examples.data_ptr(), + bin_num_positives.data_ptr(), + bin_boundaries.data_ptr(), + calibrated_prediction.data_ptr(), + bin_ids.data_ptr()); + }); }); return std::make_tuple(calibrated_prediction, bin_ids); From 0d5a012fb4031396fe677ab4ab842da0076323ad Mon Sep 17 00:00:00 2001 From: liligwu Date: Wed, 20 Apr 2022 14:28:09 +0000 Subject: [PATCH 25/76] Enable use_cpu in the tests. --- ...plit_embedding_inference_converter_test.py | 8 +++--- .../split_table_batched_embeddings_test.py | 28 +++++++++---------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/fbgemm_gpu/test/split_embedding_inference_converter_test.py b/fbgemm_gpu/test/split_embedding_inference_converter_test.py index 10869e590..382fb57c2 100644 --- a/fbgemm_gpu/test/split_embedding_inference_converter_test.py +++ b/fbgemm_gpu/test/split_embedding_inference_converter_test.py @@ -27,7 +27,7 @@ if open_source: # pyre-ignore[21] - from test_utils import gpu_available, TEST_WITH_ROCM + from test_utils import gpu_available else: from fbgemm_gpu.test.test_utils import gpu_available @@ -134,7 +134,7 @@ class QuantizedSplitEmbeddingsTest(unittest.TestCase): SparseType.INT2, ] ), - use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), + use_cpu=st.booleans(), pruning_ratio=st.sampled_from([None, 0.0]), ) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) @@ -208,7 +208,7 @@ def test_quantize_workflow( ) @given( - use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), + use_cpu=st.booleans(), use_array_for_index_remapping=st.booleans(), quantize_type=st.sampled_from( [ @@ -298,7 +298,7 @@ def test_l2_norm_pruning_workflow( D=st.integers(min_value=2, max_value=128), log_E=st.integers(min_value=3, max_value=5), pruning_ratio=st.floats(min_value=0.0, max_value=1.0, exclude_max=True), - use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), + use_cpu=st.booleans(), use_array_for_index_remapping=st.booleans(), ) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) diff --git a/fbgemm_gpu/test/split_table_batched_embeddings_test.py b/fbgemm_gpu/test/split_table_batched_embeddings_test.py index 15dff64ac..706a39333 100644 --- a/fbgemm_gpu/test/split_table_batched_embeddings_test.py +++ b/fbgemm_gpu/test/split_table_batched_embeddings_test.py @@ -30,7 +30,7 @@ if open_source: # pyre-ignore[21] - from test_utils import gpu_available, gpu_unavailable, TEST_WITH_ROCM + from test_utils import gpu_available, gpu_unavailable else: from fbgemm_gpu.test.test_utils import gpu_available, gpu_unavailable @@ -1094,7 +1094,7 @@ def test_nbit_forward_fused_pooled_emb_quant( split_table_batched_embeddings_ops.PoolingMode.NONE, ] ), - use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), + use_cpu=st.booleans(), ) @settings( verbosity=Verbosity.verbose, @@ -1313,7 +1313,7 @@ def test_backward_dense( split_table_batched_embeddings_ops.PoolingMode.NONE, ] ), - use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), + use_cpu=st.booleans(), exact=st.booleans(), ) @settings( @@ -1862,7 +1862,7 @@ def execute_backward_adagrad_( # noqa C901 cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), - use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), + use_cpu=st.booleans(), exact=st.booleans(), ) @settings( @@ -1924,7 +1924,7 @@ def test_backward_adagrad_fp16_pmSUM( # noqa C901 cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), - use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), + use_cpu=st.booleans(), exact=st.booleans(), ) @settings( @@ -1986,7 +1986,7 @@ def test_backward_adagrad_fp16_pmMEAN( # noqa C901 cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), - use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), + use_cpu=st.booleans(), exact=st.booleans(), ) @settings( @@ -2048,7 +2048,7 @@ def test_backward_adagrad_fp16_pmNONE( # noqa C901 cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), - use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), + use_cpu=st.booleans(), exact=st.booleans(), ) @settings( @@ -2110,7 +2110,7 @@ def test_backward_adagrad_fp32_pmSUM( # noqa C901 cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), - use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), + use_cpu=st.booleans(), exact=st.booleans(), ) @settings( @@ -2172,7 +2172,7 @@ def test_backward_adagrad_fp32_pmMEAN( # noqa C901 cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), - use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), + use_cpu=st.booleans(), exact=st.booleans(), ) @settings( @@ -2699,7 +2699,7 @@ def execute_backward_optimizers_( # noqa C901 split_table_batched_embeddings_ops.PoolingMode.NONE, ] ), - use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), + use_cpu=st.booleans(), ) @settings( verbosity=Verbosity.verbose, @@ -2759,7 +2759,7 @@ def test_backward_optimizers_adam( # noqa C901 split_table_batched_embeddings_ops.PoolingMode.NONE, ] ), - use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), + use_cpu=st.booleans(), ) @settings( verbosity=Verbosity.verbose, @@ -2818,7 +2818,7 @@ def test_backward_optimizers_adagrad( # noqa C901 split_table_batched_embeddings_ops.PoolingMode.NONE, ] ), - use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), + use_cpu=st.booleans(), ) @settings( verbosity=Verbosity.verbose, @@ -2872,7 +2872,7 @@ def test_backward_optimizers_lamb( # noqa C901 split_table_batched_embeddings_ops.PoolingMode.NONE, ] ), - use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), + use_cpu=st.booleans(), ) @settings( verbosity=Verbosity.verbose, @@ -3806,7 +3806,7 @@ def test_nbit_cache_miss_counter(self, N: int) -> None: BoundsCheckMode.IGNORE, ] ), - use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), + use_cpu=st.booleans(), dtype=st.sampled_from( [ torch.int64, From 1718605ca9f77572d5037e3a55a428cc18a7f954 Mon Sep 17 00:00:00 2001 From: liligwu Date: Wed, 20 Apr 2022 19:35:54 +0000 Subject: [PATCH 26/76] *Taking @skipIfRocm back in the test_utils.py. *Fixing cublasGemmStridedBatchedEx in sparse_ops.cu. --- .../embedding_backward_template_helpers.cuh | 58 +++---------------- fbgemm_gpu/src/sparse_ops.cu | 51 ++++++++++++++++ fbgemm_gpu/test/test_utils.py | 12 ++++ 3 files changed, 70 insertions(+), 51 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/embedding_backward_template_helpers.cuh b/fbgemm_gpu/include/fbgemm_gpu/embedding_backward_template_helpers.cuh index 0635c32e0..4c4724cbe 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/embedding_backward_template_helpers.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/embedding_backward_template_helpers.cuh @@ -1,35 +1,28 @@ /* - * Copyright (c) Facebook, Inc. and its affiliates. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ -// clang-format off -#include "fbgemm_gpu/cub_namespace_prefix.cuh" -#include -#include -#include -#include "fbgemm_gpu/cub_namespace_postfix.cuh" -// clang-format on - #include #include +#include +#include +#include #if !defined(NEW_GENERATOR_PATH) #include #else #include #endif -#include -#include -#include #include -#include #if !defined(NEW_ATOMIC_PATH) #include #else #include #endif +#include + #include #include #include @@ -40,43 +33,6 @@ #include "fbgemm_cuda_utils.cuh" #include "sparse_ops_utils.h" -inline at::Tensor asynchronous_complete_cumsum(at::Tensor t_in) { - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(t_in.get_device()); - size_t temp_storage_bytes = 0; - TORCH_CHECK(t_in.is_contiguous()); - TORCH_CHECK(t_in.dtype() == at::kInt || t_in.dtype() == at::kLong); - // CUB only handles up to INT_MAX elements. - TORCH_CHECK(t_in.numel() < std::numeric_limits::max()); - TORCH_CHECK(t_in.dim() == 1); - auto t_out = at::empty({t_in.numel() + 1}, t_in.options()); - t_out[0].zero_(); - AT_DISPATCH_INTEGRAL_TYPES( - t_in.scalar_type(), "cub_inclusive_sum_wrapper1", ([&] { - AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum( - nullptr, - temp_storage_bytes, - t_in.data_ptr(), - t_out.data_ptr() + 1, - t_in.numel(), - at::cuda::getCurrentCUDAStream())); - })); - auto temp_storage = at::empty( - {static_cast(temp_storage_bytes)}, - t_in.options().dtype(at::kByte)); - AT_DISPATCH_INTEGRAL_TYPES( - t_in.scalar_type(), "cub_inclusive_sum_wrapper2", ([&] { - AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum( - temp_storage.data_ptr(), - temp_storage_bytes, - t_in.data_ptr(), - t_out.data_ptr() + 1, - t_in.numel(), - at::cuda::getCurrentCUDAStream())); - })); - return t_out; -} - class FixedDivisor { public: explicit FixedDivisor(const int32_t d) : d_(d) { @@ -157,4 +113,4 @@ DEVICE_INLINE int64_t gpuAtomicIncrement(int64_t* p) { return static_cast(atomicAdd( reinterpret_cast(p), static_cast(1))); -} +} \ No newline at end of file diff --git a/fbgemm_gpu/src/sparse_ops.cu b/fbgemm_gpu/src/sparse_ops.cu index e29413327..bb80b3f10 100644 --- a/fbgemm_gpu/src/sparse_ops.cu +++ b/fbgemm_gpu/src/sparse_ops.cu @@ -26,6 +26,10 @@ #include "fbgemm_gpu/fbgemm_cuda_utils.cuh" #include "fbgemm_gpu/split_embeddings_utils.cuh" +#ifdef __HIP_PLATFORM_HCC__ +#include +#endif + using Tensor = at::Tensor; namespace fbgemm_gpu { @@ -2122,6 +2126,52 @@ Tensor permute102_baddbmm_permute102_cuda( // C (m, b, n) = A (m, b, k) * B (b, k, n) ---> row major // C (m, b, n) = (B^T (b, k, n) * A^T (m, b, k))^T ---> column major +#ifdef __HIP_PLATFORM_HCC__ + float alpha = 1.0f; + float beta = 1.0f; + + auto Btype = HIPBLAS_R_16F; + auto ldb = n; + auto strideB = n * k; + + auto Atype = HIPBLAS_R_16F; + auto lda = k * batch_size; + auto strideA = k; + + auto Ctype = HIPBLAS_R_16F; + auto ldc = n * batch_size; + auto strideC = n; + +auto computeType = HIPBLAS_R_32F; + + auto result = hipblasGemmStridedBatchedEx( + handle, + HIPBLAS_OP_N, + HIPBLAS_OP_N, + n, + m, + k, + &alpha, + B.data_ptr(), + Btype, + ldb, + strideB, + A.data_ptr(), + Atype, + lda, + strideA, + &beta, + C.data_ptr(), + Ctype, + ldc, + strideC, + batch_size, + computeType, + HIPBLAS_GEMM_DEFAULT); + TORCH_CHECK(result == CUBLAS_STATUS_SUCCESS); + return C; +} +#else float alpha = 1.0f; float beta = 1.0f; @@ -2166,6 +2216,7 @@ Tensor permute102_baddbmm_permute102_cuda( TORCH_CHECK(result == CUBLAS_STATUS_SUCCESS); return C; } +#endif // Kernel for permuting the indices and weights. Used for permutation of // table-wise partitioned sequence embeddings diff --git a/fbgemm_gpu/test/test_utils.py b/fbgemm_gpu/test/test_utils.py index f2db2b890..82de24ea2 100644 --- a/fbgemm_gpu/test/test_utils.py +++ b/fbgemm_gpu/test/test_utils.py @@ -14,6 +14,7 @@ from functools import wraps import unittest +TEST_WITH_ROCM = os.getenv('FBGEMM_TEST_WITH_ROCM', '0') == '1' # Eigen/Python round 0.5 away from 0, Numpy rounds to even round_to_nearest: Callable[[np.ndarray], np.ndarray] = np.vectorize(round) @@ -185,3 +186,14 @@ def cpu_and_maybe_gpu() -> st.SearchStrategy[List[torch.device]]: def cpu_only() -> st.SearchStrategy[List[torch.device]]: return st.sampled_from([torch.device("cpu")]) + +def skipIfRocm(reason="test doesn't currently work on the ROCm stack"): + def skipIfRocmDecorator(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + if TEST_WITH_ROCM: + raise unittest.SkipTest(reason) + else: + fn(*args, **kwargs) + return wrapper + return skipIfRocmDecorator \ No newline at end of file From bc902a3e3c131280a36fdc215182c73406ba0396 Mon Sep 17 00:00:00 2001 From: liligwu Date: Wed, 20 Apr 2022 22:00:50 +0000 Subject: [PATCH 27/76] Cleaning up the code. --- fbgemm_gpu/CMakeLists.txt | 3 +-- fbgemm_gpu/src/sparse_ops.cu | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt index 99d4acbc9..e82988d5a 100644 --- a/fbgemm_gpu/CMakeLists.txt +++ b/fbgemm_gpu/CMakeLists.txt @@ -385,8 +385,7 @@ elseif(USE_ROCM) hip_add_library(fbgemm_gpu_py SHARED ${cpp_asmjit_files} ${FBGEMM_ALL_HIP_FILES} ${FBGEMM_HIP_HCC_LIBRARIES} HIPCC_OPTIONS ${HIP_HCC_FLAGS}) target_include_directories(fbgemm_gpu_py PUBLIC ${FBGEMM_HIP_INCLUDE} ${ROCRAND_INCLUDE} ${ROCM_SMI_INCLUDE}) - # set_property(TARGET fbgemm_gpu_py PROPERTY HIP_ARCHITECTURES ${FBGEMM_ROCM_ARCH}) - + # For ROCm5.1 list (GET TORCH_INCLUDE_DIRS 0 TORCH_PATH) if(EXISTS "${TORCH_PATH}/ATen/cuda/CUDAGeneratorImpl.h") diff --git a/fbgemm_gpu/src/sparse_ops.cu b/fbgemm_gpu/src/sparse_ops.cu index bb80b3f10..4c97688d8 100644 --- a/fbgemm_gpu/src/sparse_ops.cu +++ b/fbgemm_gpu/src/sparse_ops.cu @@ -2142,7 +2142,7 @@ Tensor permute102_baddbmm_permute102_cuda( auto ldc = n * batch_size; auto strideC = n; -auto computeType = HIPBLAS_R_32F; + auto computeType = HIPBLAS_R_32F; auto result = hipblasGemmStridedBatchedEx( handle, From 9a5a33b6d120108ef51f9de9f3de2125986017b7 Mon Sep 17 00:00:00 2001 From: Li Li Date: Thu, 21 Apr 2022 13:58:29 -0500 Subject: [PATCH 28/76] Enabling cuda (#25) * Alinging with upstream with merge_pooled_embeddings_test.py and enabling cuda. * Disabling use_cpu in split_table_batched_embeddings_test since it's still unstable. Co-authored-by: root --- fbgemm_gpu/CMakeLists.txt | 21 ++++++------ .../src/merge_pooled_embeddings_gpu.cpp | 1 + .../test/merge_pooled_embeddings_test.py | 28 +--------------- .../split_table_batched_embeddings_test.py | 32 +++++++++---------- 4 files changed, 29 insertions(+), 53 deletions(-) diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt index e82988d5a..0b2c3cb63 100644 --- a/fbgemm_gpu/CMakeLists.txt +++ b/fbgemm_gpu/CMakeLists.txt @@ -316,7 +316,7 @@ if(NOT FBGEMM_CPU_ONLY) codegen/embedding_bounds_check_host.cpp src/cumem_utils_host.cpp src/layout_transform_ops_gpu.cpp - src/merge_pooled_embeddings_cpu.cpp src/merge_pooled_embeddings_gpu.cpp + # src/merge_pooled_embeddings_cpu.cpp src/merge_pooled_embeddings_gpu.cpp src/permute_pooled_embedding_ops_gpu.cpp src/quantize_ops_gpu.cpp src/sparse_ops_gpu.cpp @@ -385,21 +385,22 @@ elseif(USE_ROCM) hip_add_library(fbgemm_gpu_py SHARED ${cpp_asmjit_files} ${FBGEMM_ALL_HIP_FILES} ${FBGEMM_HIP_HCC_LIBRARIES} HIPCC_OPTIONS ${HIP_HCC_FLAGS}) target_include_directories(fbgemm_gpu_py PUBLIC ${FBGEMM_HIP_INCLUDE} ${ROCRAND_INCLUDE} ${ROCM_SMI_INCLUDE}) - - # For ROCm5.1 - list (GET TORCH_INCLUDE_DIRS 0 TORCH_PATH) - if(EXISTS "${TORCH_PATH}/ATen/cuda/CUDAGeneratorImpl.h") - target_compile_definitions(fbgemm_gpu_py PRIVATE NEW_GENERATOR_PATH) - endif() - if(EXISTS "${TORCH_PATH}/ATen/cuda/Atomic.cuh") - target_compile_definitions(fbgemm_gpu_py PRIVATE NEW_ATOMIC_PATH) - endif() +endif() +list (GET TORCH_INCLUDE_DIRS 0 TORCH_PATH) +if(EXISTS "${TORCH_PATH}/ATen/cuda/CUDAGeneratorImpl.h") + target_compile_definitions(fbgemm_gpu_py PRIVATE NEW_GENERATOR_PATH) +endif() +if(EXISTS "${TORCH_PATH}/ATen/cuda/Atomic.cuh") + target_compile_definitions(fbgemm_gpu_py PRIVATE NEW_ATOMIC_PATH) endif() set_target_properties(fbgemm_gpu_py PROPERTIES PREFIX "") target_link_libraries(fbgemm_gpu_py ${TORCH_LIBRARIES}) target_include_directories(fbgemm_gpu_py PRIVATE ${TORCH_INCLUDE_DIRS}) +if(USE_CUDA) + set_property(TARGET fbgemm_gpu_py PROPERTY CXX_STANDARD 17) +endif() install(TARGETS fbgemm_gpu_py DESTINATION fbgemm_gpu) diff --git a/fbgemm_gpu/src/merge_pooled_embeddings_gpu.cpp b/fbgemm_gpu/src/merge_pooled_embeddings_gpu.cpp index 2f9642242..18bcf8cb7 100644 --- a/fbgemm_gpu/src/merge_pooled_embeddings_gpu.cpp +++ b/fbgemm_gpu/src/merge_pooled_embeddings_gpu.cpp @@ -482,6 +482,7 @@ std::vector all_to_one_device( output_tensors.reserve(input_tensors.size()); for (const auto& tensor : input_tensors) { + TORCH_CHECK(tensor.is_cuda()); output_tensors.push_back( tensor.device() != target_device ? at::empty(tensor.sizes(), tensor.options().device(target_device)) diff --git a/fbgemm_gpu/test/merge_pooled_embeddings_test.py b/fbgemm_gpu/test/merge_pooled_embeddings_test.py index 732477310..1cf5b8ba7 100644 --- a/fbgemm_gpu/test/merge_pooled_embeddings_test.py +++ b/fbgemm_gpu/test/merge_pooled_embeddings_test.py @@ -31,7 +31,7 @@ @unittest.skipIf(*gpu_unavailable) -#@unittest.skipIf(open_source, "Not supported in open source yet") +@unittest.skipIf(open_source, "Not supported in open source yet") class MergePooledEmbeddingsTest(unittest.TestCase): @given( num_ads=st.integers(min_value=1, max_value=10), @@ -119,32 +119,6 @@ def test_all_to_one_device( self.assertEqual(o.device, dst_device) torch.testing.assert_close(o.cpu(), i) - @given( - num_inputs=st.integers(min_value=1, max_value=10), - num_gpus=st.integers(min_value=1, max_value=torch.cuda.device_count()), - non_default_stream=st.booleans(), - r=st.randoms(), - ) - # Can instantiate 8 contexts which takes a long time. - @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None) - def test_all_to_one_device( - self, - num_inputs, - num_gpus, - non_default_stream, - r, - ) -> None: - dst_device = torch.device(f"cuda:{r.randint(0, num_gpus - 1)}") - with torch.cuda.device(dst_device): - inputs = [torch.randn(10, 20) for _ in range(num_inputs)] - cuda_inputs = [ - input.to(f"cuda:{i % num_gpus}") for i, input in enumerate(inputs) - ] - cuda_outputs = torch.ops.fbgemm.all_to_one_device(cuda_inputs, dst_device) - for i, o in zip(inputs, cuda_outputs): - self.assertEqual(o.device, dst_device) - torch.testing.assert_allclose(o.cpu(), i) - if __name__ == "__main__": unittest.main() diff --git a/fbgemm_gpu/test/split_table_batched_embeddings_test.py b/fbgemm_gpu/test/split_table_batched_embeddings_test.py index 0bdfad9b5..19500629d 100644 --- a/fbgemm_gpu/test/split_table_batched_embeddings_test.py +++ b/fbgemm_gpu/test/split_table_batched_embeddings_test.py @@ -33,7 +33,7 @@ if open_source: # pyre-ignore[21] - from test_utils import gpu_available, gpu_unavailable + from test_utils import gpu_available, gpu_unavailable, TEST_WITH_ROCM else: from fbgemm_gpu.test.test_utils import gpu_available, gpu_unavailable @@ -1097,7 +1097,7 @@ def test_nbit_forward_fused_pooled_emb_quant( split_table_batched_embeddings_ops.PoolingMode.NONE, ] ), - use_cpu=st.booleans(), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), ) @settings( verbosity=Verbosity.verbose, @@ -1316,7 +1316,7 @@ def test_backward_dense( split_table_batched_embeddings_ops.PoolingMode.NONE, ] ), - use_cpu=st.booleans(), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), exact=st.booleans(), ) @settings( @@ -1865,7 +1865,7 @@ def execute_backward_adagrad_( # noqa C901 cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), - use_cpu=st.booleans(), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), exact=st.booleans(), ) @settings( @@ -1927,7 +1927,7 @@ def test_backward_adagrad_fp16_pmSUM( # noqa C901 cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), - use_cpu=st.booleans(), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), exact=st.booleans(), ) @settings( @@ -1989,7 +1989,7 @@ def test_backward_adagrad_fp16_pmMEAN( # noqa C901 cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), - use_cpu=st.booleans(), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), exact=st.booleans(), ) @settings( @@ -2051,7 +2051,7 @@ def test_backward_adagrad_fp16_pmNONE( # noqa C901 cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), - use_cpu=st.booleans(), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), exact=st.booleans(), ) @settings( @@ -2113,7 +2113,7 @@ def test_backward_adagrad_fp32_pmSUM( # noqa C901 cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), - use_cpu=st.booleans(), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), exact=st.booleans(), ) @settings( @@ -2175,7 +2175,7 @@ def test_backward_adagrad_fp32_pmMEAN( # noqa C901 cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), - use_cpu=st.booleans(), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), exact=st.booleans(), ) @settings( @@ -2725,7 +2725,7 @@ def execute_backward_optimizers_( # noqa C901 split_table_batched_embeddings_ops.PoolingMode.NONE, ] ), - use_cpu=st.booleans(), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), ) @settings( verbosity=Verbosity.verbose, @@ -2785,7 +2785,7 @@ def test_backward_optimizers_adam( # noqa C901 split_table_batched_embeddings_ops.PoolingMode.NONE, ] ), - use_cpu=st.booleans() if torch.cuda.is_available() else st.just(True), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), weight_decay_mode=st.sampled_from( [ WeightDecayMode.L2, @@ -2852,7 +2852,7 @@ def test_backward_optimizers_adagrad( # noqa C901 split_table_batched_embeddings_ops.PoolingMode.NONE, ] ), - use_cpu=st.booleans(), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), ) @settings( verbosity=Verbosity.verbose, @@ -2906,7 +2906,7 @@ def test_backward_optimizers_lamb( # noqa C901 split_table_batched_embeddings_ops.PoolingMode.NONE, ] ), - use_cpu=st.booleans(), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), ) @settings( verbosity=Verbosity.verbose, @@ -3470,7 +3470,7 @@ def test_nbit_forward_uvm_cache( T=st.integers(min_value=1, max_value=5), B=st.integers(min_value=1, max_value=8), L=st.integers(min_value=0, max_value=8), - use_cpu=st.booleans() if gpu_available else st.just(True), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), use_cpu_hashtable=st.booleans(), use_array_for_index_remapping=st.booleans(), ) @@ -3840,7 +3840,7 @@ def test_nbit_cache_miss_counter(self, N: int) -> None: BoundsCheckMode.IGNORE, ] ), - use_cpu=st.booleans(), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), dtype=st.sampled_from( [ torch.int64, @@ -4099,7 +4099,7 @@ def test_lxu_cache_lookup(self) -> None: SparseType.INT8, ] ), - use_cpu=st.booleans() if gpu_available else st.just(True), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), test_internal=st.booleans(), ) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) From 6490dbcdfa1ae334bf34740283a330f010bc99c7 Mon Sep 17 00:00:00 2001 From: Li Li Date: Thu, 21 Apr 2022 13:58:29 -0500 Subject: [PATCH 29/76] Enabling cuda (#25) * Alinging with upstream with merge_pooled_embeddings_test.py and enabling cuda. * Disabling use_cpu in split_table_batched_embeddings_test since it's still unstable. Co-authored-by: root --- fbgemm_gpu/CMakeLists.txt | 21 ++++++------ .../src/merge_pooled_embeddings_gpu.cpp | 1 + .../test/merge_pooled_embeddings_test.py | 28 +--------------- .../split_table_batched_embeddings_test.py | 32 +++++++++---------- 4 files changed, 29 insertions(+), 53 deletions(-) diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt index e82988d5a..0b2c3cb63 100644 --- a/fbgemm_gpu/CMakeLists.txt +++ b/fbgemm_gpu/CMakeLists.txt @@ -316,7 +316,7 @@ if(NOT FBGEMM_CPU_ONLY) codegen/embedding_bounds_check_host.cpp src/cumem_utils_host.cpp src/layout_transform_ops_gpu.cpp - src/merge_pooled_embeddings_cpu.cpp src/merge_pooled_embeddings_gpu.cpp + # src/merge_pooled_embeddings_cpu.cpp src/merge_pooled_embeddings_gpu.cpp src/permute_pooled_embedding_ops_gpu.cpp src/quantize_ops_gpu.cpp src/sparse_ops_gpu.cpp @@ -385,21 +385,22 @@ elseif(USE_ROCM) hip_add_library(fbgemm_gpu_py SHARED ${cpp_asmjit_files} ${FBGEMM_ALL_HIP_FILES} ${FBGEMM_HIP_HCC_LIBRARIES} HIPCC_OPTIONS ${HIP_HCC_FLAGS}) target_include_directories(fbgemm_gpu_py PUBLIC ${FBGEMM_HIP_INCLUDE} ${ROCRAND_INCLUDE} ${ROCM_SMI_INCLUDE}) - - # For ROCm5.1 - list (GET TORCH_INCLUDE_DIRS 0 TORCH_PATH) - if(EXISTS "${TORCH_PATH}/ATen/cuda/CUDAGeneratorImpl.h") - target_compile_definitions(fbgemm_gpu_py PRIVATE NEW_GENERATOR_PATH) - endif() - if(EXISTS "${TORCH_PATH}/ATen/cuda/Atomic.cuh") - target_compile_definitions(fbgemm_gpu_py PRIVATE NEW_ATOMIC_PATH) - endif() +endif() +list (GET TORCH_INCLUDE_DIRS 0 TORCH_PATH) +if(EXISTS "${TORCH_PATH}/ATen/cuda/CUDAGeneratorImpl.h") + target_compile_definitions(fbgemm_gpu_py PRIVATE NEW_GENERATOR_PATH) +endif() +if(EXISTS "${TORCH_PATH}/ATen/cuda/Atomic.cuh") + target_compile_definitions(fbgemm_gpu_py PRIVATE NEW_ATOMIC_PATH) endif() set_target_properties(fbgemm_gpu_py PROPERTIES PREFIX "") target_link_libraries(fbgemm_gpu_py ${TORCH_LIBRARIES}) target_include_directories(fbgemm_gpu_py PRIVATE ${TORCH_INCLUDE_DIRS}) +if(USE_CUDA) + set_property(TARGET fbgemm_gpu_py PROPERTY CXX_STANDARD 17) +endif() install(TARGETS fbgemm_gpu_py DESTINATION fbgemm_gpu) diff --git a/fbgemm_gpu/src/merge_pooled_embeddings_gpu.cpp b/fbgemm_gpu/src/merge_pooled_embeddings_gpu.cpp index 2f9642242..18bcf8cb7 100644 --- a/fbgemm_gpu/src/merge_pooled_embeddings_gpu.cpp +++ b/fbgemm_gpu/src/merge_pooled_embeddings_gpu.cpp @@ -482,6 +482,7 @@ std::vector all_to_one_device( output_tensors.reserve(input_tensors.size()); for (const auto& tensor : input_tensors) { + TORCH_CHECK(tensor.is_cuda()); output_tensors.push_back( tensor.device() != target_device ? at::empty(tensor.sizes(), tensor.options().device(target_device)) diff --git a/fbgemm_gpu/test/merge_pooled_embeddings_test.py b/fbgemm_gpu/test/merge_pooled_embeddings_test.py index 732477310..1cf5b8ba7 100644 --- a/fbgemm_gpu/test/merge_pooled_embeddings_test.py +++ b/fbgemm_gpu/test/merge_pooled_embeddings_test.py @@ -31,7 +31,7 @@ @unittest.skipIf(*gpu_unavailable) -#@unittest.skipIf(open_source, "Not supported in open source yet") +@unittest.skipIf(open_source, "Not supported in open source yet") class MergePooledEmbeddingsTest(unittest.TestCase): @given( num_ads=st.integers(min_value=1, max_value=10), @@ -119,32 +119,6 @@ def test_all_to_one_device( self.assertEqual(o.device, dst_device) torch.testing.assert_close(o.cpu(), i) - @given( - num_inputs=st.integers(min_value=1, max_value=10), - num_gpus=st.integers(min_value=1, max_value=torch.cuda.device_count()), - non_default_stream=st.booleans(), - r=st.randoms(), - ) - # Can instantiate 8 contexts which takes a long time. - @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None) - def test_all_to_one_device( - self, - num_inputs, - num_gpus, - non_default_stream, - r, - ) -> None: - dst_device = torch.device(f"cuda:{r.randint(0, num_gpus - 1)}") - with torch.cuda.device(dst_device): - inputs = [torch.randn(10, 20) for _ in range(num_inputs)] - cuda_inputs = [ - input.to(f"cuda:{i % num_gpus}") for i, input in enumerate(inputs) - ] - cuda_outputs = torch.ops.fbgemm.all_to_one_device(cuda_inputs, dst_device) - for i, o in zip(inputs, cuda_outputs): - self.assertEqual(o.device, dst_device) - torch.testing.assert_allclose(o.cpu(), i) - if __name__ == "__main__": unittest.main() diff --git a/fbgemm_gpu/test/split_table_batched_embeddings_test.py b/fbgemm_gpu/test/split_table_batched_embeddings_test.py index 0bdfad9b5..19500629d 100644 --- a/fbgemm_gpu/test/split_table_batched_embeddings_test.py +++ b/fbgemm_gpu/test/split_table_batched_embeddings_test.py @@ -33,7 +33,7 @@ if open_source: # pyre-ignore[21] - from test_utils import gpu_available, gpu_unavailable + from test_utils import gpu_available, gpu_unavailable, TEST_WITH_ROCM else: from fbgemm_gpu.test.test_utils import gpu_available, gpu_unavailable @@ -1097,7 +1097,7 @@ def test_nbit_forward_fused_pooled_emb_quant( split_table_batched_embeddings_ops.PoolingMode.NONE, ] ), - use_cpu=st.booleans(), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), ) @settings( verbosity=Verbosity.verbose, @@ -1316,7 +1316,7 @@ def test_backward_dense( split_table_batched_embeddings_ops.PoolingMode.NONE, ] ), - use_cpu=st.booleans(), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), exact=st.booleans(), ) @settings( @@ -1865,7 +1865,7 @@ def execute_backward_adagrad_( # noqa C901 cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), - use_cpu=st.booleans(), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), exact=st.booleans(), ) @settings( @@ -1927,7 +1927,7 @@ def test_backward_adagrad_fp16_pmSUM( # noqa C901 cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), - use_cpu=st.booleans(), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), exact=st.booleans(), ) @settings( @@ -1989,7 +1989,7 @@ def test_backward_adagrad_fp16_pmMEAN( # noqa C901 cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), - use_cpu=st.booleans(), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), exact=st.booleans(), ) @settings( @@ -2051,7 +2051,7 @@ def test_backward_adagrad_fp16_pmNONE( # noqa C901 cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), - use_cpu=st.booleans(), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), exact=st.booleans(), ) @settings( @@ -2113,7 +2113,7 @@ def test_backward_adagrad_fp32_pmSUM( # noqa C901 cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), - use_cpu=st.booleans(), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), exact=st.booleans(), ) @settings( @@ -2175,7 +2175,7 @@ def test_backward_adagrad_fp32_pmMEAN( # noqa C901 cache_algorithm=st.sampled_from( split_table_batched_embeddings_ops.CacheAlgorithm ), - use_cpu=st.booleans(), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), exact=st.booleans(), ) @settings( @@ -2725,7 +2725,7 @@ def execute_backward_optimizers_( # noqa C901 split_table_batched_embeddings_ops.PoolingMode.NONE, ] ), - use_cpu=st.booleans(), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), ) @settings( verbosity=Verbosity.verbose, @@ -2785,7 +2785,7 @@ def test_backward_optimizers_adam( # noqa C901 split_table_batched_embeddings_ops.PoolingMode.NONE, ] ), - use_cpu=st.booleans() if torch.cuda.is_available() else st.just(True), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), weight_decay_mode=st.sampled_from( [ WeightDecayMode.L2, @@ -2852,7 +2852,7 @@ def test_backward_optimizers_adagrad( # noqa C901 split_table_batched_embeddings_ops.PoolingMode.NONE, ] ), - use_cpu=st.booleans(), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), ) @settings( verbosity=Verbosity.verbose, @@ -2906,7 +2906,7 @@ def test_backward_optimizers_lamb( # noqa C901 split_table_batched_embeddings_ops.PoolingMode.NONE, ] ), - use_cpu=st.booleans(), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), ) @settings( verbosity=Verbosity.verbose, @@ -3470,7 +3470,7 @@ def test_nbit_forward_uvm_cache( T=st.integers(min_value=1, max_value=5), B=st.integers(min_value=1, max_value=8), L=st.integers(min_value=0, max_value=8), - use_cpu=st.booleans() if gpu_available else st.just(True), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), use_cpu_hashtable=st.booleans(), use_array_for_index_remapping=st.booleans(), ) @@ -3840,7 +3840,7 @@ def test_nbit_cache_miss_counter(self, N: int) -> None: BoundsCheckMode.IGNORE, ] ), - use_cpu=st.booleans(), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), dtype=st.sampled_from( [ torch.int64, @@ -4099,7 +4099,7 @@ def test_lxu_cache_lookup(self) -> None: SparseType.INT8, ] ), - use_cpu=st.booleans() if gpu_available else st.just(True), + use_cpu=st.booleans() if (gpu_available and not TEST_WITH_ROCM) else st.just(False) if (gpu_available and TEST_WITH_ROCM) else st.just(True), test_internal=st.booleans(), ) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) From 785afb88f8b26fceffd442921df48b7b6962a1a9 Mon Sep 17 00:00:00 2001 From: Li Li Date: Thu, 5 May 2022 18:19:55 +0000 Subject: [PATCH 30/76] Removing building and testing bash scripts. --- fbgemm_gpu/build.sh | 4 ---- fbgemm_gpu/test/run.sh | 20 -------------------- 2 files changed, 24 deletions(-) delete mode 100755 fbgemm_gpu/build.sh delete mode 100755 fbgemm_gpu/test/run.sh diff --git a/fbgemm_gpu/build.sh b/fbgemm_gpu/build.sh deleted file mode 100755 index f181dd6a9..000000000 --- a/fbgemm_gpu/build.sh +++ /dev/null @@ -1,4 +0,0 @@ -#!/bin/bash - -export MAX_JOBS=32 -python3.6 setup.py build develop 2>&1 | tee build.log diff --git a/fbgemm_gpu/test/run.sh b/fbgemm_gpu/test/run.sh deleted file mode 100755 index 2e9347833..000000000 --- a/fbgemm_gpu/test/run.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -# exit immediately on failure, or if an undefined variable is used -set -eux - -export FBGEMM_TEST_WITH_ROCM=1 - -python layout_transform_ops_test.py --verbose - -python permute_pooled_embedding_modules_test.py --verbose - -python sparse_ops_test.py --verbose - -python merge_pooled_embeddings_test.py --verbose - -python quantize_ops_test.py --verbose - -python split_embedding_inference_converter_test.py --verbose - -python split_table_batched_embeddings_test.py --verbose \ No newline at end of file From bbd0ad14cfe1ed5cd8f4da7edf5b8c60b5e2e6dd Mon Sep 17 00:00:00 2001 From: liligwu Date: Mon, 9 May 2022 22:05:17 +0000 Subject: [PATCH 31/76] * Addressing the comments in PR review ROCm changes #1102. * Reoganize CMakeList.txt and minimize the differences to the upsream. --- fbgemm_gpu/CMakeLists.txt | 165 +++++++----------- fbgemm_gpu/build.sh | 7 + fbgemm_gpu/cmake/Hip.cmake | 25 ++- .../embedding_backward_code_generator.py | 1 - .../embedding_backward_split_template.cu | 1 + .../embedding_forward_template_helpers.cuh | 4 - .../embedding_backward_template_helpers.cuh | 8 - .../fbgemm_gpu/hipcub_namespace_postfix.cuh | 21 --- .../fbgemm_gpu/hipcub_namespace_postfix.hpp | 21 --- .../fbgemm_gpu/hipcub_namespace_prefix.cuh | 16 -- .../fbgemm_gpu/hipcub_namespace_prefix.hpp | 16 -- fbgemm_gpu/run_all.sh | 39 ----- fbgemm_gpu/src/cumem_utils.h | 2 +- fbgemm_gpu/src/split_embeddings_cache_cuda.cu | 8 - 14 files changed, 97 insertions(+), 237 deletions(-) create mode 100644 fbgemm_gpu/build.sh delete mode 100644 fbgemm_gpu/include/fbgemm_gpu/hipcub_namespace_postfix.cuh delete mode 100644 fbgemm_gpu/include/fbgemm_gpu/hipcub_namespace_postfix.hpp delete mode 100644 fbgemm_gpu/include/fbgemm_gpu/hipcub_namespace_prefix.cuh delete mode 100644 fbgemm_gpu/include/fbgemm_gpu/hipcub_namespace_prefix.hpp delete mode 100755 fbgemm_gpu/run_all.sh diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt index 09c1669ac..5ca034d2a 100644 --- a/fbgemm_gpu/CMakeLists.txt +++ b/fbgemm_gpu/CMakeLists.txt @@ -41,18 +41,11 @@ else() LANGUAGES CXX C CUDA) endif() -if(USE_CUDA) - set(default_cuda_architectures 60 61 70 75 80) - set(cuda_architectures_doc - "CUDA architectures to build for. Default is ${default_cuda_architectures}") - set(cuda_architectures - "${default_cuda_architectures}" - CACHE STRING "${cuda_architectures_doc}") +find_package(Torch REQUIRED) +find_package(PythonExtensions REQUIRED) - message("${message_line}") - message("fbgemm_gpu:") - message("Building for cuda_architectures = \"${cuda_architectures}\"") - message("${message_line}") +set(FBGEMM ${CMAKE_CURRENT_SOURCE_DIR}/..) +set(THIRDPARTY ${FBGEMM}/third_party) if(DEFINED GLIBCXX_USE_CXX11_ABI) if(${GLIBCXX_USE_CXX11_ABI} EQUAL 1) @@ -68,49 +61,24 @@ endif() # constructor exists to convert from "int" to "__half" errors in # gen_embedding_forward_quantized_split_[un]weighted_codegen_cuda.cu # - set(TORCH_CUDA_OPTIONS - --expt-relaxed-constexpr - -D__CUDA_NO_HALF_OPERATORS__ - # -D__CUDA_NO_HALF_CONVERSIONS__ - -D__CUDA_NO_BFLOAT16_CONVERSIONS__ - -D__CUDA_NO_HALF2_OPERATORS__) -endif() -find_package(Torch REQUIRED) -find_package(PythonExtensions REQUIRED) +set(TORCH_CUDA_OPTIONS + --expt-relaxed-constexpr + -D__CUDA_NO_HALF_OPERATORS__ + # -D__CUDA_NO_HALF_CONVERSIONS__ + -D__CUDA_NO_BFLOAT16_CONVERSIONS__ + -D__CUDA_NO_HALF2_OPERATORS__) -set(FBGEMM ${CMAKE_CURRENT_SOURCE_DIR}/..) -set(THIRDPARTY ${FBGEMM}/third_party) if(USE_ROCM) - if(NOT DEFINED ENV{PYTORCH_ROCM_ARCH}) - SET(FBGEMM_ROCM_ARCH gfx900;gfx906;gfx908;gfx90a) - else() - SET(FBGEMM_ROCM_ARCH $ENV{PYTORCH_ROCM_ARCH}) - endif() - list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake" "${THIRDPARTY}/hipify_torch/cmake") include(Hip) - if(NOT FBGEMM_HAVE_HIP) - message(FATAL_ERROR "Not able to find HIP installation.") - endif() include(Hipify) - list (APPEND CMAKE_PREFIX_PATH /opt/rocm/hip /opt/rocm) - set(CMAKE_MODULE_PATH ${HIP_PATH}/cmake ${CMAKE_MODULE_PATH}) - - find_package(rocBLAS REQUIRED) - find_package(hipFFT REQUIRED) - find_package(hipRAND REQUIRED) - find_package(rocRAND REQUIRED) - find_package(hipSPARSE REQUIRED) - find_package(OpenMP REQUIRED) - find_package(rocPRIM REQUIRED) - + message("${message_line}") message(STATUS "hip found ${ROCM_FOUND}") endif() - # # GENERATED CUDA, CPP and Python code # @@ -197,31 +165,18 @@ set(codegen_dependencies ${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/embedding_backward_template_helpers.cuh ${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/embedding_common.h ${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/fbgemm_cuda_utils.cuh - ${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/quantize_ops_gpu.h ${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/quantize_ops_utils.h ${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/split_embeddings_utils.cuh ${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/sparse_ops_utils.h ) -if(USE_CUDA) - add_custom_command( - OUTPUT ${gen_cpu_source_files} ${gen_gpu_source_files} - ${gen_gpu_host_source_files} ${gen_python_files} +if(USE_ROCM) + execute_process( COMMAND - "${PYTHON_EXECUTABLE}" - "${CMAKE_CODEGEN_DIR}/embedding_backward_code_generator.py" - "--opensource" - DEPENDS "${codegen_dependencies}") - - set_source_files_properties( - ${gen_cpu_source_files} PROPERTIES COMPILE_OPTIONS - "-mavx2;-mf16c;-mfma;-fopenmp") -elseif(USE_ROCM) - execute_process( - COMMAND - "${PYTHON_EXECUTABLE}" - "${CMAKE_CODEGEN_DIR}/embedding_backward_code_generator.py" - "--opensource") + "${PYTHON_EXECUTABLE}" + "${CMAKE_CODEGEN_DIR}/embedding_backward_code_generator.py" + "--opensource" + DEPENDS "${codegen_dependencies}") set(header_include_dir ${CMAKE_CURRENT_SOURCE_DIR}/include @@ -229,12 +184,21 @@ elseif(USE_ROCM) ${CMAKE_CURRENT_SOURCE_DIR} ) hipify(CUDA_SOURCE_DIR ${PROJECT_SOURCE_DIR} HEADER_INCLUDE_DIR ${header_include_dir}) - - set_source_files_properties( - ${gen_cpu_source_files} PROPERTIES COMPILE_OPTIONS - "-mavx2;-mf16c;-mfma") +else() + add_custom_command( + OUTPUT ${gen_cpu_source_files} ${gen_gpu_source_files} + ${gen_gpu_host_source_files} ${gen_python_files} + COMMAND + "${PYTHON_EXECUTABLE}" + "${CMAKE_CODEGEN_DIR}/embedding_backward_code_generator.py" + "--opensource" + DEPENDS "${codegen_dependencies}") endif() +set_source_files_properties( + ${gen_cpu_source_files} PROPERTIES COMPILE_OPTIONS + "-mavx2;-mf16c;-mfma;-fopenmp") + set_source_files_properties( ${gen_cpu_source_files} PROPERTIES @@ -285,14 +249,18 @@ set(cpp_fbgemm_files_avx2 "../src/EmbeddingSpMDMAvx2.cc" set_source_files_properties(${cpp_fbgemm_files_avx2} PROPERTIES COMPILE_OPTIONS "-mavx2;-mf16c;-mfma") -set(cpp_fbgemm_files ${cpp_fbgemm_files_normal} ${cpp_fbgemm_files_avx2}) set(cpp_fbgemm_files_avx512 "../src/EmbeddingSpMDMAvx512.cc") -if(USE_CUDA) - set_source_files_properties( - ${cpp_fbgemm_files_avx512} - PROPERTIES COMPILE_OPTIONS - "-mavx2;-mf16c;-mfma;-mavx512f;-mavx512bw;-mavx512dq;-mavx512vl") - list(APPEND cpp_fbgemm_files ${cpp_fbgemm_files_avx512}) + +set_source_files_properties( + ${cpp_fbgemm_files_avx512} + PROPERTIES COMPILE_OPTIONS + "-mavx2;-mf16c;-mfma;-mavx512f;-mavx512bw;-mavx512dq;-mavx512vl") + +if(USE_ROCM) + set(cpp_fbgemm_files ${cpp_fbgemm_files_normal} ${cpp_fbgemm_files_avx2}) +else() + set(cpp_fbgemm_files ${cpp_fbgemm_files_normal} ${cpp_fbgemm_files_avx2} + ${cpp_fbgemm_files_avx512}) endif() set(cpp_fbgemm_files_include_directories @@ -356,12 +324,9 @@ if(NOT FBGEMM_CPU_ONLY) endif() endif() -set(fbgemm_gpu_sources_cpu_option "-mavx;-mf16c;-mfma;-mavx2") -if(USE_CUDA) - set_source_files_properties( - ${fbgemm_gpu_sources_cpu} PROPERTIES COMPILE_OPTIONS - "${fbgemm_gpu_sources_cpu_option};-fopenmp") -endif() +set_source_files_properties( + ${fbgemm_gpu_sources_cpu} PROPERTIES COMPILE_OPTIONS + "-mavx;-mf16c;-mfma;-mavx2;-fopenmp") if(NOT FBGEMM_CPU_ONLY) set(fbgemm_gpu_sources_gpu @@ -398,16 +363,11 @@ if(USE_ROCM) endforeach() endif() -if(USE_CUDA) - add_library(fbgemm_gpu_py MODULE ${fbgemm_gpu_sources} ${gen_source_files} - ${cpp_asmjit_files} ${cpp_fbgemm_files}) - set_property(TARGET fbgemm_gpu_py PROPERTY CUDA_ARCHITECTURES - "${cuda_architectures}") - if(NOT FBGEMM_CPU_ONLY) - target_compile_definitions(fbgemm_gpu_py PRIVATE FBGEMM_CUB_USE_NAMESPACE) - endif() - set_property(TARGET fbgemm_gpu_py PROPERTY CXX_STANDARD 17) -elseif(USE_ROCM) +# +# MODULE +# + +if(USE_ROCM) get_hipified_list("${fbgemm_gpu_sources}" fbgemm_gpu_sources) get_hipified_list("${abspath_gen_source_files}" abspath_gen_source_files) get_hipified_list("${cpp_fbgemm_files}" cpp_fbgemm_files) @@ -415,17 +375,26 @@ elseif(USE_ROCM) set(FBGEMM_ALL_HIP_FILES ${fbgemm_gpu_sources} ${abspath_gen_source_files} ${cpp_fbgemm_files}) set_source_files_properties(${FBGEMM_ALL_HIP_FILES} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) hip_include_directories("${cpp_fbgemm_files_include_directories}") - - hip_add_library(fbgemm_gpu_py SHARED ${cpp_asmjit_files} ${FBGEMM_ALL_HIP_FILES} ${FBGEMM_HIP_HCC_LIBRARIES} + + hip_add_library(fbgemm_gpu_py SHARED ${cpp_asmjit_files} ${FBGEMM_ALL_HIP_FILES} ${FBGEMM_HIP_HCC_LIBRARIES} HIPCC_OPTIONS ${HIP_HCC_FLAGS}) target_include_directories(fbgemm_gpu_py PUBLIC ${FBGEMM_HIP_INCLUDE} ${ROCRAND_INCLUDE} ${ROCM_SMI_INCLUDE}) -endif() -list (GET TORCH_INCLUDE_DIRS 0 TORCH_PATH) -if(EXISTS "${TORCH_PATH}/ATen/cuda/CUDAGeneratorImpl.h") - target_compile_definitions(fbgemm_gpu_py PRIVATE NEW_GENERATOR_PATH) -endif() -if(EXISTS "${TORCH_PATH}/ATen/cuda/Atomic.cuh") - target_compile_definitions(fbgemm_gpu_py PRIVATE NEW_ATOMIC_PATH) + list(GET TORCH_INCLUDE_DIRS 0 TORCH_PATH) + # if(EXISTS "${TORCH_PATH}/ATen/cuda/CUDAGeneratorImpl.h") + # target_compile_definitions(fbgemm_gpu_py PRIVATE NEW_GENERATOR_PATH) + # endif() + # if(EXISTS "${TORCH_PATH}/ATen/cuda/Atomic.cuh") + # target_compile_definitions(fbgemm_gpu_py PRIVATE NEW_ATOMIC_PATH) + # endif() +else() + add_library(fbgemm_gpu_py MODULE ${fbgemm_gpu_sources} ${gen_source_files} + ${cpp_asmjit_files} ${cpp_fbgemm_files}) + set_property(TARGET fbgemm_gpu_py PROPERTY CUDA_ARCHITECTURES + "${cuda_architectures}") + + if(NOT FBGEMM_CPU_ONLY) + target_compile_definitions(fbgemm_gpu_py PRIVATE FBGEMM_CUB_USE_NAMESPACE) + endif() endif() set_target_properties(fbgemm_gpu_py PROPERTIES PREFIX "") diff --git a/fbgemm_gpu/build.sh b/fbgemm_gpu/build.sh new file mode 100644 index 000000000..6ef2c4fba --- /dev/null +++ b/fbgemm_gpu/build.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +export MAX_JOBS=96 +gpu_arch="$(/opt/rocm/bin/rocminfo | grep -o -m 1 'gfx.*')" +export PYTORCH_ROCM_ARCH=$gpu_arch +git clean -dfx +python setup.py build develop 2>&1 | tee build.log \ No newline at end of file diff --git a/fbgemm_gpu/cmake/Hip.cmake b/fbgemm_gpu/cmake/Hip.cmake index cdc225e9d..e729ed065 100644 --- a/fbgemm_gpu/cmake/Hip.cmake +++ b/fbgemm_gpu/cmake/Hip.cmake @@ -94,6 +94,12 @@ set(CMAKE_MODULE_PATH ${HIP_PATH}/cmake ${CMAKE_MODULE_PATH}) ADD_DEFINITIONS(-DNDEBUG) ADD_DEFINITIONS(-DUSE_ROCM) +IF(NOT DEFINED ENV{PYTORCH_ROCM_ARCH}) + SET(FBGEMM_ROCM_ARCH gfx900;gfx906;gfx908;gfx90a) +ELSE() + SET(FBGEMM_ROCM_ARCH $ENV{PYTORCH_ROCM_ARCH}) +ENDIF() + # Find the HIP Package find_package(HIP) @@ -107,6 +113,15 @@ IF(HIP_FOUND) endif() message("HIP library name: ${hip_library_name}") + find_package(hip REQUIRED) + find_package(rocBLAS REQUIRED) + find_package(hipFFT REQUIRED) + find_package(hipRAND REQUIRED) + find_package(rocRAND REQUIRED) + find_package(hipSPARSE REQUIRED) + find_package(OpenMP REQUIRED) + find_package(rocPRIM REQUIRED) + set(CMAKE_HCC_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG}) set(CMAKE_HCC_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE}) FIND_LIBRARY(FBGEMM_HIP_HCC_LIBRARIES ${hip_library_name} HINTS ${HIP_PATH}/lib) @@ -145,9 +160,6 @@ IF(HIP_FOUND) set(hipcub_DIR ${HIPCUB_PATH}/lib/cmake/hipcub) set(rocthrust_DIR ${ROCTHRUST_PATH}/lib/cmake/rocthrust) set(ROCclr_DIR ${ROCM_PATH}/rocclr/lib/cmake/rocclr) - - find_package(hip REQUIRED) - set(ROCRAND_INCLUDE ${ROCRAND_PATH}/include) set(ROCM_SMI_INCLUDE ${ROCM_PATH}/rocm_smi/include) @@ -156,4 +168,9 @@ IF(HIP_FOUND) hip_include_directories(${FBGEMM_HIP_INCLUDE} ${ROCRAND_INCLUDE} ${ROCM_SMI_INCLUDE}) -ENDIF() + list (APPEND CMAKE_PREFIX_PATH /opt/rocm/hip /opt/rocm) + set(CMAKE_MODULE_PATH ${HIP_PATH}/cmake ${CMAKE_MODULE_PATH}) + +ELSE() + message(FATAL_ERROR "Not able to find HIP installation.") +ENDIF() \ No newline at end of file diff --git a/fbgemm_gpu/codegen/embedding_backward_code_generator.py b/fbgemm_gpu/codegen/embedding_backward_code_generator.py index 7b7b015d2..d43de70d9 100644 --- a/fbgemm_gpu/codegen/embedding_backward_code_generator.py +++ b/fbgemm_gpu/codegen/embedding_backward_code_generator.py @@ -1316,4 +1316,3 @@ def main() -> None: if __name__ == "__main__": main() - # hipify_gen() diff --git a/fbgemm_gpu/codegen/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_template.cu index a18ae700e..a8d9c4b70 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_template.cu @@ -968,6 +968,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ // over 48 KB per block are architecture-specific, as such they // must use dynamic shared memory (rather than statically sized // arrays) and require an explicit opt-in using cudaFuncSetAttribute()". + #ifndef __HIP_PLATFORM_HCC__ cudaFuncSetAttribute( split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_kernel_cta_per_row_1< diff --git a/fbgemm_gpu/codegen/embedding_forward_template_helpers.cuh b/fbgemm_gpu/codegen/embedding_forward_template_helpers.cuh index 4a7e517a6..24e76e8c7 100644 --- a/fbgemm_gpu/codegen/embedding_forward_template_helpers.cuh +++ b/fbgemm_gpu/codegen/embedding_forward_template_helpers.cuh @@ -10,11 +10,7 @@ #include #include #include -#if !defined(NEW_ATOMIC_PATH) -#include -#else #include -#endif // clang-format off #include "fbgemm_gpu/cub_namespace_prefix.cuh" diff --git a/fbgemm_gpu/include/fbgemm_gpu/embedding_backward_template_helpers.cuh b/fbgemm_gpu/include/fbgemm_gpu/embedding_backward_template_helpers.cuh index 4c4724cbe..e1adf9756 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/embedding_backward_template_helpers.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/embedding_backward_template_helpers.cuh @@ -10,17 +10,9 @@ #include #include #include -#if !defined(NEW_GENERATOR_PATH) -#include -#else #include -#endif #include -#if !defined(NEW_ATOMIC_PATH) -#include -#else #include -#endif #include #include diff --git a/fbgemm_gpu/include/fbgemm_gpu/hipcub_namespace_postfix.cuh b/fbgemm_gpu/include/fbgemm_gpu/hipcub_namespace_postfix.cuh deleted file mode 100644 index 8922edbba..000000000 --- a/fbgemm_gpu/include/fbgemm_gpu/hipcub_namespace_postfix.cuh +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * All rights reserved. - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#undef FBGEMM_GPU_CUB_NS_PREFIX - -#ifdef FBGEMM_CUB_USE_NAMESPACE - -#undef CUB_NS_PREFIX -#undef CUB_NS_POSTFIX - -#define FBGEMM_GPU_CUB_NS_PREFIX fbgemm_gpu:: - -#else - -#define FBGEMM_GPU_CUB_NS_PREFIX - -#endif diff --git a/fbgemm_gpu/include/fbgemm_gpu/hipcub_namespace_postfix.hpp b/fbgemm_gpu/include/fbgemm_gpu/hipcub_namespace_postfix.hpp deleted file mode 100644 index 8922edbba..000000000 --- a/fbgemm_gpu/include/fbgemm_gpu/hipcub_namespace_postfix.hpp +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * All rights reserved. - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#undef FBGEMM_GPU_CUB_NS_PREFIX - -#ifdef FBGEMM_CUB_USE_NAMESPACE - -#undef CUB_NS_PREFIX -#undef CUB_NS_POSTFIX - -#define FBGEMM_GPU_CUB_NS_PREFIX fbgemm_gpu:: - -#else - -#define FBGEMM_GPU_CUB_NS_PREFIX - -#endif diff --git a/fbgemm_gpu/include/fbgemm_gpu/hipcub_namespace_prefix.cuh b/fbgemm_gpu/include/fbgemm_gpu/hipcub_namespace_prefix.cuh deleted file mode 100644 index c977653fa..000000000 --- a/fbgemm_gpu/include/fbgemm_gpu/hipcub_namespace_prefix.cuh +++ /dev/null @@ -1,16 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * All rights reserved. - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#ifdef FBGEMM_CUB_USE_NAMESPACE - -#undef CUB_NS_PREFIX -#undef CUB_NS_POSTFIX - -#define CUB_NS_PREFIX namespace fbgemm_gpu { -#define CUB_NS_POSTFIX } // namespace fbgemm_gpu - -#endif diff --git a/fbgemm_gpu/include/fbgemm_gpu/hipcub_namespace_prefix.hpp b/fbgemm_gpu/include/fbgemm_gpu/hipcub_namespace_prefix.hpp deleted file mode 100644 index c977653fa..000000000 --- a/fbgemm_gpu/include/fbgemm_gpu/hipcub_namespace_prefix.hpp +++ /dev/null @@ -1,16 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * All rights reserved. - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#ifdef FBGEMM_CUB_USE_NAMESPACE - -#undef CUB_NS_PREFIX -#undef CUB_NS_POSTFIX - -#define CUB_NS_PREFIX namespace fbgemm_gpu { -#define CUB_NS_POSTFIX } // namespace fbgemm_gpu - -#endif diff --git a/fbgemm_gpu/run_all.sh b/fbgemm_gpu/run_all.sh deleted file mode 100755 index d7a457c08..000000000 --- a/fbgemm_gpu/run_all.sh +++ /dev/null @@ -1,39 +0,0 @@ -#!/bin/bash - -common_opts="--bag-size 55 \ - --batch-size 65536 \ - --num-embeddings 19300000 \ - --num-tables 1 \ - --iters 5" - -# Run on GPU and get PyTorch-level performance -for D in 64 128 192 256 512; do - for fp in "fp32" "fp16"; do - for alpha in 1 1.15; do - echo "D = ${D}, FP = ${fp}, alpha = ${alpha}" - python3.6 bench/split_table_batched_embeddings_benchmark.py device \ - $common_opts \ - --embedding-dim $D \ - --alpha ${alpha} \ - --weights-precision $fp - done - done -done 2>&1 | tee log_fbgemm_gpu_m1.log - -# Run on GPU and get rocprof-level performance -for D in 64 128 192 256 512; do - for fp in "fp32" "fp16"; do - for alpha in 1 1.15; do - rm -rf rocprof - rm -rf rocprof_tmp - echo "D = ${D}, FP = ${fp}, alpha = ${alpha}" - outf="rocprof_fbgemm_gpu_D_${D}_${fp}_alpha_${alpha}.csv" - rocprof --timestamp on -o $outf -d rocprof -t rocprof_tmp \ - python3.6 bench/split_table_batched_embeddings_benchmark.py device \ - $common_opts \ - --embedding-dim $D \ - --alpha ${alpha} \ - --weights-precision $fp - done - done -done diff --git a/fbgemm_gpu/src/cumem_utils.h b/fbgemm_gpu/src/cumem_utils.h index 75532f435..ef44d9fcf 100644 --- a/fbgemm_gpu/src/cumem_utils.h +++ b/fbgemm_gpu/src/cumem_utils.h @@ -1,5 +1,5 @@ /* - * Copyright (c) Facebook, Inc. and its affiliates. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. diff --git a/fbgemm_gpu/src/split_embeddings_cache_cuda.cu b/fbgemm_gpu/src/split_embeddings_cache_cuda.cu index 6e0c48ccb..12991b0ef 100644 --- a/fbgemm_gpu/src/split_embeddings_cache_cuda.cu +++ b/fbgemm_gpu/src/split_embeddings_cache_cuda.cu @@ -15,11 +15,7 @@ #include #include -#if !defined(NEW_GENERATOR_PATH) -#include -#else #include -#endif #include #include #include @@ -27,11 +23,7 @@ #include #include #include -#if !defined(NEW_ATOMIC_PATH) -#include -#else #include -#endif #include #include #include From 9db83d82584a5c7e0cda0b0f02388e2bb8573d15 Mon Sep 17 00:00:00 2001 From: liligwu Date: Mon, 9 May 2022 23:47:47 +0000 Subject: [PATCH 32/76] Minor changes that minimize the difference to upstream. --- fbgemm_gpu/CMakeLists.txt | 1 - fbgemm_gpu/build.sh | 0 fbgemm_gpu/cmake/Hip.cmake | 2 +- .../fbgemm_gpu/embedding_backward_template_helpers.cuh | 2 +- fbgemm_gpu/include/fbgemm_gpu/enum_utils.h | 4 ++-- fbgemm_gpu/src/cumem_utils.cu | 4 ++-- fbgemm_gpu/src/cumem_utils.h | 2 +- fbgemm_gpu/src/cumem_utils_host.cpp | 7 +++++-- fbgemm_gpu/src/split_embeddings_cache_cuda.cu | 2 +- .../test/split_embedding_inference_converter_test.py | 6 +++--- 10 files changed, 16 insertions(+), 14 deletions(-) mode change 100644 => 100755 fbgemm_gpu/build.sh diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt index 5ca034d2a..0cb32fb6c 100644 --- a/fbgemm_gpu/CMakeLists.txt +++ b/fbgemm_gpu/CMakeLists.txt @@ -198,7 +198,6 @@ endif() set_source_files_properties( ${gen_cpu_source_files} PROPERTIES COMPILE_OPTIONS "-mavx2;-mf16c;-mfma;-fopenmp") - set_source_files_properties( ${gen_cpu_source_files} PROPERTIES diff --git a/fbgemm_gpu/build.sh b/fbgemm_gpu/build.sh old mode 100644 new mode 100755 diff --git a/fbgemm_gpu/cmake/Hip.cmake b/fbgemm_gpu/cmake/Hip.cmake index e729ed065..0fcbcc5b2 100644 --- a/fbgemm_gpu/cmake/Hip.cmake +++ b/fbgemm_gpu/cmake/Hip.cmake @@ -173,4 +173,4 @@ IF(HIP_FOUND) ELSE() message(FATAL_ERROR "Not able to find HIP installation.") -ENDIF() \ No newline at end of file +ENDIF() diff --git a/fbgemm_gpu/include/fbgemm_gpu/embedding_backward_template_helpers.cuh b/fbgemm_gpu/include/fbgemm_gpu/embedding_backward_template_helpers.cuh index e1adf9756..0a192d85f 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/embedding_backward_template_helpers.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/embedding_backward_template_helpers.cuh @@ -105,4 +105,4 @@ DEVICE_INLINE int64_t gpuAtomicIncrement(int64_t* p) { return static_cast(atomicAdd( reinterpret_cast(p), static_cast(1))); -} \ No newline at end of file +} diff --git a/fbgemm_gpu/include/fbgemm_gpu/enum_utils.h b/fbgemm_gpu/include/fbgemm_gpu/enum_utils.h index cf2fdbf68..8a9ccdf17 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/enum_utils.h +++ b/fbgemm_gpu/include/fbgemm_gpu/enum_utils.h @@ -1,5 +1,5 @@ /* - * Copyright (c) Facebook, Inc. and its affiliates. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -82,4 +82,4 @@ static inline enum_result enum_query() { return enum_registration::enum_query(); } -} // namespace fbgemm_gpu \ No newline at end of file +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/cumem_utils.cu b/fbgemm_gpu/src/cumem_utils.cu index e40676175..b94a67a60 100644 --- a/fbgemm_gpu/src/cumem_utils.cu +++ b/fbgemm_gpu/src/cumem_utils.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) Facebook, Inc. and its affiliates. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -409,4 +409,4 @@ FBGEMM_GPU_ENUM_REGISTER_START(uvm, cudaMemoryAdvise){ } FBGEMM_GPU_ENUM_REGISTER_END #endif -} // namespace fbgemm_gpu \ No newline at end of file +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/cumem_utils.h b/fbgemm_gpu/src/cumem_utils.h index ef44d9fcf..43469fd39 100644 --- a/fbgemm_gpu/src/cumem_utils.h +++ b/fbgemm_gpu/src/cumem_utils.h @@ -52,4 +52,4 @@ Tensor uvm_to_cpu_clone(Tensor t); FBGEMM_GPU_ENUM_CREATE_TAG(uvm) -} // namespace fbgemm_gpu \ No newline at end of file +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/cumem_utils_host.cpp b/fbgemm_gpu/src/cumem_utils_host.cpp index 249d2e27e..7728ad00e 100644 --- a/fbgemm_gpu/src/cumem_utils_host.cpp +++ b/fbgemm_gpu/src/cumem_utils_host.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) Facebook, Inc. and its affiliates. + * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -15,6 +15,7 @@ using Tensor = at::Tensor; namespace fbgemm_gpu { +// Deprecated for fb namespace! Please use fbgemm namespace instead! TORCH_LIBRARY_FRAGMENT(fb, m) { m.def("is_uvm_tensor(Tensor t) -> bool", TORCH_FN(is_uvm_tensor)); m.def("uvm_storage(Tensor t) -> bool", TORCH_FN(uvm_storage)); @@ -67,10 +68,12 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "uvm_mem_advice_dont_fork(Tensor t) -> ()", TORCH_FN(uvm_mem_advice_dont_fork)); + m.def("uvm_to_cpu_clone(Tensor t) -> Tensor", TORCH_FN(uvm_to_cpu_clone)); + #ifndef __HIP_PLATFORM_HCC__ // FIXME: some advanced "cudaMemAdvise" flags are not supported by HIP. m.def(FBGEMM_GPU_ENUM_OP(uvm, fbgemm_gpu_uvm_enum_query)); #endif } -} // namespace fbgemm_gpu \ No newline at end of file +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/split_embeddings_cache_cuda.cu b/fbgemm_gpu/src/split_embeddings_cache_cuda.cu index 12991b0ef..29c18e90f 100644 --- a/fbgemm_gpu/src/split_embeddings_cache_cuda.cu +++ b/fbgemm_gpu/src/split_embeddings_cache_cuda.cu @@ -15,10 +15,10 @@ #include #include -#include #include #include #include +#include #include #include #include diff --git a/fbgemm_gpu/test/split_embedding_inference_converter_test.py b/fbgemm_gpu/test/split_embedding_inference_converter_test.py index 296f55d52..bf789862d 100644 --- a/fbgemm_gpu/test/split_embedding_inference_converter_test.py +++ b/fbgemm_gpu/test/split_embedding_inference_converter_test.py @@ -135,7 +135,7 @@ class QuantizedSplitEmbeddingsTest(unittest.TestCase): SparseType.INT2, ] ), - use_cpu=st.booleans(), + use_cpu=st.booleans() if gpu_available else st.just(True), pruning_ratio=st.sampled_from([None, 0.0]), ) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) @@ -209,7 +209,7 @@ def test_quantize_workflow( ) @given( - use_cpu=st.booleans(), + use_cpu=st.booleans() if gpu_available else st.just(True), use_array_for_index_remapping=st.booleans(), quantize_type=st.sampled_from( [ @@ -299,7 +299,7 @@ def test_l2_norm_pruning_workflow( D=st.integers(min_value=2, max_value=128), log_E=st.integers(min_value=3, max_value=5), pruning_ratio=st.floats(min_value=0.0, max_value=1.0, exclude_max=True), - use_cpu=st.booleans(), + use_cpu=st.booleans() if gpu_available else st.just(True), use_array_for_index_remapping=st.booleans(), ) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) From eabd0a8723efff806eb3c4e8947ee754f56dfdba Mon Sep 17 00:00:00 2001 From: liligwu Date: Mon, 9 May 2022 23:57:42 +0000 Subject: [PATCH 33/76] A minor change on a blank line. --- fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp index bf2c3dee9..b9e1a1400 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp @@ -186,6 +186,7 @@ class Split{{ "NoBag" if nobag else "" }}LookupFunction_{{ optimizer }}_Op : {% for (var, _) in args.saved_data %} ctx->saved_data["{{ var }}"] = {{ var }}; {% endfor %} + {% if not nobag %} #ifdef __HIP_PLATFORM_HCC__ constexpr int32_t BT_block_size = 64; From 2038008a4854279ae84cac2e7ebfb565826cbcaf Mon Sep 17 00:00:00 2001 From: liligwu Date: Tue, 10 May 2022 18:00:49 +0000 Subject: [PATCH 34/76] Fixing indentation and commented code in CMakeList.txt --- fbgemm_gpu/CMakeLists.txt | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt index 0cb32fb6c..0ff4888db 100644 --- a/fbgemm_gpu/CMakeLists.txt +++ b/fbgemm_gpu/CMakeLists.txt @@ -165,6 +165,7 @@ set(codegen_dependencies ${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/embedding_backward_template_helpers.cuh ${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/embedding_common.h ${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/fbgemm_cuda_utils.cuh + ${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/quantize_ops_gpu.h ${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/quantize_ops_utils.h ${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/split_embeddings_utils.cuh ${CMAKE_CURRENT_SOURCE_DIR}/include/fbgemm_gpu/sparse_ops_utils.h @@ -179,10 +180,10 @@ if(USE_ROCM) DEPENDS "${codegen_dependencies}") set(header_include_dir - ${CMAKE_CURRENT_SOURCE_DIR}/include - ${CMAKE_CURRENT_SOURCE_DIR}/src - ${CMAKE_CURRENT_SOURCE_DIR} - ) + ${CMAKE_CURRENT_SOURCE_DIR}/include + ${CMAKE_CURRENT_SOURCE_DIR}/src + ${CMAKE_CURRENT_SOURCE_DIR}) + hipify(CUDA_SOURCE_DIR ${PROJECT_SOURCE_DIR} HEADER_INCLUDE_DIR ${header_include_dir}) else() add_custom_command( @@ -379,12 +380,6 @@ if(USE_ROCM) HIPCC_OPTIONS ${HIP_HCC_FLAGS}) target_include_directories(fbgemm_gpu_py PUBLIC ${FBGEMM_HIP_INCLUDE} ${ROCRAND_INCLUDE} ${ROCM_SMI_INCLUDE}) list(GET TORCH_INCLUDE_DIRS 0 TORCH_PATH) - # if(EXISTS "${TORCH_PATH}/ATen/cuda/CUDAGeneratorImpl.h") - # target_compile_definitions(fbgemm_gpu_py PRIVATE NEW_GENERATOR_PATH) - # endif() - # if(EXISTS "${TORCH_PATH}/ATen/cuda/Atomic.cuh") - # target_compile_definitions(fbgemm_gpu_py PRIVATE NEW_ATOMIC_PATH) - # endif() else() add_library(fbgemm_gpu_py MODULE ${fbgemm_gpu_sources} ${gen_source_files} ${cpp_asmjit_files} ${cpp_fbgemm_files}) From 020207804e1aa08ed1a35c95802098ffcfd24fae Mon Sep 17 00:00:00 2001 From: liligwu Date: Tue, 10 May 2022 18:08:11 +0000 Subject: [PATCH 35/76] Removing build script. --- fbgemm_gpu/build.sh | 7 ------- 1 file changed, 7 deletions(-) delete mode 100755 fbgemm_gpu/build.sh diff --git a/fbgemm_gpu/build.sh b/fbgemm_gpu/build.sh deleted file mode 100755 index 6ef2c4fba..000000000 --- a/fbgemm_gpu/build.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/bin/bash - -export MAX_JOBS=96 -gpu_arch="$(/opt/rocm/bin/rocminfo | grep -o -m 1 'gfx.*')" -export PYTORCH_ROCM_ARCH=$gpu_arch -git clean -dfx -python setup.py build develop 2>&1 | tee build.log \ No newline at end of file From 9cf8856ef73abd64a924ba87dcf9e7bd9bf40c7d Mon Sep 17 00:00:00 2001 From: liligwu Date: Wed, 11 May 2022 17:46:59 +0000 Subject: [PATCH 36/76] Addressing the second batch of comments of https://github.com/pytorch/FBGEMM/pull/1102 --- fbgemm_gpu/CMakeLists.txt | 20 +++++++++---------- fbgemm_gpu/cmake/Hip.cmake | 8 ++++---- ...edding_forward_quantized_split_template.cu | 11 ---------- .../include/fbgemm_gpu/fbgemm_cuda_utils.cuh | 19 +++++------------- fbgemm_gpu/src/cumem_utils.cu | 13 +----------- fbgemm_gpu/src/quantize_ops.cu | 6 ++---- fbgemm_gpu/src/split_embeddings_cache_cuda.cu | 6 ------ fbgemm_gpu/test/test_utils.py | 3 ++- 8 files changed, 24 insertions(+), 62 deletions(-) diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt index 0ff4888db..f4010c598 100644 --- a/fbgemm_gpu/CMakeLists.txt +++ b/fbgemm_gpu/CMakeLists.txt @@ -172,12 +172,12 @@ set(codegen_dependencies ) if(USE_ROCM) - execute_process( + execute_process( COMMAND "${PYTHON_EXECUTABLE}" "${CMAKE_CODEGEN_DIR}/embedding_backward_code_generator.py" "--opensource" - DEPENDS "${codegen_dependencies}") + DEPENDS "${codegen_dependencies}") set(header_include_dir ${CMAKE_CURRENT_SOURCE_DIR}/include @@ -186,14 +186,14 @@ if(USE_ROCM) hipify(CUDA_SOURCE_DIR ${PROJECT_SOURCE_DIR} HEADER_INCLUDE_DIR ${header_include_dir}) else() - add_custom_command( - OUTPUT ${gen_cpu_source_files} ${gen_gpu_source_files} - ${gen_gpu_host_source_files} ${gen_python_files} - COMMAND - "${PYTHON_EXECUTABLE}" - "${CMAKE_CODEGEN_DIR}/embedding_backward_code_generator.py" - "--opensource" - DEPENDS "${codegen_dependencies}") + add_custom_command( + OUTPUT ${gen_cpu_source_files} ${gen_gpu_source_files} + ${gen_gpu_host_source_files} ${gen_python_files} + COMMAND + "${PYTHON_EXECUTABLE}" + "${CMAKE_CODEGEN_DIR}/embedding_backward_code_generator.py" + "--opensource" + DEPENDS "${codegen_dependencies}") endif() set_source_files_properties( diff --git a/fbgemm_gpu/cmake/Hip.cmake b/fbgemm_gpu/cmake/Hip.cmake index 0fcbcc5b2..89b8ceea8 100644 --- a/fbgemm_gpu/cmake/Hip.cmake +++ b/fbgemm_gpu/cmake/Hip.cmake @@ -60,10 +60,10 @@ ELSE() ENDIF() # THRUST_PATH -IF(DEFINED ENV{THRUST_PATH}) - SET(THRUST_PATH $ENV{THRUST_PATH}) -ELSE() +IF(NOT DEFINED ENV{THRUST_PATH}) SET(THRUST_PATH ${ROCM_PATH}/include) +ELSE() + SET(THRUST_PATH $ENV{THRUST_PATH}) ENDIF() # HIPRAND_PATH @@ -168,7 +168,7 @@ IF(HIP_FOUND) hip_include_directories(${FBGEMM_HIP_INCLUDE} ${ROCRAND_INCLUDE} ${ROCM_SMI_INCLUDE}) - list (APPEND CMAKE_PREFIX_PATH /opt/rocm/hip /opt/rocm) + list (APPEND CMAKE_PREFIX_PATH ${HIP_PATH} ${ROCM_PATH}) set(CMAKE_MODULE_PATH ${HIP_PATH}/cmake ${CMAKE_MODULE_PATH}) ELSE() diff --git a/fbgemm_gpu/codegen/embedding_forward_quantized_split_template.cu b/fbgemm_gpu/codegen/embedding_forward_quantized_split_template.cu index eea68c958..50dc63070 100644 --- a/fbgemm_gpu/codegen/embedding_forward_quantized_split_template.cu +++ b/fbgemm_gpu/codegen/embedding_forward_quantized_split_template.cu @@ -505,20 +505,9 @@ __global__ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_ found = true; dense_indices[indices_start + l_start + subwarp_id] = slot_dense_idx; } -#ifdef __HIP_PLATFORM_HCC__ - // FIXME: __any_sync with mask isn't supported by HIP yet. - // See https://fburl.com/fvy7j0lq for the similar context. - // assert false here with https://fburl.com/pfm7enw2 - if (__any_sync(subwarp_mask, found)) { -#else if (__any_sync(subwarp_mask, found)) { -#endif break; -#ifdef __HIP_PLATFORM_HCC__ - } else if (__any_sync(subwarp_mask, empty)) { -#else } else if (__any_sync(subwarp_mask, empty)) { -#endif dense_indices[indices_start + l_start + subwarp_id] = -1; break; } diff --git a/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh b/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh index d9e9adeea..e66fed97f 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh @@ -625,11 +625,7 @@ template DEVICE_INLINE T warpReduceAllSum(T val) { #pragma unroll for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) { -#ifdef __HIP_PLATFORM_HCC__ - val += __shfl_xor(val, mask); -#else val += shfl_xor(val, mask); -#endif } return val; } @@ -1216,10 +1212,7 @@ DEVICE_INLINE float_16 make_zero_float_16() { __forceinline__ __device__ __half2 hfma2(const __half2 a, const __half2 b, const __half2 c) { -#ifdef __HIP_PLATFORM_HCC__ - return __hfma2(a, b, c); -#else -#if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610 +#if (__CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610) || defined(__HIP_PLATFORM_HCC__) return __hfma2(a, b, c); #else float2 fa, fb, fc; @@ -1230,19 +1223,14 @@ hfma2(const __half2 a, const __half2 b, const __half2 c) { fc.y = fa.y * fb.y + fc.y; return __float22half2_rn(fc); #endif -#endif } __forceinline__ __device__ half hmul(half a, half b) { -#ifdef __HIP_PLATFORM_HCC__ - return __hmul(a, b); -#else -#if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610 +#if (__CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610) || defined(__HIP_PLATFORM_HCC__) return __hmul(a, b); #else return __float2half(__half2float(a) * __half2float(b)); #endif -#endif } // Reinterpret a pair of uint16_t (packed into a uint32_t) as half2, and @@ -2223,6 +2211,9 @@ DEVICE_INLINE float float16_min(float_16 val) { #undef min #undef max +// ROCm does not natively support __any_sync(). Using __ballot() +// (https://rocmdocs.amd.com/en/latest/Programming_Guides/Kernel_language.html) +// to implement __any_sync(). Note: the "warp-size" of AMD GPU is 64. #ifdef __HIP_PLATFORM_HCC__ __device__ int __any_sync(uint64_t mask, int predicate) { uint64_t predicate_bit_pattern = __ballot(predicate); diff --git a/fbgemm_gpu/src/cumem_utils.cu b/fbgemm_gpu/src/cumem_utils.cu index b94a67a60..eb3726b42 100644 --- a/fbgemm_gpu/src/cumem_utils.cu +++ b/fbgemm_gpu/src/cumem_utils.cu @@ -388,17 +388,7 @@ Tensor uvm_to_cpu_clone(Tensor t) { } FBGEMM_GPU_ENUM_GLOGAL(uvm) -#ifdef __HIP_PLATFORM_HCC__ -// FIXME: some advanced "cudaMemAdvise" flags are not supported by HIP. -FBGEMM_GPU_ENUM_REGISTER_START(uvm, hipMemoryAdvise){ - FBGEMM_GPU_ENUM_ITEM(cudaMemAdviseSetReadMostly), - FBGEMM_GPU_ENUM_ITEM(cudaMemAdviseUnsetReadMostly), - FBGEMM_GPU_ENUM_ITEM(cudaMemAdviseSetPreferredLocation), - FBGEMM_GPU_ENUM_ITEM(cudaMemAdviseUnsetPreferredLocation), - FBGEMM_GPU_ENUM_ITEM(cudaMemAdviseSetAccessedBy), - FBGEMM_GPU_ENUM_ITEM(cudaMemAdviseUnsetAccessedBy), -} FBGEMM_GPU_ENUM_REGISTER_END -#else + FBGEMM_GPU_ENUM_REGISTER_START(uvm, cudaMemoryAdvise){ FBGEMM_GPU_ENUM_ITEM(cudaMemAdviseSetReadMostly), FBGEMM_GPU_ENUM_ITEM(cudaMemAdviseUnsetReadMostly), @@ -407,6 +397,5 @@ FBGEMM_GPU_ENUM_REGISTER_START(uvm, cudaMemoryAdvise){ FBGEMM_GPU_ENUM_ITEM(cudaMemAdviseSetAccessedBy), FBGEMM_GPU_ENUM_ITEM(cudaMemAdviseUnsetAccessedBy), } FBGEMM_GPU_ENUM_REGISTER_END -#endif } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/quantize_ops.cu b/fbgemm_gpu/src/quantize_ops.cu index 23d3c13fe..fa0c4be39 100644 --- a/fbgemm_gpu/src/quantize_ops.cu +++ b/fbgemm_gpu/src/quantize_ops.cu @@ -109,12 +109,10 @@ __global__ inline void _float_to_fused8bitrowwise_cuda_kernel( template __device__ inline __attribute__((always_inline)) T quantize_ops_shfl_xor(const T val, int laneMask, int width) { -#ifdef __HIP_PLATFORM_HCC__ +#if defined(__HIP_PLATFORM_HCC__) || CUDA_VERSION < 9000 return __shfl_xor(val, laneMask, width); -#elif CUDA_VERSION >= 9000 - return __shfl_xor_sync(0xffffffff, val, laneMask, width); #else - return __shfl_xor(val, laneMask, width); + return __shfl_xor_sync(0xffffffff, val, laneMask, width); #endif } diff --git a/fbgemm_gpu/src/split_embeddings_cache_cuda.cu b/fbgemm_gpu/src/split_embeddings_cache_cuda.cu index 29c18e90f..715101eb0 100644 --- a/fbgemm_gpu/src/split_embeddings_cache_cuda.cu +++ b/fbgemm_gpu/src/split_embeddings_cache_cuda.cu @@ -435,9 +435,6 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_find_uncached_kernel( } #ifdef __HIP_PLATFORM_HCC__ - // FIXME: __any_sync with mask isn't supported by HIP yet. - // See https://fburl.com/fvy7j0lq for the similar context. - // assert false here with https://fburl.com/pfm7enw2 if (!__any_sync(0xFFFFFFFFFFFFFFFF, found)) { #else if (!__any_sync(0xFFFFFFFF, found)) { @@ -1163,9 +1160,6 @@ __global__ __launch_bounds__(kMaxThreads) void lfu_cache_find_uncached_kernel( } #ifdef __HIP_PLATFORM_HCC__ - // FIXME: __any_sync with mask isn't supported by HIP yet. - // See https://fburl.com/fvy7j0lq for the similar context. - // assert false here with https://fburl.com/pfm7enw2 if (!__any_sync(0xFFFFFFFFFFFFFFFF, found)) { #else if (!__any_sync(0xFFFFFFFF, found)) { diff --git a/fbgemm_gpu/test/test_utils.py b/fbgemm_gpu/test/test_utils.py index 82de24ea2..4d06685fe 100644 --- a/fbgemm_gpu/test/test_utils.py +++ b/fbgemm_gpu/test/test_utils.py @@ -196,4 +196,5 @@ def wrapper(*args, **kwargs): else: fn(*args, **kwargs) return wrapper - return skipIfRocmDecorator \ No newline at end of file + return skipIfRocmDecorator + \ No newline at end of file From b885322fb47138a06c69d12eca6a678975bdde90 Mon Sep 17 00:00:00 2001 From: liligwu Date: Thu, 12 May 2022 20:01:21 +0000 Subject: [PATCH 37/76] * Removing the condition on c++ standard * An indentation correction --- fbgemm_gpu/CMakeLists.txt | 4 +--- fbgemm_gpu/test/test_utils.py | 1 - 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt index f4010c598..97bdbeb6b 100644 --- a/fbgemm_gpu/CMakeLists.txt +++ b/fbgemm_gpu/CMakeLists.txt @@ -398,9 +398,7 @@ if(NVML_LIB_PATH) target_link_libraries(fbgemm_gpu_py ${NVML_LIB_PATH}) endif() target_include_directories(fbgemm_gpu_py PRIVATE ${TORCH_INCLUDE_DIRS}) -if(USE_CUDA) - set_property(TARGET fbgemm_gpu_py PROPERTY CXX_STANDARD 17) -endif() +set_property(TARGET fbgemm_gpu_py PROPERTY CXX_STANDARD 17) install(TARGETS fbgemm_gpu_py DESTINATION fbgemm_gpu) diff --git a/fbgemm_gpu/test/test_utils.py b/fbgemm_gpu/test/test_utils.py index 4d06685fe..3b8af3c59 100644 --- a/fbgemm_gpu/test/test_utils.py +++ b/fbgemm_gpu/test/test_utils.py @@ -197,4 +197,3 @@ def wrapper(*args, **kwargs): fn(*args, **kwargs) return wrapper return skipIfRocmDecorator - \ No newline at end of file From 0e3dfdb0dac282383b7047983050ef1e50749161 Mon Sep 17 00:00:00 2001 From: liligwu Date: Fri, 13 May 2022 19:22:49 +0000 Subject: [PATCH 38/76] * Changing the logic of detecting GPU vender, making CUDA as default. * Fixing the cudaMemoryAdvise mapping in hipify_torch --- fbgemm_gpu/CMakeLists.txt | 17 ++++------ fbgemm_gpu/cmake/Hip.cmake | 2 +- .../split_table_batched_embeddings_ops.py | 11 ------- fbgemm_gpu/src/cumem_utils.cu | 32 ------------------- third_party/hipify_torch | 2 +- 5 files changed, 9 insertions(+), 55 deletions(-) diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt index 97bdbeb6b..ad3309522 100644 --- a/fbgemm_gpu/CMakeLists.txt +++ b/fbgemm_gpu/CMakeLists.txt @@ -10,17 +10,13 @@ if(SKBUILD) message("The project is built using scikit-build") endif() -if(EXISTS "/usr/bin/nvidia-smi") - message("NVIDIA GPU detected.") - option(USE_CUDA "Use CUDA" ON) - option(USE_ROCM "Use ROCm" OFF) -elseif(EXISTS "/opt/rocm/bin/rocm-smi") +option(USE_CUDA "Use CUDA" ON) +option(USE_ROCM "Use ROCm" OFF) + +if((EXISTS "/bin/hipcc") AND NOT(EXISTS "/bin/nvcc")) message("AMD GPU detected.") - option(USE_CUDA "Use CUDA" OFF) - option(USE_ROCM "Use ROCm" ON) -else() - message("Unable to detect GPU vendor") - message(FATAL_ERROR "") + SET(USE_CUDA OFF) + SET(USE_ROCM ON) endif() if(FBGEMM_CPU_ONLY) @@ -28,6 +24,7 @@ if(FBGEMM_CPU_ONLY) endif() message("${message_line}") +message(STATUS "USE_ROCM ${USE_ROCM}") if(FBGEMM_CPU_ONLY OR USE_ROCM) project( diff --git a/fbgemm_gpu/cmake/Hip.cmake b/fbgemm_gpu/cmake/Hip.cmake index 89b8ceea8..e6a7d5693 100644 --- a/fbgemm_gpu/cmake/Hip.cmake +++ b/fbgemm_gpu/cmake/Hip.cmake @@ -172,5 +172,5 @@ IF(HIP_FOUND) set(CMAKE_MODULE_PATH ${HIP_PATH}/cmake ${CMAKE_MODULE_PATH}) ELSE() - message(FATAL_ERROR "Not able to find HIP installation.") + message("Not able to find HIP installation.") ENDIF() diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py index 08e78de53..7225e5897 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py @@ -21,17 +21,6 @@ from torch import Tensor, nn ASSOC = 32 if torch.version.hip is None else 64 -try: - # pyre-ignore[21] - from fbgemm_gpu import open_source # noqa: F401 -except Exception: - torch.ops.load_library( - "//deeplearning/fbgemm/fbgemm_gpu/fb:embedding_inplace_update" - ) - torch.ops.load_library( - "//deeplearning/fbgemm/fbgemm_gpu/fb:embedding_inplace_update_cpu" - ) - # Maximum number of times prefetch() can be called without # a corresponding forward() call MAX_PREFETCH_DEPTH = 100 diff --git a/fbgemm_gpu/src/cumem_utils.cu b/fbgemm_gpu/src/cumem_utils.cu index eb3726b42..b6c6675f1 100644 --- a/fbgemm_gpu/src/cumem_utils.cu +++ b/fbgemm_gpu/src/cumem_utils.cu @@ -265,35 +265,6 @@ int64_t uvm_get_guard_index(Tensor& t) { } } // namespace -#ifdef __HIP_PLATFORM_HCC__ -void uvm_cuda_mem_advise(Tensor t, int64_t hipMemoryAdvise) { - // Call hipMemAdvise on vm tensor - // See hipMemAdvise enum (automatically exported to python fbgemm_gpu.uvm - // namespace) for valid values and interface stub. - at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard; - int64_t cuda_device_index = uvm_get_guard_index(t); - int hint_device; - if (t.is_cpu()) { - hint_device = hipCpuDeviceId; - } else { - TORCH_CHECK(t.is_cuda()); - hint_device = static_cast(cuda_device_index); - } - - void* ptr = t.data_ptr(); - size_t size_bytes = at::detail::computeStorageNbytes( - t.sizes(), t.strides(), t.dtype().itemsize()); - - device_guard.set_index(cuda_device_index); - - AT_CUDA_CHECK(hipMemAdvise( - ptr, - size_bytes, - static_cast(hipMemoryAdvise), - hint_device)); - return; -} -#else void uvm_cuda_mem_advise(Tensor t, int64_t cudaMemoryAdvise) { // Call cudaMemAdvise on vm tensor // See cudaMemoryAdvise enum (automatically exported to python fbgemm_gpu.uvm @@ -314,17 +285,14 @@ void uvm_cuda_mem_advise(Tensor t, int64_t cudaMemoryAdvise) { device_guard.set_index(cuda_device_index); -#ifndef __HIP_PLATFORM_HCC__ // FIXME: some advanced "cudaMemAdvise" flags are not supported by HIP. AT_CUDA_CHECK(cudaMemAdvise( ptr, size_bytes, static_cast(cudaMemoryAdvise), hint_device)); -#endif return; } -#endif void uvm_cuda_mem_prefetch_async(Tensor t, c10::optional device_t) { // Call cudaMemPrefetchAsync on Tensor diff --git a/third_party/hipify_torch b/third_party/hipify_torch index 59e17e5fc..1840658c1 160000 --- a/third_party/hipify_torch +++ b/third_party/hipify_torch @@ -1 +1 @@ -Subproject commit 59e17e5fcf00d4fb7c0a64cd727ca08e5100d9bd +Subproject commit 1840658c184f3eeba787dae0f06c45756c1daaf5 From adefcc0b6860245219ff360fc31b72f5f1ff5be8 Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Tue, 24 May 2022 19:03:00 +0000 Subject: [PATCH 39/76] fix enum macro to avoid missing symbols --- fbgemm_gpu/include/fbgemm_gpu/enum_utils.h | 1 + fbgemm_gpu/src/cumem_utils_host.cpp | 8 -------- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/enum_utils.h b/fbgemm_gpu/include/fbgemm_gpu/enum_utils.h index 8a9ccdf17..05ffd95c0 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/enum_utils.h +++ b/fbgemm_gpu/include/fbgemm_gpu/enum_utils.h @@ -25,6 +25,7 @@ namespace fbgemm_gpu { struct fbgemm_gpu_enum_tag_##module_name #define FBGEMM_GPU_ENUM_GLOGAL(module_name) \ + template class enum_registration; \ template <> \ enum_registration* \ enum_registration::registration_list = \ diff --git a/fbgemm_gpu/src/cumem_utils_host.cpp b/fbgemm_gpu/src/cumem_utils_host.cpp index 7728ad00e..a4c046582 100644 --- a/fbgemm_gpu/src/cumem_utils_host.cpp +++ b/fbgemm_gpu/src/cumem_utils_host.cpp @@ -39,11 +39,7 @@ TORCH_LIBRARY_FRAGMENT(fb, m) { TORCH_FN(uvm_mem_advice_dont_fork)); m.def("uvm_to_cpu_clone(Tensor t) -> Tensor", TORCH_FN(uvm_to_cpu_clone)); - -#ifndef __HIP_PLATFORM_HCC__ - // FIXME: some advanced "cudaMemAdvise" flags are not supported by HIP. m.def(FBGEMM_GPU_ENUM_OP(uvm, fbgemm_gpu_uvm_enum_query)); -#endif } TORCH_LIBRARY_FRAGMENT(fbgemm, m) { @@ -69,11 +65,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { TORCH_FN(uvm_mem_advice_dont_fork)); m.def("uvm_to_cpu_clone(Tensor t) -> Tensor", TORCH_FN(uvm_to_cpu_clone)); - -#ifndef __HIP_PLATFORM_HCC__ - // FIXME: some advanced "cudaMemAdvise" flags are not supported by HIP. m.def(FBGEMM_GPU_ENUM_OP(uvm, fbgemm_gpu_uvm_enum_query)); -#endif } } // namespace fbgemm_gpu From b96bd9a74d127a0b02d9651c18b3421fcc8f82ca Mon Sep 17 00:00:00 2001 From: liligwu Date: Thu, 26 May 2022 21:41:07 +0000 Subject: [PATCH 40/76] - Changing detection of ROCm to /opt/rocm. - Skipping 4 unit tests for ROCm in uvm_test.py. --- fbgemm_gpu/CMakeLists.txt | 2 +- fbgemm_gpu/test/uvm_test.py | 18 ++++++++++++++---- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt index 79ddc157a..e00512bce 100644 --- a/fbgemm_gpu/CMakeLists.txt +++ b/fbgemm_gpu/CMakeLists.txt @@ -13,7 +13,7 @@ endif() option(USE_CUDA "Use CUDA" ON) option(USE_ROCM "Use ROCm" OFF) -if((EXISTS "/bin/hipcc") AND NOT (EXISTS "/bin/nvcc")) +if((EXISTS "/opt/rocm/") AND NOT (EXISTS "/bin/nvcc")) message("AMD GPU detected.") set(USE_CUDA OFF) set(USE_ROCM ON) diff --git a/fbgemm_gpu/test/uvm_test.py b/fbgemm_gpu/test/uvm_test.py index 5a2140552..162b82577 100644 --- a/fbgemm_gpu/test/uvm_test.py +++ b/fbgemm_gpu/test/uvm_test.py @@ -19,13 +19,16 @@ if open_source: # pyre-ignore[21] - from test_utils import gpu_available, gpu_unavailable + from test_utils import gpu_available, gpu_unavailable, skipIfRocm else: from fbgemm_gpu.test.test_utils import gpu_available, gpu_unavailable if gpu_available: - # pyre-ignore[21] - from fbgemm_gpu.uvm import cudaMemAdvise, cudaMemoryAdvise, cudaMemPrefetchAsync + if torch.version.hip: + # pyre-ignore[21] + from fbgemm_gpu.uvm import cudaMemAdvise, hipMemoryAdvise, cudaMemPrefetchAsync + else: + from fbgemm_gpu.uvm import cudaMemAdvise, cudaMemoryAdvise, cudaMemPrefetchAsync from hypothesis import given, settings, Verbosity @@ -78,8 +81,12 @@ def test_uvm_to_cpu(self, sizes: List[int], vanilla: bool) -> None: @unittest.skipIf(*gpu_unavailable) def test_enum(self) -> None: # pyre-ignore[16] - assert cudaMemoryAdvise.cudaMemAdviseSetAccessedBy.value == 5 + if torch.version.hip: + assert hipMemoryAdvise.hipMemAdviseSetAccessedBy.value == 5 + else: + assert cudaMemoryAdvise.cudaMemAdviseSetAccessedBy.value == 5 + @skipIfRocm @unittest.skipIf(*gpu_unavailable) @given( sizes=st.lists( @@ -123,6 +130,7 @@ def test_cudaMemPrefetchAsync(self, sizes: List[int], vanilla: bool) -> None: torch.cuda.synchronize(torch.device("cuda:0")) + @skipIfRocm @unittest.skipIf(*gpu_unavailable or torch.cuda.device_count() < 2) @given( sizes=st.lists( @@ -154,6 +162,7 @@ def test_uvm_to_device(self, sizes: List[int], vanilla: bool) -> None: assert torch.ops.fbgemm.uvm_storage(second_t) assert second_t.device == device_prototype.device + @skipIfRocm @unittest.skipIf(*gpu_unavailable) @given( sizes=st.lists( @@ -183,6 +192,7 @@ def test_uvm_slice(self, sizes: List[int], vanilla: bool) -> None: assert torch.ops.fbgemm.is_uvm_tensor(uvm_slice) assert torch.ops.fbgemm.uvm_storage(cpu_slice) + @skipIfRocm @unittest.skipIf(*gpu_unavailable) @given( sizes=st.lists( From 3a1c2a3c1a8fb175577db5e1cc792a73189c8a00 Mon Sep 17 00:00:00 2001 From: Aswin John Mathews Date: Mon, 23 May 2022 17:42:58 +0000 Subject: [PATCH 41/76] Cherry-pick 33c5e061e7aa47b8efbcb7dee83580b3844f6d67 --- fbgemm_gpu/bench/quantize_ops_benchmark.py | 38 ++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/fbgemm_gpu/bench/quantize_ops_benchmark.py b/fbgemm_gpu/bench/quantize_ops_benchmark.py index da758e21c..8a14d1738 100644 --- a/fbgemm_gpu/bench/quantize_ops_benchmark.py +++ b/fbgemm_gpu/bench/quantize_ops_benchmark.py @@ -33,6 +33,7 @@ def cli() -> None: pass +<<<<<<< HEAD @cli.command() @click.option("--flush-gpu-cache-size-mb", default=0) @click.option("--iters", default=100) @@ -44,6 +45,9 @@ def cli() -> None: num_rows=st.sampled_from([2**n for n in range(4, 10)]), ) def bench( +======= +def bench_impl( +>>>>>>> 33c5e06... update flush_gpu_cache_size_mb: int, iters: int, num_columns: int, @@ -138,6 +142,40 @@ def bench( logging.info(f"{k} time per iter: {t_time * 1.0e6:.0f}us") +@settings(max_examples=10, deadline=None) +# pyre-ignore +@given( + num_columns=st.sampled_from([2 ** n for n in range(4, 10)]), + num_rows=st.sampled_from([2 ** n for n in range(4, 10)]), +) +def bench_spectrum( + flush_gpu_cache_size_mb: int, + iters: int, + num_columns: int, + num_rows: int, + warmup_runs: int, +) -> None: + bench_impl(flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, iters=iters, num_columns=num_columns, num_rows=num_rows, warmup_runs=warmup_runs) + +@cli.command() +@click.option("--flush-gpu-cache-size-mb", default=0) +@click.option("--iters", default=100) +@click.option("--num-columns", default=-1) +@click.option("--num-rows", default=-1) +@click.option("--warmup-runs", default=2) +def bench( + flush_gpu_cache_size_mb: int, + iters: int, + num_columns: int, + num_rows: int, + warmup_runs: int, +) -> None: + if num_columns == -1 or num_rows == -1: + bench_spectrum(flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, iters=iters, warmup_runs=warmup_runs) + else: + bench_impl(flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, iters=iters, num_columns=num_columns, num_rows=num_rows, warmup_runs=warmup_runs) + + @cli.command() @click.option("--flush-gpu-cache-size-mb", default=0) @click.option("--iters", default=100) From f6642672e5581c0fdf924dec470132e871060bb9 Mon Sep 17 00:00:00 2001 From: liligwu Date: Thu, 26 May 2022 22:10:07 +0000 Subject: [PATCH 42/76] Resolve the conflict in quantize_ops_benchmark.py --- fbgemm_gpu/bench/quantize_ops_benchmark.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/fbgemm_gpu/bench/quantize_ops_benchmark.py b/fbgemm_gpu/bench/quantize_ops_benchmark.py index 8a14d1738..c4cfc3378 100644 --- a/fbgemm_gpu/bench/quantize_ops_benchmark.py +++ b/fbgemm_gpu/bench/quantize_ops_benchmark.py @@ -33,21 +33,7 @@ def cli() -> None: pass -<<<<<<< HEAD -@cli.command() -@click.option("--flush-gpu-cache-size-mb", default=0) -@click.option("--iters", default=100) -@click.option("--warmup-runs", default=2) -@settings(max_examples=10, deadline=None) -# pyre-ignore -@given( - num_columns=st.sampled_from([2**n for n in range(4, 10)]), - num_rows=st.sampled_from([2**n for n in range(4, 10)]), -) -def bench( -======= def bench_impl( ->>>>>>> 33c5e06... update flush_gpu_cache_size_mb: int, iters: int, num_columns: int, From 2cc36568512937623546f89bacf7d3b7d055b95d Mon Sep 17 00:00:00 2001 From: Li Li Date: Wed, 29 Jun 2022 04:46:33 +0000 Subject: [PATCH 43/76] add rocm runner --- .github/workflows/fbgemmci.yml | 57 ++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/.github/workflows/fbgemmci.yml b/.github/workflows/fbgemmci.yml index 54b920666..9a0a1ebdc 100644 --- a/.github/workflows/fbgemmci.yml +++ b/.github/workflows/fbgemmci.yml @@ -308,6 +308,63 @@ jobs: python3 -c "import fbgemm_gpu" python3 -c "import fbgemm_gpu.split_embedding_codegen_lookup_invokers" + test_amd_gpu: + runs-on: rocm + strategy: + matrix: + os: [ubuntu-latest] + + steps: + - uses: actions/checkout@v2 + + - name: Install ROCm 5.1.1 + shell: bash + run: | + sudo update-alternatives --install /usr/bin/python python /usr/bin/python3 10 + wget https://repo.radeon.com/amdgpu-install/22.10.1/ubuntu/focal/amdgpu-install_22.10.1.50101-1_all.deb + export DEBIAN_FRONTEND=noninteractive + sudo apt install -y ./amdgpu-install_22.10.1.50101-1_all.deb + amdgpu-install -y --usecase=hiplibsdk,rocm + sudo rm amdgpu-install_22.10.1.50101-1_all.deb + + - name: Install dependencies + shell: bash + run: | + sudo apt-get update + sudo apt-get -y install git pip python3-dev mesa-common-dev clang comgr libopenblas-dev jp intel-mkl-full locales libnuma-dev + sudo apt-get install -y hipify-clang || true + sudo pip install cmake scikit-build ninja jinja2 numpy hypothesis --no-input + sudo apt-get clean + # Install pytorch 1.11 as required by fbgemm_gpu + sudo pip install --pre torch torchvision --extra-index-url https://download.pytorch.org/whl/nightly/rocm5.1.1/ + + - name: Checkout submodules + shell: bash + run: | + cd fbgemm_gpu + git submodule sync + git submodule update --init --recursive + + - name: Build fbgemm_gpu + shell: bash + run: | + sudo update-alternatives --install /usr/bin/python python /usr/bin/python3 10 + cd fbgemm_gpu + # build for MI250 only to save time. + sudo PYTORCH_ROCM_ARCH=gfx90a python3 setup.py build develop + + - name: Test fbgemm_gpu installation + shell: bash + run: | + cd fbgemm_gpu + cd test + python3 input_combine_test.py + python3 quantize_ops_test.py + # disable sparse_ops_test for ROCm at the moment, since a "core dumped" aborts the test on the Github actions vm that has no GPUs. + # python3 sparse_ops_test.py + python3 -c "import fbgemm_gpu" + python3 -c "import fbgemm_gpu.split_embedding_codegen_lookup_invokers" + build_cpu_only: runs-on: ${{ matrix.os }} strategy: From b9ed7dae623aa4859a8ca51070cccb950ade320d Mon Sep 17 00:00:00 2001 From: Li Li Date: Wed, 29 Jun 2022 04:51:22 +0000 Subject: [PATCH 44/76] remove cd fbgemm_gpu --- .github/workflows/fbgemmci.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/fbgemmci.yml b/.github/workflows/fbgemmci.yml index 9a0a1ebdc..8fdab55b4 100644 --- a/.github/workflows/fbgemmci.yml +++ b/.github/workflows/fbgemmci.yml @@ -320,7 +320,6 @@ jobs: - name: Install ROCm 5.1.1 shell: bash run: | - sudo update-alternatives --install /usr/bin/python python /usr/bin/python3 10 wget https://repo.radeon.com/amdgpu-install/22.10.1/ubuntu/focal/amdgpu-install_22.10.1.50101-1_all.deb export DEBIAN_FRONTEND=noninteractive sudo apt install -y ./amdgpu-install_22.10.1.50101-1_all.deb From d4ebb6d7b66539c0fa93a5dafaeeec9500b87741 Mon Sep 17 00:00:00 2001 From: Li Li Date: Wed, 29 Jun 2022 04:57:17 +0000 Subject: [PATCH 45/76] remove sodu --- .github/workflows/fbgemmci.yml | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/.github/workflows/fbgemmci.yml b/.github/workflows/fbgemmci.yml index 8fdab55b4..64582b25e 100644 --- a/.github/workflows/fbgemmci.yml +++ b/.github/workflows/fbgemmci.yml @@ -322,20 +322,20 @@ jobs: run: | wget https://repo.radeon.com/amdgpu-install/22.10.1/ubuntu/focal/amdgpu-install_22.10.1.50101-1_all.deb export DEBIAN_FRONTEND=noninteractive - sudo apt install -y ./amdgpu-install_22.10.1.50101-1_all.deb + apt install -y ./amdgpu-install_22.10.1.50101-1_all.deb amdgpu-install -y --usecase=hiplibsdk,rocm - sudo rm amdgpu-install_22.10.1.50101-1_all.deb + rm amdgpu-install_22.10.1.50101-1_all.deb - name: Install dependencies shell: bash run: | - sudo apt-get update - sudo apt-get -y install git pip python3-dev mesa-common-dev clang comgr libopenblas-dev jp intel-mkl-full locales libnuma-dev - sudo apt-get install -y hipify-clang || true - sudo pip install cmake scikit-build ninja jinja2 numpy hypothesis --no-input - sudo apt-get clean + apt-get update + apt-get -y install git pip python3-dev mesa-common-dev clang comgr libopenblas-dev jp intel-mkl-full locales libnuma-dev + apt-get install -y hipify-clang || true + pip install cmake scikit-build ninja jinja2 numpy hypothesis --no-input + apt-get clean # Install pytorch 1.11 as required by fbgemm_gpu - sudo pip install --pre torch torchvision --extra-index-url https://download.pytorch.org/whl/nightly/rocm5.1.1/ + pip install --pre torch torchvision --extra-index-url https://download.pytorch.org/whl/nightly/rocm5.1.1/ - name: Checkout submodules shell: bash @@ -347,10 +347,9 @@ jobs: - name: Build fbgemm_gpu shell: bash run: | - sudo update-alternatives --install /usr/bin/python python /usr/bin/python3 10 cd fbgemm_gpu # build for MI250 only to save time. - sudo PYTORCH_ROCM_ARCH=gfx90a python3 setup.py build develop + PYTORCH_ROCM_ARCH=gfx90a python3 setup.py build develop - name: Test fbgemm_gpu installation shell: bash From 6552b13ee78da75d331b52f9161987c3d41956aa Mon Sep 17 00:00:00 2001 From: Li Li Date: Thu, 30 Jun 2022 02:28:53 +0000 Subject: [PATCH 46/76] switch to docker container --- .github/workflows/fbgemmci.yml | 63 +++++++++++----------------------- 1 file changed, 20 insertions(+), 43 deletions(-) diff --git a/.github/workflows/fbgemmci.yml b/.github/workflows/fbgemmci.yml index 64582b25e..6e00b6f27 100644 --- a/.github/workflows/fbgemmci.yml +++ b/.github/workflows/fbgemmci.yml @@ -317,51 +317,28 @@ jobs: steps: - uses: actions/checkout@v2 - - name: Install ROCm 5.1.1 - shell: bash - run: | - wget https://repo.radeon.com/amdgpu-install/22.10.1/ubuntu/focal/amdgpu-install_22.10.1.50101-1_all.deb - export DEBIAN_FRONTEND=noninteractive - apt install -y ./amdgpu-install_22.10.1.50101-1_all.deb - amdgpu-install -y --usecase=hiplibsdk,rocm - rm amdgpu-install_22.10.1.50101-1_all.deb - - - name: Install dependencies - shell: bash - run: | - apt-get update - apt-get -y install git pip python3-dev mesa-common-dev clang comgr libopenblas-dev jp intel-mkl-full locales libnuma-dev - apt-get install -y hipify-clang || true - pip install cmake scikit-build ninja jinja2 numpy hypothesis --no-input - apt-get clean - # Install pytorch 1.11 as required by fbgemm_gpu - pip install --pre torch torchvision --extra-index-url https://download.pytorch.org/whl/nightly/rocm5.1.1/ - - - name: Checkout submodules - shell: bash - run: | - cd fbgemm_gpu - git submodule sync - git submodule update --init --recursive - - - name: Build fbgemm_gpu - shell: bash - run: | - cd fbgemm_gpu - # build for MI250 only to save time. - PYTORCH_ROCM_ARCH=gfx90a python3 setup.py build develop - - - name: Test fbgemm_gpu installation + - name: build fbgemm_gpu and test shell: bash run: | - cd fbgemm_gpu - cd test - python3 input_combine_test.py - python3 quantize_ops_test.py - # disable sparse_ops_test for ROCm at the moment, since a "core dumped" aborts the test on the Github actions vm that has no GPUs. - # python3 sparse_ops_test.py - python3 -c "import fbgemm_gpu" - python3 -c "import fbgemm_gpu.split_embedding_codegen_lookup_invokers" + set -eux + env + DOCKER_IMAGE=rocm/fbgemm-private:latest_v0.1.1 + docker pull $DOCKER_IMAGE + JENKINS_REPO_DIR=fbgemm-private-jenkins + JENKINS_REPO_DIR_BAREMETAL=$PWD/$JENKINS_REPO_DIR + JENKINS_REPO_DIR_DOCKER=/workspace/$JENKINS_REPO_DIR + DOCKER_OPTIONS="\ + --network=host \ + --ipc=host \ + --shm-size 16G \ + --group-add video \ + --cap-add=SYS_PTRACE \ + --security-opt seccomp=unconfined \ + --device=/dev/kfd \ + --device=/dev/dri \ + -v $JENKINS_REPO_DIR_BAREMETAL:$JENKINS_REPO_DIR_DOCKER + " + docker run $DOCKER_OPTIONS $DOCKER_IMAGE /scripts/build_and_run_unit_tests.sh $JENKINS_REPO_DIR_DOCKER build_cpu_only: runs-on: ${{ matrix.os }} From cc8a0c2b4f0023ce5639be40501b5285d488a141 Mon Sep 17 00:00:00 2001 From: Li Li Date: Thu, 30 Jun 2022 02:46:40 +0000 Subject: [PATCH 47/76] add docker login --- .github/workflows/fbgemmci.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/fbgemmci.yml b/.github/workflows/fbgemmci.yml index 6e00b6f27..6b3014a9c 100644 --- a/.github/workflows/fbgemmci.yml +++ b/.github/workflows/fbgemmci.yml @@ -321,7 +321,9 @@ jobs: shell: bash run: | set -eux - env + env + echo "Zl977127!" >> password.txt + cat ~/password.txt | docker login --username liligwu --password-stdin DOCKER_IMAGE=rocm/fbgemm-private:latest_v0.1.1 docker pull $DOCKER_IMAGE JENKINS_REPO_DIR=fbgemm-private-jenkins From 3d714f96e21b61ed1e759c984f3bd9cbedc2b830 Mon Sep 17 00:00:00 2001 From: Li Li Date: Thu, 30 Jun 2022 02:48:51 +0000 Subject: [PATCH 48/76] change password file --- .github/workflows/fbgemmci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/fbgemmci.yml b/.github/workflows/fbgemmci.yml index 6b3014a9c..1af66dfde 100644 --- a/.github/workflows/fbgemmci.yml +++ b/.github/workflows/fbgemmci.yml @@ -322,8 +322,8 @@ jobs: run: | set -eux env - echo "Zl977127!" >> password.txt - cat ~/password.txt | docker login --username liligwu --password-stdin + sudo echo "Zl977127!" | tee password.txt + cat password.txt | docker login --username liligwu --password-stdin DOCKER_IMAGE=rocm/fbgemm-private:latest_v0.1.1 docker pull $DOCKER_IMAGE JENKINS_REPO_DIR=fbgemm-private-jenkins From a2998a654bf807b4f603bec3f6f613ffe0036acf Mon Sep 17 00:00:00 2001 From: Li Li Date: Thu, 30 Jun 2022 03:09:51 +0000 Subject: [PATCH 49/76] clone repo to bearmetal --- .github/workflows/fbgemmci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/fbgemmci.yml b/.github/workflows/fbgemmci.yml index 1af66dfde..ee02468f5 100644 --- a/.github/workflows/fbgemmci.yml +++ b/.github/workflows/fbgemmci.yml @@ -329,6 +329,7 @@ jobs: JENKINS_REPO_DIR=fbgemm-private-jenkins JENKINS_REPO_DIR_BAREMETAL=$PWD/$JENKINS_REPO_DIR JENKINS_REPO_DIR_DOCKER=/workspace/$JENKINS_REPO_DIR + git clone https://streamhsa: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@github.com/ROCmSoftwarePlatform/FBGEMM.git $JENKINS_REPO_DIR_BAREMETAL DOCKER_OPTIONS="\ --network=host \ --ipc=host \ From 6c2daa4510c1e99c13bd08cb0e4b2f017936fd6b Mon Sep 17 00:00:00 2001 From: Li Li Date: Tue, 5 Jul 2022 22:00:08 +0000 Subject: [PATCH 50/76] migrate to rocm/pytorch image, which is public --- .github/workflows/fbgemmci.yml | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/.github/workflows/fbgemmci.yml b/.github/workflows/fbgemmci.yml index ee02468f5..45a2addc5 100644 --- a/.github/workflows/fbgemmci.yml +++ b/.github/workflows/fbgemmci.yml @@ -316,20 +316,31 @@ jobs: steps: - uses: actions/checkout@v2 + with: + ref: ${{ github.ref }} + path: 'fbgemm-private-jenkins' + + - name: Checkout script + uses: actions/checkout@v2 + with: + repository: https://github.com/liligwu/fbgemm_ci.git + path: fbgemm_ci - name: build fbgemm_gpu and test shell: bash run: | set -eux env - sudo echo "Zl977127!" | tee password.txt - cat password.txt | docker login --username liligwu --password-stdin - DOCKER_IMAGE=rocm/fbgemm-private:latest_v0.1.1 + ls -l + DOCKER_IMAGE=rocm/pytorch:rocm5.1.3_ubuntu20.04_py3.7_pytorch_1.11.0 docker pull $DOCKER_IMAGE JENKINS_REPO_DIR=fbgemm-private-jenkins JENKINS_REPO_DIR_BAREMETAL=$PWD/$JENKINS_REPO_DIR JENKINS_REPO_DIR_DOCKER=/workspace/$JENKINS_REPO_DIR - git clone https://streamhsa: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@github.com/ROCmSoftwarePlatform/FBGEMM.git $JENKINS_REPO_DIR_BAREMETAL + SCRIPT_DIR=fbgemm_ci + SCRIPT_DIR_BAREMETAL=$PWD/$SCRIPT_DIR + SCRIPT_DIR_DOCKER=/workspace/$SCRIPT_DIR + # git clone https://github.com/ROCmSoftwarePlatform/FBGEMM.git $JENKINS_REPO_DIR_BAREMETAL DOCKER_OPTIONS="\ --network=host \ --ipc=host \ @@ -339,9 +350,10 @@ jobs: --security-opt seccomp=unconfined \ --device=/dev/kfd \ --device=/dev/dri \ - -v $JENKINS_REPO_DIR_BAREMETAL:$JENKINS_REPO_DIR_DOCKER + -v $JENKINS_REPO_DIR_BAREMETAL:$JENKINS_REPO_DIR_DOCKER \ + -v $SCRIPT_DIR_BAREMETAL:$SCRIPT_DIR_DOCKER " - docker run $DOCKER_OPTIONS $DOCKER_IMAGE /scripts/build_and_run_unit_tests.sh $JENKINS_REPO_DIR_DOCKER + docker run $DOCKER_OPTIONS $DOCKER_IMAGE $SCRIPT_DIR_DOCKER/docker/scripts/build_and_run_unit_tests.sh $JENKINS_REPO_DIR_DOCKER build_cpu_only: runs-on: ${{ matrix.os }} From 528c0aaf30850df376518da6bd4c0cf2e5778071 Mon Sep 17 00:00:00 2001 From: Li Li Date: Tue, 5 Jul 2022 22:36:15 +0000 Subject: [PATCH 51/76] add pre-checkout to help actions/checkout --- .github/workflows/fbgemmci.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/fbgemmci.yml b/.github/workflows/fbgemmci.yml index 45a2addc5..fa8fa9042 100644 --- a/.github/workflows/fbgemmci.yml +++ b/.github/workflows/fbgemmci.yml @@ -315,6 +315,11 @@ jobs: os: [ubuntu-latest] steps: + - name: pre-checkout + shell: bash + run: | + sudo chown -R $USER:$USER ${{ github.workspace }}/fbgemm-private-jenkins + - uses: actions/checkout@v2 with: ref: ${{ github.ref }} From ae1a87f515fd9f5a747d7811a76ce63aa8ef62b3 Mon Sep 17 00:00:00 2001 From: Li Li Date: Tue, 5 Jul 2022 22:39:25 +0000 Subject: [PATCH 52/76] fix checkout script --- .github/workflows/fbgemmci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/fbgemmci.yml b/.github/workflows/fbgemmci.yml index fa8fa9042..c8342f622 100644 --- a/.github/workflows/fbgemmci.yml +++ b/.github/workflows/fbgemmci.yml @@ -328,7 +328,7 @@ jobs: - name: Checkout script uses: actions/checkout@v2 with: - repository: https://github.com/liligwu/fbgemm_ci.git + repository: liligwu/fbgemm_ci path: fbgemm_ci - name: build fbgemm_gpu and test From 21888fcf998df8bebf75cd37eafaa65328872686 Mon Sep 17 00:00:00 2001 From: Li Li Date: Tue, 5 Jul 2022 22:49:44 +0000 Subject: [PATCH 53/76] fix baremetal path --- .github/workflows/fbgemmci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/fbgemmci.yml b/.github/workflows/fbgemmci.yml index c8342f622..b167cf5bd 100644 --- a/.github/workflows/fbgemmci.yml +++ b/.github/workflows/fbgemmci.yml @@ -340,7 +340,7 @@ jobs: DOCKER_IMAGE=rocm/pytorch:rocm5.1.3_ubuntu20.04_py3.7_pytorch_1.11.0 docker pull $DOCKER_IMAGE JENKINS_REPO_DIR=fbgemm-private-jenkins - JENKINS_REPO_DIR_BAREMETAL=$PWD/$JENKINS_REPO_DIR + JENKINS_REPO_DIR_BAREMETAL=$PWD JENKINS_REPO_DIR_DOCKER=/workspace/$JENKINS_REPO_DIR SCRIPT_DIR=fbgemm_ci SCRIPT_DIR_BAREMETAL=$PWD/$SCRIPT_DIR From c5b158d473e3358bd02075a3f27e07c320cdce0e Mon Sep 17 00:00:00 2001 From: Li Li Date: Tue, 5 Jul 2022 23:13:20 +0000 Subject: [PATCH 54/76] fix baremetal path --- .github/workflows/fbgemmci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/fbgemmci.yml b/.github/workflows/fbgemmci.yml index b167cf5bd..7c5ae5c56 100644 --- a/.github/workflows/fbgemmci.yml +++ b/.github/workflows/fbgemmci.yml @@ -319,6 +319,7 @@ jobs: shell: bash run: | sudo chown -R $USER:$USER ${{ github.workspace }}/fbgemm-private-jenkins + sudo rm -rf ${{ github.workspace }}/fbgemm-private-jenkins - uses: actions/checkout@v2 with: From 888550ea43fe838a24ba96aa32069d8afb8ed268 Mon Sep 17 00:00:00 2001 From: Li Li Date: Tue, 5 Jul 2022 23:21:51 +0000 Subject: [PATCH 55/76] checkout submodules --- .github/workflows/fbgemmci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/fbgemmci.yml b/.github/workflows/fbgemmci.yml index 7c5ae5c56..c7b29ab06 100644 --- a/.github/workflows/fbgemmci.yml +++ b/.github/workflows/fbgemmci.yml @@ -324,6 +324,7 @@ jobs: - uses: actions/checkout@v2 with: ref: ${{ github.ref }} + lfs: 'true' path: 'fbgemm-private-jenkins' - name: Checkout script From 6103266a890abf1d102abead7b5fc7b2efae8245 Mon Sep 17 00:00:00 2001 From: Li Li Date: Tue, 5 Jul 2022 23:24:57 +0000 Subject: [PATCH 56/76] upgrade git --- .github/workflows/fbgemmci.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/fbgemmci.yml b/.github/workflows/fbgemmci.yml index c7b29ab06..270290ade 100644 --- a/.github/workflows/fbgemmci.yml +++ b/.github/workflows/fbgemmci.yml @@ -320,6 +320,8 @@ jobs: run: | sudo chown -R $USER:$USER ${{ github.workspace }}/fbgemm-private-jenkins sudo rm -rf ${{ github.workspace }}/fbgemm-private-jenkins + sudo apt update + sudo apt install --only-upgrade git - uses: actions/checkout@v2 with: From 51f90ec74a2fb840a17348f3b07608beb239601f Mon Sep 17 00:00:00 2001 From: Li Li Date: Tue, 5 Jul 2022 23:29:38 +0000 Subject: [PATCH 57/76] upgrade git --- .github/workflows/fbgemmci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/fbgemmci.yml b/.github/workflows/fbgemmci.yml index 270290ade..e8e54a230 100644 --- a/.github/workflows/fbgemmci.yml +++ b/.github/workflows/fbgemmci.yml @@ -320,6 +320,7 @@ jobs: run: | sudo chown -R $USER:$USER ${{ github.workspace }}/fbgemm-private-jenkins sudo rm -rf ${{ github.workspace }}/fbgemm-private-jenkins + sudo add-apt-repository ppa:git-core/ppa sudo apt update sudo apt install --only-upgrade git From b5d9d4f0b27db105be1152b7d5b0480b21c922ce Mon Sep 17 00:00:00 2001 From: Li Li Date: Tue, 5 Jul 2022 23:31:39 +0000 Subject: [PATCH 58/76] upgrade git --- .github/workflows/fbgemmci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/fbgemmci.yml b/.github/workflows/fbgemmci.yml index e8e54a230..b060eb436 100644 --- a/.github/workflows/fbgemmci.yml +++ b/.github/workflows/fbgemmci.yml @@ -322,7 +322,7 @@ jobs: sudo rm -rf ${{ github.workspace }}/fbgemm-private-jenkins sudo add-apt-repository ppa:git-core/ppa sudo apt update - sudo apt install --only-upgrade git + sudo apt -y install --only-upgrade git - uses: actions/checkout@v2 with: From 213667a2de576ebc3f8da48dd59b880729401bb4 Mon Sep 17 00:00:00 2001 From: Li Li Date: Tue, 5 Jul 2022 23:35:34 +0000 Subject: [PATCH 59/76] add logic that cheks workspace directory --- .github/workflows/fbgemmci.yml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/.github/workflows/fbgemmci.yml b/.github/workflows/fbgemmci.yml index b060eb436..2a49ae0ea 100644 --- a/.github/workflows/fbgemmci.yml +++ b/.github/workflows/fbgemmci.yml @@ -318,8 +318,11 @@ jobs: - name: pre-checkout shell: bash run: | - sudo chown -R $USER:$USER ${{ github.workspace }}/fbgemm-private-jenkins - sudo rm -rf ${{ github.workspace }}/fbgemm-private-jenkins + if [ -d ${{ github.workspace }}/fbgemm-private-jenkins ] + then + sudo chown -R $USER:$USER ${{ github.workspace }}/fbgemm-private-jenkins + sudo rm -rf ${{ github.workspace }}/fbgemm-private-jenkins + fi sudo add-apt-repository ppa:git-core/ppa sudo apt update sudo apt -y install --only-upgrade git From 12dec216666b6ad50d75b6c621d1ac8f52a6b500 Mon Sep 17 00:00:00 2001 From: Li Li Date: Tue, 5 Jul 2022 23:38:03 +0000 Subject: [PATCH 60/76] checkout suubmodules --- .github/workflows/fbgemmci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/fbgemmci.yml b/.github/workflows/fbgemmci.yml index 2a49ae0ea..f256463f8 100644 --- a/.github/workflows/fbgemmci.yml +++ b/.github/workflows/fbgemmci.yml @@ -330,7 +330,7 @@ jobs: - uses: actions/checkout@v2 with: ref: ${{ github.ref }} - lfs: 'true' + submodules: 'true' path: 'fbgemm-private-jenkins' - name: Checkout script From c74102febb3abf843039d9bbac01fcecbbe421fa Mon Sep 17 00:00:00 2001 From: Li Li Date: Wed, 6 Jul 2022 00:00:40 +0000 Subject: [PATCH 61/76] debug thirdparty --- fbgemm_gpu/CMakeLists.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt index 1dcfbfa37..7ff3df701 100644 --- a/fbgemm_gpu/CMakeLists.txt +++ b/fbgemm_gpu/CMakeLists.txt @@ -43,6 +43,8 @@ find_package(PythonExtensions REQUIRED) set(FBGEMM ${CMAKE_CURRENT_SOURCE_DIR}/..) set(THIRDPARTY ${FBGEMM}/third_party) +message("${message_line}") +message(STATUS "THIRDPARTY ${THIRDPARTY}") if(DEFINED GLIBCXX_USE_CXX11_ABI) if(${GLIBCXX_USE_CXX11_ABI} EQUAL 1) @@ -73,7 +75,7 @@ if(USE_ROCM) include(Hipify) message("${message_line}") - message(STATUS "hip found ${ROCM_FOUND}") + message(STATUS "hip found ${HIP_FOUND}") endif() # From a639c591f2ad5487d6d0e19f9c76a0d7327a838f Mon Sep 17 00:00:00 2001 From: Li Li Date: Wed, 6 Jul 2022 00:17:00 +0000 Subject: [PATCH 62/76] checkout branch --- .github/workflows/fbgemmci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/fbgemmci.yml b/.github/workflows/fbgemmci.yml index f256463f8..47915ed8b 100644 --- a/.github/workflows/fbgemmci.yml +++ b/.github/workflows/fbgemmci.yml @@ -329,7 +329,7 @@ jobs: - uses: actions/checkout@v2 with: - ref: ${{ github.ref }} + ref: 'add_rocm_runner' submodules: 'true' path: 'fbgemm-private-jenkins' From d5e2685825eef1f4d6447175eed0094ea87f7fe8 Mon Sep 17 00:00:00 2001 From: Li Li Date: Wed, 6 Jul 2022 00:24:40 +0000 Subject: [PATCH 63/76] checkout branch --- .github/workflows/fbgemmci.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/fbgemmci.yml b/.github/workflows/fbgemmci.yml index 47915ed8b..bf368d25e 100644 --- a/.github/workflows/fbgemmci.yml +++ b/.github/workflows/fbgemmci.yml @@ -331,7 +331,6 @@ jobs: with: ref: 'add_rocm_runner' submodules: 'true' - path: 'fbgemm-private-jenkins' - name: Checkout script uses: actions/checkout@v2 From a13b45e7a50e2dfe6a09d7ea756ccbef46596c3d Mon Sep 17 00:00:00 2001 From: Li Li Date: Wed, 6 Jul 2022 00:26:58 +0000 Subject: [PATCH 64/76] change working dir --- .github/workflows/fbgemmci.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/fbgemmci.yml b/.github/workflows/fbgemmci.yml index bf368d25e..8301ecf74 100644 --- a/.github/workflows/fbgemmci.yml +++ b/.github/workflows/fbgemmci.yml @@ -318,10 +318,10 @@ jobs: - name: pre-checkout shell: bash run: | - if [ -d ${{ github.workspace }}/fbgemm-private-jenkins ] + if [ -d ${{ github.workspace }} ] then - sudo chown -R $USER:$USER ${{ github.workspace }}/fbgemm-private-jenkins - sudo rm -rf ${{ github.workspace }}/fbgemm-private-jenkins + sudo chown -R $USER:$USER ${{ github.workspace }} + sudo rm -rf ${{ github.workspace }} fi sudo add-apt-repository ppa:git-core/ppa sudo apt update From 4aad1e9ce732aca6b4aecfd3516ff852fff60b7b Mon Sep 17 00:00:00 2001 From: Li Li Date: Wed, 6 Jul 2022 00:28:38 +0000 Subject: [PATCH 65/76] change working dir --- .github/workflows/fbgemmci.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/fbgemmci.yml b/.github/workflows/fbgemmci.yml index 8301ecf74..b04aa4d99 100644 --- a/.github/workflows/fbgemmci.yml +++ b/.github/workflows/fbgemmci.yml @@ -321,7 +321,6 @@ jobs: if [ -d ${{ github.workspace }} ] then sudo chown -R $USER:$USER ${{ github.workspace }} - sudo rm -rf ${{ github.workspace }} fi sudo add-apt-repository ppa:git-core/ppa sudo apt update From 1993520609c4829a903043f4abe93bbfc8223584 Mon Sep 17 00:00:00 2001 From: Li Li Date: Wed, 6 Jul 2022 00:41:56 +0000 Subject: [PATCH 66/76] checkout current branch --- .github/workflows/fbgemmci.yml | 2 +- fbgemm_gpu/CMakeLists.txt | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/.github/workflows/fbgemmci.yml b/.github/workflows/fbgemmci.yml index b04aa4d99..fde64f16f 100644 --- a/.github/workflows/fbgemmci.yml +++ b/.github/workflows/fbgemmci.yml @@ -328,7 +328,7 @@ jobs: - uses: actions/checkout@v2 with: - ref: 'add_rocm_runner' + ref: ${{ github.ref }} submodules: 'true' - name: Checkout script diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt index 7ff3df701..a4d1544f3 100644 --- a/fbgemm_gpu/CMakeLists.txt +++ b/fbgemm_gpu/CMakeLists.txt @@ -43,8 +43,6 @@ find_package(PythonExtensions REQUIRED) set(FBGEMM ${CMAKE_CURRENT_SOURCE_DIR}/..) set(THIRDPARTY ${FBGEMM}/third_party) -message("${message_line}") -message(STATUS "THIRDPARTY ${THIRDPARTY}") if(DEFINED GLIBCXX_USE_CXX11_ABI) if(${GLIBCXX_USE_CXX11_ABI} EQUAL 1) From 4e86fc9b21b32d6488288e13c2edb8ebf22fa02d Mon Sep 17 00:00:00 2001 From: Li Li Date: Thu, 7 Jul 2022 16:33:34 +0000 Subject: [PATCH 67/76] move build_and_run script to FBGEMM repo --- .github/workflows/fbgemmci.yml | 11 +------ .jenkins/rocm/build_and_test.sh | 57 +++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 10 deletions(-) create mode 100644 .jenkins/rocm/build_and_test.sh diff --git a/.github/workflows/fbgemmci.yml b/.github/workflows/fbgemmci.yml index fde64f16f..19b64cb5e 100644 --- a/.github/workflows/fbgemmci.yml +++ b/.github/workflows/fbgemmci.yml @@ -331,12 +331,6 @@ jobs: ref: ${{ github.ref }} submodules: 'true' - - name: Checkout script - uses: actions/checkout@v2 - with: - repository: liligwu/fbgemm_ci - path: fbgemm_ci - - name: build fbgemm_gpu and test shell: bash run: | @@ -348,9 +342,6 @@ jobs: JENKINS_REPO_DIR=fbgemm-private-jenkins JENKINS_REPO_DIR_BAREMETAL=$PWD JENKINS_REPO_DIR_DOCKER=/workspace/$JENKINS_REPO_DIR - SCRIPT_DIR=fbgemm_ci - SCRIPT_DIR_BAREMETAL=$PWD/$SCRIPT_DIR - SCRIPT_DIR_DOCKER=/workspace/$SCRIPT_DIR # git clone https://github.com/ROCmSoftwarePlatform/FBGEMM.git $JENKINS_REPO_DIR_BAREMETAL DOCKER_OPTIONS="\ --network=host \ @@ -364,7 +355,7 @@ jobs: -v $JENKINS_REPO_DIR_BAREMETAL:$JENKINS_REPO_DIR_DOCKER \ -v $SCRIPT_DIR_BAREMETAL:$SCRIPT_DIR_DOCKER " - docker run $DOCKER_OPTIONS $DOCKER_IMAGE $SCRIPT_DIR_DOCKER/docker/scripts/build_and_run_unit_tests.sh $JENKINS_REPO_DIR_DOCKER + docker run $DOCKER_OPTIONS $DOCKER_IMAGE $JENKINS_REPO_DIR_DOCKER/.jenkins/rocm/build_and_test.sh $JENKINS_REPO_DIR_DOCKER build_cpu_only: runs-on: ${{ matrix.os }} diff --git a/.jenkins/rocm/build_and_test.sh b/.jenkins/rocm/build_and_test.sh new file mode 100644 index 000000000..dd4c4eb09 --- /dev/null +++ b/.jenkins/rocm/build_and_test.sh @@ -0,0 +1,57 @@ +#!/bin/bash + +# exit immediately on failure, or if an undefined variable is used +set -eux + +FBGEMM_REPO_DIR=${1:-/workspace/FBGEMM} + +git config --global --add safe.directory $FBGEMM_REPO_DIR +git config --global --add safe.directory $FBGEMM_REPO_DIR/third_party/asmjit +git config --global --add safe.directory $FBGEMM_REPO_DIR/third_party/cpuinfo +git config --global --add safe.directory $FBGEMM_REPO_DIR/third_party/googletest +git config --global --add safe.directory $FBGEMM_REPO_DIR/third_party/hipify_torch + +# Install dependencies +apt-get update --allow-insecure-repositories && \ + apt-get install -y --allow-unauthenticated \ + git \ + jq \ + sshfs \ + sshpass \ + unzip + +apt-get install -y locales +locale-gen en_US.UTF-8 + +pip3 install click +pip3 install jinja2 +pip3 install ninja +pip3 install scikit-build +pip3 install --upgrade hypothesis + +pip3 list + +# Build fbgemm_gpu +cd $FBGEMM_REPO_DIR/fbgemm_gpu +export MAX_JOBS=`nproc` +export PYTORCH_ROCM_ARCH="gfx908" +CXX=hipcc python setup.py build develop + +export FBGEMM_TEST_WITH_ROCM=1 + +# Test fbgemm_gpu +cd test + +python layout_transform_ops_test.py --verbose + +python permute_pooled_embedding_modules_test.py --verbose + +python sparse_ops_test.py --verbose + +python merge_pooled_embeddings_test.py --verbose + +python quantize_ops_test.py --verbose + +python split_embedding_inference_converter_test.py --verbose + +python split_table_batched_embeddings_test.py --verbose From bd39c030a28ef913a81f3a693ed4b5332e468866 Mon Sep 17 00:00:00 2001 From: Li Li Date: Thu, 7 Jul 2022 16:38:25 +0000 Subject: [PATCH 68/76] remove SCRIPT_DIR_BAREMETAL --- .github/workflows/fbgemmci.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/fbgemmci.yml b/.github/workflows/fbgemmci.yml index 19b64cb5e..5b56f78f7 100644 --- a/.github/workflows/fbgemmci.yml +++ b/.github/workflows/fbgemmci.yml @@ -352,8 +352,7 @@ jobs: --security-opt seccomp=unconfined \ --device=/dev/kfd \ --device=/dev/dri \ - -v $JENKINS_REPO_DIR_BAREMETAL:$JENKINS_REPO_DIR_DOCKER \ - -v $SCRIPT_DIR_BAREMETAL:$SCRIPT_DIR_DOCKER + -v $JENKINS_REPO_DIR_BAREMETAL:$JENKINS_REPO_DIR_DOCKER " docker run $DOCKER_OPTIONS $DOCKER_IMAGE $JENKINS_REPO_DIR_DOCKER/.jenkins/rocm/build_and_test.sh $JENKINS_REPO_DIR_DOCKER From a6e24584a2cbaacf2c2fa1e339da209b357b143f Mon Sep 17 00:00:00 2001 From: Li Li Date: Thu, 7 Jul 2022 16:55:32 +0000 Subject: [PATCH 69/76] change build_and_run permission --- .jenkins/rocm/build_and_test.sh | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 .jenkins/rocm/build_and_test.sh diff --git a/.jenkins/rocm/build_and_test.sh b/.jenkins/rocm/build_and_test.sh old mode 100644 new mode 100755 From 6c5ccc77575092841d4a3dfe9b9d18881dc0fa76 Mon Sep 17 00:00:00 2001 From: Li Li Date: Thu, 7 Jul 2022 20:01:56 +0000 Subject: [PATCH 70/76] fix the data type matching issue --- .../fbgemm_gpu/split_table_batched_embeddings_ops.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py index eb096be01..d11b06db1 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py @@ -1942,8 +1942,8 @@ def _update_cache_miss_counter( lxu_cache_locations: Tensor, linear_cache_indices: Tensor, ) -> None: - CACHE_MISS = torch.tensor([-1], device=self.current_device, dtype=torch.int32) - CACHE_HIT = torch.tensor([-2], device=self.current_device, dtype=torch.int32) + CACHE_MISS = torch.tensor([-1], device=self.current_device, dtype=linear_cache_indices.dtype) + CACHE_HIT = torch.tensor([-2], device=self.current_device, dtype=linear_cache_indices.dtype) cache_missed_locations = torch.where( lxu_cache_locations == CACHE_MISS, linear_cache_indices, CACHE_HIT @@ -1977,8 +1977,8 @@ def _update_tablewise_cache_miss( linear_cache_indices: Tensor, offsets: Tensor, ) -> None: - CACHE_MISS = torch.tensor([-1], device=self.current_device, dtype=torch.int32) - CACHE_HIT = torch.tensor([-2], device=self.current_device, dtype=torch.int32) + CACHE_MISS = torch.tensor([-1], device=self.current_device, dtype=linear_cache_indices.dtype) + CACHE_HIT = torch.tensor([-2], device=self.current_device, dtype=linear_cache_indices.dtype) # pyre-ignore[6]: # Incompatible parameter type [6]: Expected `typing.Sized` for 1st From dc7d7ca485bc36ccee375a575928ab181d996f98 Mon Sep 17 00:00:00 2001 From: Li Li Date: Thu, 7 Jul 2022 20:41:22 +0000 Subject: [PATCH 71/76] fix indentation --- fbgemm_gpu/test/uvm_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/test/uvm_test.py b/fbgemm_gpu/test/uvm_test.py index e27f4969c..e34f4ee77 100644 --- a/fbgemm_gpu/test/uvm_test.py +++ b/fbgemm_gpu/test/uvm_test.py @@ -24,8 +24,8 @@ from fbgemm_gpu.test.test_utils import gpu_available, gpu_unavailable, skipIfRocm if gpu_available: - # pyre-ignore[21] - from fbgemm_gpu.uvm import cudaMemAdvise, cudaMemoryAdvise, cudaMemPrefetchAsync + # pyre-ignore[21] + from fbgemm_gpu.uvm import cudaMemAdvise, cudaMemoryAdvise, cudaMemPrefetchAsync from hypothesis import given, settings, Verbosity From 14f2f41a590e099044caccbc647f75287523fdce Mon Sep 17 00:00:00 2001 From: Li Li Date: Fri, 8 Jul 2022 23:45:49 +0000 Subject: [PATCH 72/76] recover the changes in split_table_batched_embeddings_ops.py --- .github/workflows/fbgemmci.yml | 3 +-- .jenkins/rocm/build_and_test.sh | 2 ++ fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py | 4 ++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/.github/workflows/fbgemmci.yml b/.github/workflows/fbgemmci.yml index 5b56f78f7..2a09a5ae8 100644 --- a/.github/workflows/fbgemmci.yml +++ b/.github/workflows/fbgemmci.yml @@ -337,12 +337,11 @@ jobs: set -eux env ls -l - DOCKER_IMAGE=rocm/pytorch:rocm5.1.3_ubuntu20.04_py3.7_pytorch_1.11.0 + DOCKER_IMAGE=rocm/pytorch:rocm5.1.1_ubuntu20.04_py3.7_pytorch_1.10.0 docker pull $DOCKER_IMAGE JENKINS_REPO_DIR=fbgemm-private-jenkins JENKINS_REPO_DIR_BAREMETAL=$PWD JENKINS_REPO_DIR_DOCKER=/workspace/$JENKINS_REPO_DIR - # git clone https://github.com/ROCmSoftwarePlatform/FBGEMM.git $JENKINS_REPO_DIR_BAREMETAL DOCKER_OPTIONS="\ --network=host \ --ipc=host \ diff --git a/.jenkins/rocm/build_and_test.sh b/.jenkins/rocm/build_and_test.sh index dd4c4eb09..789f1e660 100755 --- a/.jenkins/rocm/build_and_test.sh +++ b/.jenkins/rocm/build_and_test.sh @@ -28,6 +28,8 @@ pip3 install jinja2 pip3 install ninja pip3 install scikit-build pip3 install --upgrade hypothesis +pip3 uninstall -y torch torchvision +pip3 install --pre torch torchvision --extra-index-url https://download.pytorch.org/whl/nightly/rocm5.1.1/ pip3 list diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py index d11b06db1..fa8f5879a 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py @@ -1942,8 +1942,8 @@ def _update_cache_miss_counter( lxu_cache_locations: Tensor, linear_cache_indices: Tensor, ) -> None: - CACHE_MISS = torch.tensor([-1], device=self.current_device, dtype=linear_cache_indices.dtype) - CACHE_HIT = torch.tensor([-2], device=self.current_device, dtype=linear_cache_indices.dtype) + CACHE_MISS = torch.tensor([-1], device=self.current_device, dtype=torch.int32) + CACHE_HIT = torch.tensor([-2], device=self.current_device, dtype=torch.int32) cache_missed_locations = torch.where( lxu_cache_locations == CACHE_MISS, linear_cache_indices, CACHE_HIT From 1fcaff5e41820cacd43d391e1cb50de8f89548cf Mon Sep 17 00:00:00 2001 From: Li Li Date: Mon, 11 Jul 2022 14:37:43 +0000 Subject: [PATCH 73/76] change docker image of ROCm CI to staging_base --- .github/workflows/fbgemmci.yml | 2 +- .jenkins/rocm/build_and_test.sh | 1 - fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py | 4 ++-- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/.github/workflows/fbgemmci.yml b/.github/workflows/fbgemmci.yml index 2a09a5ae8..ffb0bdb7a 100644 --- a/.github/workflows/fbgemmci.yml +++ b/.github/workflows/fbgemmci.yml @@ -337,7 +337,7 @@ jobs: set -eux env ls -l - DOCKER_IMAGE=rocm/pytorch:rocm5.1.1_ubuntu20.04_py3.7_pytorch_1.10.0 + DOCKER_IMAGE=rocm/pytorch:rocm5.1.1_ubuntu20.04_py3.7_pytorch_staging_base docker pull $DOCKER_IMAGE JENKINS_REPO_DIR=fbgemm-private-jenkins JENKINS_REPO_DIR_BAREMETAL=$PWD diff --git a/.jenkins/rocm/build_and_test.sh b/.jenkins/rocm/build_and_test.sh index 789f1e660..1cc3cb10c 100755 --- a/.jenkins/rocm/build_and_test.sh +++ b/.jenkins/rocm/build_and_test.sh @@ -28,7 +28,6 @@ pip3 install jinja2 pip3 install ninja pip3 install scikit-build pip3 install --upgrade hypothesis -pip3 uninstall -y torch torchvision pip3 install --pre torch torchvision --extra-index-url https://download.pytorch.org/whl/nightly/rocm5.1.1/ pip3 list diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py index fa8f5879a..eb096be01 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py @@ -1977,8 +1977,8 @@ def _update_tablewise_cache_miss( linear_cache_indices: Tensor, offsets: Tensor, ) -> None: - CACHE_MISS = torch.tensor([-1], device=self.current_device, dtype=linear_cache_indices.dtype) - CACHE_HIT = torch.tensor([-2], device=self.current_device, dtype=linear_cache_indices.dtype) + CACHE_MISS = torch.tensor([-1], device=self.current_device, dtype=torch.int32) + CACHE_HIT = torch.tensor([-2], device=self.current_device, dtype=torch.int32) # pyre-ignore[6]: # Incompatible parameter type [6]: Expected `typing.Sized` for 1st From b73c3a9d04a484da7f74823cf27e916519032912 Mon Sep 17 00:00:00 2001 From: Li Li Date: Mon, 11 Jul 2022 14:51:20 +0000 Subject: [PATCH 74/76] run docker container as root in ROCm CI --- .github/workflows/fbgemmci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/fbgemmci.yml b/.github/workflows/fbgemmci.yml index ffb0bdb7a..b614aca4e 100644 --- a/.github/workflows/fbgemmci.yml +++ b/.github/workflows/fbgemmci.yml @@ -343,6 +343,7 @@ jobs: JENKINS_REPO_DIR_BAREMETAL=$PWD JENKINS_REPO_DIR_DOCKER=/workspace/$JENKINS_REPO_DIR DOCKER_OPTIONS="\ + --user 0 \ --network=host \ --ipc=host \ --shm-size 16G \ From 0c4b962072aa5c67bdac820614facd85cbad6896 Mon Sep 17 00:00:00 2001 From: liligwu Date: Tue, 12 Jul 2022 19:31:24 +0000 Subject: [PATCH 75/76] remove CXX=hipcc in ROCm CI --- .jenkins/rocm/build_and_test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.jenkins/rocm/build_and_test.sh b/.jenkins/rocm/build_and_test.sh index 1cc3cb10c..88da472cb 100755 --- a/.jenkins/rocm/build_and_test.sh +++ b/.jenkins/rocm/build_and_test.sh @@ -36,7 +36,7 @@ pip3 list cd $FBGEMM_REPO_DIR/fbgemm_gpu export MAX_JOBS=`nproc` export PYTORCH_ROCM_ARCH="gfx908" -CXX=hipcc python setup.py build develop +python setup.py build develop export FBGEMM_TEST_WITH_ROCM=1 From f7a0c6864d1f31e3dc8325dfead89cba6b3985d8 Mon Sep 17 00:00:00 2001 From: liligwu Date: Wed, 13 Jul 2022 23:33:50 +0000 Subject: [PATCH 76/76] enable more tests in ROCm CI --- .jenkins/rocm/build_and_test.sh | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/.jenkins/rocm/build_and_test.sh b/.jenkins/rocm/build_and_test.sh index 88da472cb..dadd1342c 100755 --- a/.jenkins/rocm/build_and_test.sh +++ b/.jenkins/rocm/build_and_test.sh @@ -5,11 +5,11 @@ set -eux FBGEMM_REPO_DIR=${1:-/workspace/FBGEMM} -git config --global --add safe.directory $FBGEMM_REPO_DIR -git config --global --add safe.directory $FBGEMM_REPO_DIR/third_party/asmjit -git config --global --add safe.directory $FBGEMM_REPO_DIR/third_party/cpuinfo -git config --global --add safe.directory $FBGEMM_REPO_DIR/third_party/googletest -git config --global --add safe.directory $FBGEMM_REPO_DIR/third_party/hipify_torch +git config --global --add safe.directory "$FBGEMM_REPO_DIR" +git config --global --add safe.directory "$FBGEMM_REPO_DIR/third_party/asmjit" +git config --global --add safe.directory "$FBGEMM_REPO_DIR/third_party/cpuinfo" +git config --global --add safe.directory "$FBGEMM_REPO_DIR/third_party/googletest" +git config --global --add safe.directory "$FBGEMM_REPO_DIR/third_party/hipify_torch" # Install dependencies apt-get update --allow-insecure-repositories && \ @@ -33,8 +33,9 @@ pip3 install --pre torch torchvision --extra-index-url https://download.pytorch. pip3 list # Build fbgemm_gpu -cd $FBGEMM_REPO_DIR/fbgemm_gpu -export MAX_JOBS=`nproc` +cd "$FBGEMM_REPO_DIR/fbgemm_gpu" +MAX_JOBS="$(nproc)" +export MAX_JOBS export PYTORCH_ROCM_ARCH="gfx908" python setup.py build develop @@ -43,16 +44,15 @@ export FBGEMM_TEST_WITH_ROCM=1 # Test fbgemm_gpu cd test +python batched_unary_embeddings_test.py --verbose +python input_combine_test.py --verbose +python jagged_tensor_ops_test.py --verbose python layout_transform_ops_test.py --verbose - -python permute_pooled_embedding_modules_test.py --verbose - -python sparse_ops_test.py --verbose - python merge_pooled_embeddings_test.py --verbose - +python metric_ops_test.py --verbose +python permute_pooled_embedding_modules_test.py --verbose python quantize_ops_test.py --verbose - +python sparse_ops_test.py --verbose python split_embedding_inference_converter_test.py --verbose - python split_table_batched_embeddings_test.py --verbose +python uvm_test.py --verbose