diff --git a/.github/scripts/utils_build.bash b/.github/scripts/utils_build.bash index 82fa3e26a2..709e7b62f4 100644 --- a/.github/scripts/utils_build.bash +++ b/.github/scripts/utils_build.bash @@ -370,6 +370,7 @@ install_build_tools () { patchelf \ rhash \ scikit-build \ + tbb-devel \ tbb \ wheel \ xz \ diff --git a/.github/workflows/_fbgemm_gpu_cuda_test.yml b/.github/workflows/_fbgemm_gpu_cuda_test.yml index 03d619cae0..692b6ab7ac 100644 --- a/.github/workflows/_fbgemm_gpu_cuda_test.yml +++ b/.github/workflows/_fbgemm_gpu_cuda_test.yml @@ -132,6 +132,9 @@ jobs: # clang-16: error: unknown argument: '-fno-tree-loop-vectorize' run: . $PRELUDE; install_cxx_compiler $BUILD_ENV gcc + - name: Install Build Tools + run: . $PRELUDE; install_build_tools $BUILD_ENV + - name: Install CUDA run: . $PRELUDE; install_cuda $BUILD_ENV ${{ matrix.cuda-version }} diff --git a/.github/workflows/fbgemm_gpu_ci_cpu.yml b/.github/workflows/fbgemm_gpu_ci_cpu.yml index 911944438a..5f5475acfc 100644 --- a/.github/workflows/fbgemm_gpu_ci_cpu.yml +++ b/.github/workflows/fbgemm_gpu_ci_cpu.yml @@ -75,7 +75,7 @@ jobs: { arch: arm, instance: "linux.arm64.m7g.4xlarge" }, ] build-target: [ "default" ] - python-version: [ "3.9", "3.10", "3.11", "3.12", "3.13" ] + python-version: [ "3.10", "3.11", "3.12", "3.13" ] compiler: [ "gcc", "clang" ] steps: @@ -149,7 +149,7 @@ jobs: { arch: arm, instance: "linux.arm64.m7g.4xlarge", timeout: 30 }, ] build-target: [ "default" ] - python-version: [ "3.9", "3.10", "3.11", "3.12", "3.13" ] + python-version: [ "3.10", "3.11", "3.12", "3.13" ] compiler: [ "gcc", "clang" ] needs: build_artifact diff --git a/cmake/modules/CppLibrary.cmake b/cmake/modules/CppLibrary.cmake index 92a93a60b6..388d3ac779 100644 --- a/cmake/modules/CppLibrary.cmake +++ b/cmake/modules/CppLibrary.cmake @@ -168,6 +168,18 @@ function(cpp_library) target_link_libraries(${lib_name} PUBLIC OpenMP::OpenMP_CXX) endif() + if(NOT TARGET TBB::tbb) + find_package(TBB QUIET) + endif() + if(TBB_FOUND) + target_link_libraries(${lib_name} PUBLIC TBB::tbb) + else() + find_library(TBB_LIB NAMES tbb tbb12 HINTS $ENV{CONDA_PREFIX}/lib /usr/lib/x86_64-linux-gnu /usr/local/lib /lib/x86_64-linux-gnu) + if(TBB_LIB) + target_link_libraries(${lib_name} PUBLIC ${TBB_LIB}) + endif() + endif() + # Add sanitizer options if needed if(args_SANITIZER_OPTIONS) target_link_options(${lib_name} PUBLIC diff --git a/cmake/modules/GpuCppLibrary.cmake b/cmake/modules/GpuCppLibrary.cmake index 51c30df750..e662848348 100644 --- a/cmake/modules/GpuCppLibrary.cmake +++ b/cmake/modules/GpuCppLibrary.cmake @@ -302,6 +302,18 @@ function(gpu_cpp_library) list(APPEND library_dependencies ${NVML_LIB_PATH}) endif() + if(NOT TARGET TBB::tbb) + find_package(TBB QUIET) + endif() + if(TBB_FOUND) + list(APPEND library_dependencies TBB::tbb) + else() + find_library(TBB_LIB NAMES tbb tbb12 HINTS $ENV{CONDA_PREFIX}/lib /usr/lib/x86_64-linux-gnu /usr/local/lib /lib/x86_64-linux-gnu) + if(TBB_LIB) + list(APPEND library_dependencies ${TBB_LIB}) + endif() + endif() + # Link against the external libraries as needed target_link_libraries(${lib_name} PRIVATE ${library_dependencies}) diff --git a/fbgemm_gpu/bench/jagged_tensor_benchmark.py b/fbgemm_gpu/bench/jagged_tensor_benchmark.py index bcc3e27488..51375f0a64 100644 --- a/fbgemm_gpu/bench/jagged_tensor_benchmark.py +++ b/fbgemm_gpu/bench/jagged_tensor_benchmark.py @@ -10,8 +10,11 @@ import functools import logging +import os import random +from contextlib import nullcontext from dataclasses import dataclass +from typing import Callable import click import fbgemm_gpu @@ -542,6 +545,17 @@ def ref( @click.option("--has-weights", is_flag=True, default=False) @click.option("--weight-type", type=str, default="float") @click.option("--use-selected-lengths-sum", is_flag=True, default=False) +@click.option( + "--export-trace", + is_flag=True, + default=False, + help="Enable export of trace for profiling. Default is False.", +) +@click.option( + "--trace-url", + type=str, + default="keyed_jagged_index_select_dim1_{phase}_trace_{ospid}.json", +) def keyed_jagged_index_select_dim1( num_batches: int, max_seq_length: int, @@ -551,6 +565,8 @@ def keyed_jagged_index_select_dim1( has_weights: bool, weight_type: str, use_selected_lengths_sum: bool, + export_trace: bool, + trace_url: str, ) -> None: jagged_tensor_types = { "float": torch.float, @@ -622,20 +638,28 @@ def keyed_jagged_index_select_dim1( if is_float: values.requires_grad = True - time, output = benchmark_torch_function( - torch.ops.fbgemm.keyed_jagged_index_select_dim1, - ( - values, - lengths, - offsets, - indices, - input_batch_size, - weights, - selected_lengths_sum, - ), - iters=1000, - ) - output = output[0] + def _kineto_trace_handler(p: profile, phase: str) -> None: + p.export_chrome_trace(trace_url.format(phase=phase, ospid=os.getpid())) + + # pyre-ignore[3] + def context_factory(on_trace_ready: Callable[[profile], None]): + return profile(on_trace_ready=on_trace_ready) if export_trace else nullcontext() + + with context_factory(lambda p: _kineto_trace_handler(p, "fwd")): + time, output = benchmark_torch_function( + torch.ops.fbgemm.keyed_jagged_index_select_dim1, + ( + values, + lengths, + offsets, + indices, + input_batch_size, + weights, + selected_lengths_sum, + ), + iters=1000, + ) + output = output[0] # Prepare inputs for the reference run ref_inputs = [] @@ -687,9 +711,12 @@ def keyed_jagged_index_select_dim1_ref( return grad = torch.rand_like(output) - time, _ = benchmark_torch_function( - functools.partial(output.backward, retain_graph=True), (grad,), iters=1000 - ) + + with context_factory(lambda p: _kineto_trace_handler(p, "bwd")): + time, _ = benchmark_torch_function( + functools.partial(output.backward, retain_graph=True), (grad,), iters=1000 + ) + time_ref, _ = benchmark_torch_function( functools.partial(output_ref.backward, retain_graph=True), (grad,), iters=1000 ) diff --git a/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py index 4ffb7341a5..4dd8b3dbb3 100644 --- a/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py @@ -1506,4 +1506,4 @@ def context_factory(on_trace_ready: Callable[[profile], None]): if __name__ == "__main__": - cli() + cli() \ No newline at end of file diff --git a/fbgemm_gpu/cmake/tbe_sources.py b/fbgemm_gpu/cmake/tbe_sources.py index 82092cc173..b38f862564 100644 --- a/fbgemm_gpu/cmake/tbe_sources.py +++ b/fbgemm_gpu/cmake/tbe_sources.py @@ -176,7 +176,6 @@ "_nobag" if nobag else "", ) for nobag in [ - True, False, ] for weighted in ( @@ -495,7 +494,6 @@ "_nobag" if nobag else "", ) for nobag in [ - True, False, ] for weighted in ( diff --git a/fbgemm_gpu/codegen/genscript/generate_backward_split.py b/fbgemm_gpu/codegen/genscript/generate_backward_split.py index a817232910..50506decb1 100644 --- a/fbgemm_gpu/codegen/genscript/generate_backward_split.py +++ b/fbgemm_gpu/codegen/genscript/generate_backward_split.py @@ -52,7 +52,11 @@ def render_backward_templates( return weighted_options = [True, False] - nobag_options = [True, False] if (not is_gwd) else [False] + nobag_options = ( + [True, False] + if (not (is_gwd or kwargs.get("is_hip_optimized_backward"))) + else [False] + ) vbe_options = [True, False] if (kwargs.get("has_vbe_support")) else [False] ssd_options = [True, False] if kwargs.get("has_ssd_support") else [False] template = CodeTemplate.load(template_filepath) @@ -327,8 +331,7 @@ def generate_backward_indices() -> None: @staticmethod def generate_rocm_backward_split(**kwargs: Any) -> None: - # Generate backward device kernels based on weighted (True/False), VBE - # (True/False), no bag (True/False) + # Generate backward device kernels based on weighted (True/False) template_filepath = ( "training/backward/rocm/embedding_backward_split_device_kernel_template.hip" ) @@ -343,6 +346,7 @@ def generate_rocm_backward_split(**kwargs: Any) -> None: "has_ssd_support": False, "dense": False, "gen_once": False, + "is_hip_optimized_backward": True, }, ) @@ -422,6 +426,7 @@ def generate() -> None: "lxu_cache_locations", # 3 "uvm_cache_stats", # 4 "prev_iter_dev", # 5 + "vbe_output_offsets", # 6 ], "aux_int": [ "iter", # 0 diff --git a/fbgemm_gpu/codegen/genscript/optimizer_args.py b/fbgemm_gpu/codegen/genscript/optimizer_args.py index 9d7235af84..9c8924e49f 100644 --- a/fbgemm_gpu/codegen/genscript/optimizer_args.py +++ b/fbgemm_gpu/codegen/genscript/optimizer_args.py @@ -73,9 +73,7 @@ class OptimizerArgsSetItem: "row_counter_dev": "(q!)", "row_counter_uvm": "(r!)", "optim_tensor": "(s!)", - "delta_weights_host": "(t!)", - "delta_weights_dev": "(u!)", - "delta_weights_uvm": "(v!)", + "vbe_output": "(t!)", } ###################################################################### diff --git a/fbgemm_gpu/codegen/genscript/optimizers.py b/fbgemm_gpu/codegen/genscript/optimizers.py index c61e6843f9..8c25dc0d8f 100644 --- a/fbgemm_gpu/codegen/genscript/optimizers.py +++ b/fbgemm_gpu/codegen/genscript/optimizers.py @@ -197,6 +197,9 @@ def rowwise_adagrad() -> Dict[str, Any]: at::acc_type multiplier = 0.0; at::acc_type correction = 0.0; + """ + split_precomputation_preload = split_precomputation + split_precomputation += """ if (threadIdx.x == 0) { auto new_sum_square_grads = g_avg_square; @@ -228,6 +231,38 @@ def rowwise_adagrad() -> Dict[str, Any]: multiplier = SHFL_SYNC(multiplier, 0); correction = SHFL_SYNC(correction, 0); """ + split_precomputation_preload += """ + if (threadIdx.x == 0) { + auto new_sum_square_grads = g_avg_square; + + // Update the optimizer state. Use optimizer state offloading only if + // SSD and if enabled by the user + if (enable_optimizer_offloading) { + // Fetch the pointer to the optimizer state along the cache row + auto* optimizer = weight_row_template.template optimizer_state_ptr(); + new_sum_square_grads += optimizer->momentum; + optimizer->momentum = new_sum_square_grads; + + } else { + new_sum_square_grads += momentum1_val; + momentum1[idx] = new_sum_square_grads; + } + + multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps); + if (weight_decay_mode == 1) { + // L2 regularization + correction = 1.0 - multiplier * weight_decay; + } else if (weight_decay_mode == 2 || weight_decay_mode == 5) { + // Decoupled weight decay + correction = 1.0 - learning_rate * weight_decay; + } else { + // default value + correction = 1.0; + } + } + multiplier = SHFL_SYNC(multiplier, 0); + correction = SHFL_SYNC(correction, 0); + """ split_weight_update_cpu = """ at::acc_type g_local_sum_square = 0.0; for (int64_t d = 0; d < D; ++d) { @@ -275,6 +310,7 @@ def rowwise_adagrad() -> Dict[str, Any]: }, ), "split_precomputation": split_precomputation, + "split_precomputation_preload": split_precomputation_preload, "split_weight_update": split_weight_update, "split_post_update": split_post_update, "split_weight_update_cpu": split_weight_update_cpu, diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp index 626838e930..0bc3c5f254 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp @@ -172,7 +172,7 @@ Tensor split_embedding_codegen_lookup_dense_function( c10::SymInt /* max_B = -1 */, c10::SymInt /* max_B_feature_rank = -1 */, c10::SymInt /* vbe_output_size = -1 */, - bool /* mixed_D = true */) { + bool /* mixed_D = false */) { return SplitLookupFunction_Dense_Op::apply( host_weights, weights_offsets, diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh index b9db6e47f8..bb15b24f15 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh @@ -14,6 +14,100 @@ using namespace fbgemm_gpu; +{%- if is_rocm %} +// Helper macro: Generate block_size grad_offset_j_i variables (i from 1 to block_size-1) +#define GRAD_OFFSET(i, j) const auto grad_offset_j_##i = SHFL_SYNC(grad_offset, j + i); +#define L(i, j) int32_t l_j_##i = SHFL_SYNC(l, j + i); +#define B(i, j) int32_t b_j_##i = SHFL_SYNC(b, j + i); +#define D_START(i, j) int32_t D_start_j_##i = SHFL_SYNC(D_start, j + i); +#define IDX_WEIGHT(i, j) at::acc_type idx_weight_j_##i = SHFL_SYNC(idx_weight, j + i); + +#define REPEAT_8(X, j) X(1, j); X(2, j); X(3, j); X(4, j); X(5, j); X(6, j); X(7, j); +#define REPEAT_4(X, j) X(1, j); X(2, j); X(3, j); +#define REPEAT_2(X, j) X(1, j); +#define REPEAT_1(X, j) // No additional variables needed for block size 1 + +#define REPEAT_I_S_8(X, j, m, n) X(1, j, m, n); X(2, j, m, n); X(3, j, m, n); X(4, j, m, n); X(5, j, m, n); X(6, j, m, n); X(7, j, m, n); +#define REPEAT_I_S_4(X, j, m, n) X(1, j, m, n); X(2, j, m, n); X(3, j, m, n); +#define REPEAT_I_S_2(X, j, m, n) X(1, j, m, n); +#define REPEAT_I_S_1(X, j, m, n) // No additional variables needed for block size 1 + +// Helper macro: Generate block_size Vec4TAcc objects (i from 1 to block_size-1) +// if nobag and is_index_select +#define GRAD_VEC_N_I(i, grad_offset, grad_stride, d) Vec4TAcc grad_out_vec_##i(&grad_output[grad_offset + l_j_##i * grad_stride + d]); +// elif nobag +#define GRAD_VEC_N(i, d) Vec4TAcc grad_out_vec_##i(&grad_output[l_j_##i][d]); +// elif vbe +#define GRAD_VEC_V(i, d) Vec4TAcc grad_out_vec_##i(&grad_output[0][grad_offset_j_##i + d]); +// else +#define GRAD_VEC(i, d) Vec4TAcc grad_out_vec_##i(&grad_output[b_j_##i][0] + D_start_j_##i + d); + +// Helper macro: Generate block_size fma_ calls (i from 1 to block_size-1) +#define FMA_GRAD(i, vec) grad_sum[vec].fma_(grad_out_vec_##i, idx_weight_j_##i); +// Helper macro: Generate block_size add_ calls (i from 1 to block_size-1) +#define ADD_GRAD(i, vec) grad_sum[vec].add_(grad_out_vec_##i); + +// Core macro: Process blocks of specified size (block_size = 8/4/2/1) +// Parameters: +// - block_size: Size of each block to process +// - unroll_count: Number of unroll iterations for the inner loop +#define PROCESS_BLOCK(block_size, unroll_count, grad_sum, grad_output, grad_offset, vec_start, kThreadGroupSize, threadIdx_x, VEC_WIDTH, D, j, sl, sl_end) \ + for (; j + (block_size - 1) < kThreadGroupSize && sl + j + (block_size - 1) < sl_end; j += block_size) { \ + {%- if nobag %} + int32_t l_j_0 = SHFL_SYNC(l, j); \ + REPEAT_##block_size(L, j) \ + {%- elif vbe %} + /* Generate block_size grad_offset_j_0 ~ grad_offset_j_(block_size-1) */ \ + const auto grad_offset_j_0 = SHFL_SYNC(grad_offset, j); \ + /* Generate subsequent grad_offset_j_1 ~ grad_offset_j_(block_size-1) based on block size */ \ + REPEAT_##block_size(GRAD_OFFSET, j) \ + {%- else %} + int32_t b_j_0 = SHFL_SYNC(b, j); \ + REPEAT_##block_size(B, j) \ + int32_t D_start_j_0 = SHFL_SYNC(D_start, j); \ + REPEAT_##block_size(D_START, j) \ + {%- endif %} + {%- if weighted %} + at::acc_type idx_weight_j_0 = SHFL_SYNC(idx_weight, j); \ + REPEAT_##block_size(IDX_WEIGHT, j) \ + {%- endif %} + {%- set d = "(((vec + vec_start) * kThreadGroupSize + threadIdx.x) * VEC_WIDTH)" %} + \ + for (int32_t vec = 0; vec < unroll_count && (((vec + vec_start) * kThreadGroupSize + threadIdx_x) * VEC_WIDTH) < D; ++vec) { \ + const int32_t d = (((vec + vec_start) * kThreadGroupSize + threadIdx_x) * VEC_WIDTH); \ + /* Generate block_size Vec4TAcc objects and accumulate them */ \ + Vec4TAcc grad_out_vec_0( \ + {%- if nobag and is_index_select %} + &grad_output[grad_offset + l_j_0 * grad_stride + d] \ + {%- elif nobag %} + &grad_output[l_j_0][d] \ + {%- elif vbe %} + &grad_output[0][grad_offset_j_0 + d] \ + {%- else %} + &grad_output[b_j_0][0] + D_start_j_0 + d \ + {%- endif %} + ); \ + {%- if nobag and is_index_select %} + REPEAT_I_S_##block_size(GRAD_VEC_N_I, grad_offset, grad_stride, d) \ + {%- elif nobag %} + REPEAT_##block_size(GRAD_VEC_N, d) \ + {%- elif vbe %} + REPEAT_##block_size(GRAD_VEC_V, d) \ + {%- else %} + REPEAT_##block_size(GRAD_VEC, d) \ + {%- endif %} + \ + {%- if weighted %} + grad_sum[vec].fma_(grad_out_vec_0, idx_weight_j_0); \ + REPEAT_##block_size(FMA_GRAD, vec) \ + {%- else %} + grad_sum[vec].add_(grad_out_vec_0); \ + REPEAT_##block_size(ADD_GRAD, vec) \ + {%- endif %} + } \ + } +{%- endif %} + {%- if gen_once %} {#- /* The kernels in this section will be generated only once for all TBE configs @@ -141,7 +235,25 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}( ? sorted_indice_weights[segment_start + sl_j] : 0.0; {%- endif %} - for (int32_t j = 0; j < kThreadGroupSize && sl + j < sl_end; ++j) { + int32_t j = 0; + + {%- if is_rocm %} + // Process blocks of different sizes with loop unrolling + if constexpr (sizeof(grad_t) <= 2) { + PROCESS_BLOCK(8, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \ + vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end) + } + PROCESS_BLOCK(4, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \ + vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end) + PROCESS_BLOCK(2, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \ + vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end) + PROCESS_BLOCK(1, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \ + vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end) + +#undef PROCESS_BLOCK + + {%- else %} + for (; j < kThreadGroupSize && sl + j < sl_end; ++j) { {%- if nobag %} int32_t l_j = SHFL_SYNC(l, j); {%- elif vbe %} @@ -180,6 +292,7 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}( {%- endif %} } } + {%- endif %} } {%- set d_vec = "((vec + vec_start) * kThreadGroupSize + threadIdx.x)" %} @@ -198,4 +311,4 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}( {%- endif %} - // clang-format on + // clang-format on \ No newline at end of file diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp index e071d88768..3fe516891f 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp @@ -109,7 +109,12 @@ enum SSDTensor { gwd_lower_bound, {%- endif %} {# /* if is_gwd */ #} {%- endif %} {# /* if not nobag */ #} + {%- if vbe and not dense %} + {{ "is_experimental" if has_experimental else "false" }}, + std::nullopt /* vbe_output */ + {%- else %} {{ "is_experimental" if has_experimental else "false" }} + {%- endif %} ); if (is_annotate_trace_enabled) { @@ -474,7 +479,12 @@ Tensor {{ fwd_mdesc }}_embedding{{ ndesc }}_codegen_forward{{ desc_suffix }}_cud const int64_t iter, const double gwd_lower_bound, {%- endif %} + {%- if vbe and not dense %} + const bool is_experimental, + std::optional vbe_output = std::nullopt + {%- else %} const bool is_experimental + {%- endif %} ); Tensor @@ -708,7 +718,7 @@ class {{ autograd_func }} : static auto generate_vbe_metadata_op = torch::Dispatcher::singleton() .findSchemaOrThrow("fbgemm::generate_vbe_metadata", "") - .typed(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const int64_t, const bool, const c10::SymInt, const int64_t, const c10::SymInt)>(); + .typed(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const int64_t, const bool, const c10::SymInt, const int64_t, const c10::SymInt, const std::optional&)>(); auto [ vbe_row_output_offsets, @@ -729,7 +739,8 @@ class {{ autograd_func }} : {%- endif %} max_B_feature_rank, info_B_num_bits, - /*total_B=*/offsets.sym_size(0) - 1 + /*total_B=*/offsets.sym_size(0) - 1, + std::nullopt /* pre-allocated vbe_output is not supported in TBE interface V1 or Dense TBE */ ); {%- endif %} @@ -949,7 +960,7 @@ class {{ autograd_func }} : #ifdef USE_ROCM constexpr int32_t BT_block_size = 64; - constexpr int32_t max_segment_length_per_warp = 64; + constexpr int32_t max_segment_length_per_warp = 16384; #else constexpr int32_t BT_block_size = 32; constexpr int32_t max_segment_length_per_warp = 32; @@ -1105,7 +1116,7 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function( {%- else %} const c10::SymInt vbe_output_size = -1, {%- endif %} - const bool mixed_D = true + const bool mixed_D = false ) { // TODO: refactor into macro {%- if has_gpu_support %} diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu old mode 100644 new mode 100755 index 6d38d1d99a..9ffaea3a67 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu @@ -23,6 +23,10 @@ #include "fbgemm_gpu/utils/assert_macros.h" #include "fbgemm_gpu/utils/kernel_launcher.cuh" +{%- if is_rocm %} +#include "fbgemm_gpu/rocm/cdna_guard.h" +{%- endif %} + using Tensor = at::Tensor; using namespace fbgemm_gpu; @@ -209,8 +213,127 @@ __global__ __launch_bounds__(kForwardMaxThreads) void 2, offset_idx + D_emb <= weights_numel, offset_idx ) {%- endif %} + int32_t j = 0; + {%- if is_rocm and not ssd and not dense and not use_vec_blocking and not vbe %} + // Currently for split_embedding_codegen_grad_indice_weights_kernel only + if (placement != PlacementType::MANAGED_CACHING) { + for (; j < kWarpSize && l_start + j + 3 < L; j += 4) { + const auto offset_idx_j0 = shfl_sync(offset_idx, j); + const auto offset_idx_j1 = shfl_sync(offset_idx, j+1); + const auto offset_idx_j2 = shfl_sync(offset_idx, j+2); + const auto offset_idx_j3 = shfl_sync(offset_idx, j+3); + + at::acc_type grad_indice_weight0 = 0.0; + at::acc_type grad_indice_weight1 = 0.0; + at::acc_type grad_indice_weight2 = 0.0; + at::acc_type grad_indice_weight3 = 0.0; + + const auto weight_row0 = WeightRowAccessor>(&weights[offset_idx_j0], D); + const auto weight_row1 = WeightRowAccessor>(&weights[offset_idx_j1], D); + const auto weight_row2 = WeightRowAccessor>(&weights[offset_idx_j2], D); + const auto weight_row3 = WeightRowAccessor>(&weights[offset_idx_j3], D); + + #pragma unroll kFixedMaxVecsPerThread + for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && (kWarpSize * vec + threadIdx.x) * kVecWidth < D; ++vec) { + const int32_t d = (kWarpSize * vec + threadIdx.x) * kVecWidth; + + Vec4T> weight0, weight1, weight2, weight3; + weight0 = weight_row0.load(d); + weight1 = weight_row1.load(d); + weight2 = weight_row2.load(d); + weight3 = weight_row3.load(d); + + grad_indice_weight0 += weight0.acc.x * grad_out[vec].acc.x + weight0.acc.y * grad_out[vec].acc.y + + weight0.acc.z * grad_out[vec].acc.z + weight0.acc.w * grad_out[vec].acc.w; + grad_indice_weight1 += weight1.acc.x * grad_out[vec].acc.x + weight1.acc.y * grad_out[vec].acc.y + + weight1.acc.z * grad_out[vec].acc.z + weight1.acc.w * grad_out[vec].acc.w; + grad_indice_weight2 += weight2.acc.x * grad_out[vec].acc.x + weight2.acc.y * grad_out[vec].acc.y + + weight2.acc.z * grad_out[vec].acc.z + weight2.acc.w * grad_out[vec].acc.w; + grad_indice_weight3 += weight3.acc.x * grad_out[vec].acc.x + weight3.acc.y * grad_out[vec].acc.y + + weight3.acc.z * grad_out[vec].acc.z + weight3.acc.w * grad_out[vec].acc.w; + } + + grad_indice_weight0 = warpReduceAllSum>(grad_indice_weight0); + grad_indice_weight1 = warpReduceAllSum>(grad_indice_weight1); + grad_indice_weight2 = warpReduceAllSum>(grad_indice_weight2); + grad_indice_weight3 = warpReduceAllSum>(grad_indice_weight3); + + if (threadIdx.x == 0) { + grad_indice_weights[indices_start + l_start + j] = grad_indice_weight0; + grad_indice_weights[indices_start + l_start + j+1] = grad_indice_weight1; + grad_indice_weights[indices_start + l_start + j+2] = grad_indice_weight2; + grad_indice_weights[indices_start + l_start + j+3] = grad_indice_weight3; + } + } + } else { + for (; j < kWarpSize && l_start + j + 3 < L; j += 4) { + const auto offset_idx_j0 = shfl_sync(offset_idx, j); + const auto offset_idx_j1 = shfl_sync(offset_idx, j+1); + const auto offset_idx_j2 = shfl_sync(offset_idx, j+2); + const auto offset_idx_j3 = shfl_sync(offset_idx, j+3); + + const auto cache_idx_j0 = shfl_sync(cache_idx, j); + const auto cache_idx_j1 = shfl_sync(cache_idx, j+1); + const auto cache_idx_j2 = shfl_sync(cache_idx, j+2); + const auto cache_idx_j3 = shfl_sync(cache_idx, j+3); + + at::acc_type grad_indice_weight0 = 0.0; + at::acc_type grad_indice_weight1 = 0.0; + at::acc_type grad_indice_weight2 = 0.0; + at::acc_type grad_indice_weight3 = 0.0; + + const auto weight_row0 = WeightRowAccessor>(&weights[offset_idx_j0], D); + const auto weight_row1 = WeightRowAccessor>(&weights[offset_idx_j1], D); + const auto weight_row2 = WeightRowAccessor>(&weights[offset_idx_j2], D); + const auto weight_row3 = WeightRowAccessor>(&weights[offset_idx_j3], D); + + #pragma unroll kFixedMaxVecsPerThread + for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && (kWarpSize * vec + threadIdx.x) * kVecWidth < D; ++vec) { + const int32_t d = (kWarpSize * vec + threadIdx.x) * kVecWidth; + + Vec4T> weight0, weight1, weight2, weight3; + weight0 = (cache_idx_j0 != kCacheLocationMissing) ? + Vec4T>(&lxu_cache_weights[cache_idx_j0][d]) : + weight_row0.load(d); + + weight1 = (cache_idx_j1 != kCacheLocationMissing) ? + Vec4T>(&lxu_cache_weights[cache_idx_j1][d]) : + weight_row1.load(d); + + weight2 = (cache_idx_j2 != kCacheLocationMissing) ? + Vec4T>(&lxu_cache_weights[cache_idx_j2][d]) : + weight_row2.load(d); + + weight3 = (cache_idx_j3 != kCacheLocationMissing) ? + Vec4T>(&lxu_cache_weights[cache_idx_j3][d]) : + weight_row3.load(d); + + + grad_indice_weight0 += weight0.acc.x * grad_out[vec].acc.x + weight0.acc.y * grad_out[vec].acc.y + + weight0.acc.z * grad_out[vec].acc.z + weight0.acc.w * grad_out[vec].acc.w; + grad_indice_weight1 += weight1.acc.x * grad_out[vec].acc.x + weight1.acc.y * grad_out[vec].acc.y + + weight1.acc.z * grad_out[vec].acc.z + weight1.acc.w * grad_out[vec].acc.w; + grad_indice_weight2 += weight2.acc.x * grad_out[vec].acc.x + weight2.acc.y * grad_out[vec].acc.y + + weight2.acc.z * grad_out[vec].acc.z + weight2.acc.w * grad_out[vec].acc.w; + grad_indice_weight3 += weight3.acc.x * grad_out[vec].acc.x + weight3.acc.y * grad_out[vec].acc.y + + weight3.acc.z * grad_out[vec].acc.z + weight3.acc.w * grad_out[vec].acc.w; + } + + grad_indice_weight0 = warpReduceAllSum>(grad_indice_weight0); + grad_indice_weight1 = warpReduceAllSum>(grad_indice_weight1); + grad_indice_weight2 = warpReduceAllSum>(grad_indice_weight2); + grad_indice_weight3 = warpReduceAllSum>(grad_indice_weight3); - for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) { + if (threadIdx.x == 0) { + grad_indice_weights[indices_start + l_start + j] = grad_indice_weight0; + grad_indice_weights[indices_start + l_start + j+1] = grad_indice_weight1; + grad_indice_weights[indices_start + l_start + j+2] = grad_indice_weight2; + grad_indice_weights[indices_start + l_start + j+3] = grad_indice_weight3; + } + } + } + {%- endif %}{#-/* if is_rocm and not ssd and not dense and not use_vec_blocking and not vbe */#} + for (; j < kWarpSize && l_start + j < L; ++j) { const auto offset_idx_j = shfl_sync(offset_idx, j); {%- if not dense %} const auto {{ locs_or_addrs_idx }}_j = shfl_sync({{ locs_or_addrs_idx }}, j); @@ -359,6 +482,15 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( auto aligned_grad_output = aligned_grad_output_tensor_for_cuda_backwards(grad_output); CUDA_DEVICE_GUARD(dev_weights); + #ifdef USE_ROCM + if (!rocm::is_supported_cdna()) { + TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal."); + } + else { + // Ensure we're running on a supported CDNA architecture (including MI350) + TORCH_WARN_ONCE("Running on CDNA architecture"); + } + #endif const auto T = D_offsets.size(0) - 1; TORCH_CHECK_GT(T, 0); diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu index 25f7119a7a..b10eb1312e 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu @@ -625,7 +625,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row codegen/embedding_common_code_generator.py for more details */ #} -{{ instantiate_templates(use_subwarp_shuffle=False) }} +{{ instantiate_templates(use_subwarp_shuffle=True) }} //////////////////////////////////////////////////////////////////////////////// #endif diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index 5137b5766c..091b8d8001 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -32,6 +32,22 @@ {%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} {%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %} +{%- set is_optimized_hip_kernel_supported_mode = is_rocm and + optimizer == "rowwise_adagrad" and + not dense and + not nobag and + not is_index_select and + not is_gwd_kernel and + not vbe and + not ssd %} + +{%- set enable_optimized_hip_mixed_D_kernel = is_rocm and + optimizer == "rowwise_adagrad" and + not dense and + not is_index_select and + not is_gwd_kernel and + not nobag and + not ssd %} #include "fbgemm_gpu/embedding_backward_template_helpers.cuh" #include "fbgemm_gpu/utils/tensor_accessor_builder.h" @@ -333,6 +349,307 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( } } +{%- if enable_optimized_hip_mixed_D_kernel %} +template < + typename emb_t, + typename grad_t, + typename cache_t, + typename index_t, + {%- for ph_name in args.placeholder_tensor_names %} + typename {{ ph_name + "_ph_t"}}, + {%- endfor %} + int32_t kFixedMaxVecsPerThread, + int32_t kThreadGroupSize, + bool kUseVecBlocking> +__global__ __launch_bounds__(kBackwardMaxThreads) void +hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_warp_per_row_1( + const pta::PackedTensorAccessor64 grad_output, + {%- if optimizer != "none" %} + pta::PackedTensorAccessor64 dev_weights, + {%- if not dense %} + pta::PackedTensorAccessor64 uvm_weights, + pta::PackedTensorAccessor64 lxu_cache_weights, + const pta::PackedTensorAccessor32 weights_placements, + {%- endif %} + {%- endif %} + const pta::PackedTensorAccessor32 weights_offsets, + {%- if not nobag or is_index_select %} + const pta::PackedTensorAccessor32 D_offsets, + {%- else %} + int64_t D, + {%- endif %} + const pta::PackedTensorAccessor32 hash_size_cumsum, + const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, + {%- if not nobag %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- else %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- endif %} + {%- if not dense %} + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> sorted_{{ locs_or_addrs_tensor }}, + const bool use_uniq_cache_locations, + const pta::PackedTensorAccessor32 table_unique_indices_offsets, + {%- endif %} + {%- if weighted %} + const pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> sorted_indice_weights, + {%- endif %} + const pta::PackedTensorAccessor32 sorted_linear_indices_num_runs, + int32_t max_segment_length_per_warp, + {%- if not dense and optimizer != "none" %} + bool stochastic_rounding, + at::PhiloxCudaState stochastic_rounding_philox_args, + {%- else %} + pta::PackedTensorAccessor64 grad_dev_weights, + {%- endif %} // if not dense and optimizer != "none" + {%- if not nobag and vbe %} + const pta::PackedTensorAccessor32 B_offsets, + const pta::PackedTensorAccessor32 row_output_offsets, + {%- endif %} + {%- if not nobag %} + const int32_t info_B_num_bits, + const uint32_t info_B_mask, + {%- endif %} + const int32_t max_D, + const int32_t max_vecs_per_thread, + {%- if is_index_select %} + const pta::PackedTensorAccessor32 grad_offsets, + const bool permute_output_dim_0_1 + {%- else %} + {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }} + {%- endif %} +) { + {%- if not nobag %} + int32_t T = D_offsets.size(0) - 1; + {%- else %} + int32_t T = weights_offsets.size(0); + {%- endif %} + const auto start_run_id = blockIdx.x * blockDim.y + threadIdx.y; + +#define SUBWARP_SHFL_SYNC(val, srcLane) __shfl_sync(UINT64_MAX, val, srcLane, kThreadGroupSize) + +#ifdef FBGEMM_USE_SUBWARP_SHUFFLE + const unsigned int shfl_sync_mask = + ((1L << kThreadGroupSize) - 1) << + (threadIdx.y % (kWarpSize / kThreadGroupSize) * kThreadGroupSize); +#else + const unsigned int shfl_sync_mask = 0xffffffffu; +#endif + +#define BROADCAST(val, srcLane) __builtin_amdgcn_readlane(val,srcLane) + + constexpr int VEC_WIDTH = 4; + constexpr auto kIsInt8 = std::is_same::value; + + struct SharedMemory> smem; + const int32_t grad_sum_stride = max_D / VEC_WIDTH; + auto* smem_grad_sum = (kUseVecBlocking || kIsInt8) + ? smem.getPointer() + threadIdx.y * grad_sum_stride + : nullptr; + + constexpr int num_unroll = kThreadGroupSize; + + auto num_run_id = min(sorted_linear_indices_run.size(0), sorted_linear_indices_num_runs[0]); + + for (uint32_t out_run_id = start_run_id * num_unroll; out_run_id < num_run_id; out_run_id += gridDim.x * blockDim.y * num_unroll) { + auto num_valid_id = min(num_unroll, num_run_id - out_run_id); + auto is_valid = threadIdx.x < num_valid_id; + + int32_t s_segment_start = is_valid? sorted_linear_indices_cumulative_run_lengths[(out_run_id + threadIdx.x)] : -1; + int32_t s_segment_end = is_valid? sorted_linear_indices_cumulative_run_lengths[(out_run_id + threadIdx.x + 1)] : -1; + int64_t s_idx = is_valid? sorted_linear_indices_run[out_run_id + threadIdx.x] : -1; + + {%- if not nobag %} + uint32_t s_t_0 = is_valid? reinterpret_cast(&sorted_infos[0])[s_segment_start] : -1; + s_t_0 = s_t_0 >> info_B_num_bits; + {%- else %} + auto s_t_0 = is_valid? sorted_infos[s_segment_start] : -1; + s_t_0 = s_t_0 % T; + {%- endif %} + + int64_t s_hash_size = is_valid? hash_size_cumsum[s_t_0] : -1; + s_idx -= s_hash_size; + {%- if not nobag %} + int32_t s_D_offsets_0 = is_valid? D_offsets[s_t_0] : 0; + int32_t s_D_offsets_1 = is_valid? D_offsets[s_t_0 + 1] : 0; + auto s_D = s_D_offsets_1 - s_D_offsets_0; + {%- endif %} + + int32_t s_table_unique_indice_offset = is_valid? table_unique_indices_offsets[s_t_0] : 0; + int64_t s_weights_offset = is_valid? weights_offsets[s_t_0] : 0; + int32_t s_weights_placement = is_valid? weights_placements[s_t_0] : 0; + + {%- for tensor in args.split_tensors %} + {{ args.split_tensor_types[tensor] }}* __restrict__ s_{{ tensor }}; + const auto s_{{ tensor }}_placement = {{ tensor }}_placements[s_t_0]; + const int64_t s_{{ tensor }}_offset = {{ tensor }}_offsets[s_t_0]; + if (static_cast(s_{{ tensor }}_placement) == PlacementType::DEVICE) { + s_{{ tensor }} = &{{ tensor }}_dev[s_{{ tensor }}_offset]; + } else { + s_{{ tensor }} = &{{ tensor }}_uvm[s_{{ tensor }}_offset]; + } + {{ args.split_tensor_types[tensor] }} s_{{tensor}}_val = is_valid? s_{{tensor}}[s_idx] : 0; + + {%- endfor %} + + for (auto i = 0; i < num_valid_id; ++i) { + auto segment_start = SUBWARP_SHFL_SYNC(s_segment_start, i); + auto segment_end = SUBWARP_SHFL_SYNC(s_segment_end, i); + const int32_t SL = segment_end - segment_start; + if (SL >= max_segment_length_per_warp) { + continue; + } + + auto run_id = out_run_id + i; + auto t_0 = SUBWARP_SHFL_SYNC(s_t_0, i); + auto idx = SUBWARP_SHFL_SYNC(s_idx, i); + + {%- if not nobag %} + auto D = SUBWARP_SHFL_SYNC(s_D, i); + {%- endif %} + int32_t table_unique_indice_offset = SUBWARP_SHFL_SYNC(s_table_unique_indice_offset, i); + + {%- for tensor in args.split_tensors %} + const auto {{ tensor }}_placement = SUBWARP_SHFL_SYNC(s_{{ tensor }}_placement, i); + const int64_t {{ tensor }}_offset = SUBWARP_SHFL_SYNC(s_{{ tensor }}_offset, i); + {{ args.split_tensor_types[tensor] }} {{tensor}}_val = SUBWARP_SHFL_SYNC(s_{{ tensor }}_val, i); + {%- endfor %} + + // const int64_t momentum1_offset = SHFL_SYNC(s_momentum1_offset, i); + // const auto momentum1_placement = static_cast(SHFL_SYNC(s_momentum1_placement, i)); + // auto momentum1 = reinterpret_cast*>(SHFL_SYNC(reinterpret_cast(s_momentum1), i)); + // auto momentum1_val = momentum1[idx]; + + // now, each segment corresponds to exactly one table `t` and row in + // that table (`idx`). Thus, we can hoist out some of the book-keeping. + + const int32_t SL_per_warp = div_round_up(SL, blockDim.y); + const int32_t sl_start = 0; + const int32_t sl_end = SL; + + Vec4TAcc grad_sum[kFixedMaxVecsPerThread]; + constexpr int32_t kGroupVecWidth = kThreadGroupSize * VEC_WIDTH; + const int32_t num_vecs = (D + kGroupVecWidth - 1) / kGroupVecWidth; + + compute_grad_sum_{{ kdesc }}< + grad_t, + cache_t, + kFixedMaxVecsPerThread, + kThreadGroupSize, + VEC_WIDTH, + kUseVecBlocking>( + grad_sum, + smem_grad_sum, + grad_output, + {%- if not nobag or is_index_select %} + D_offsets, + {%- endif %} + D, + T, + sorted_infos, + {%- if weighted %} + sorted_indice_weights, + {%- endif %} + {%- if not nobag and vbe %} + B_offsets, + row_output_offsets, + {%- endif %} + {%- if not nobag %} + info_B_num_bits, + info_B_mask, + {%- endif %} + segment_start, + sl_start, + sl_end, + shfl_sync_mask, + num_vecs + ); + + // Copy value to max_vecs to make max_vecs_per_thread known at compile time + // when kUseVecBlocking == false + const int32_t max_vecs = + kUseVecBlocking ? max_vecs_per_thread : kFixedMaxVecsPerThread; + + {%- if not dense and optimizer != "none" %} + const int64_t weights_offset = SUBWARP_SHFL_SYNC(s_weights_offset, i); + const int32_t weights_placement = SUBWARP_SHFL_SYNC(s_weights_placement, i); + {{ mdesc }}_{{ optimizer }}_table_update_kernel< + emb_t, + cache_t, + {%- for ph_name in args.placeholder_tensor_names %} + {{ ph_name + "_ph_t" }}, + {%- endfor %} + kFixedMaxVecsPerThread, + kThreadGroupSize, + VEC_WIDTH, + kUseVecBlocking>( + dev_weights, + uvm_weights, + lxu_cache_weights, + weights_placement, + weights_offset, + sorted_{{ locs_or_addrs_tensor }}, + grad_sum, + smem_grad_sum, + smem_grad_sum, // shared_weight_update_row (reuse smem_grad_sum) + stochastic_rounding, + stochastic_rounding_philox_args, + run_id, + use_uniq_cache_locations + ? (run_id - table_unique_indices_offsets[t_0]) + : segment_start, + D, + t_0, + idx, + {%- if is_gwd_kernel %} + global_weight_decay, + {%- elif has_global_weight_decay_support %} + {# /* cases where gwd is not enabled/supported */ #} + 1, // global_weight_decay + {%- endif %} + shfl_sync_mask, + max_vecs, + {%- if ssd %} + enable_optimizer_offloading, + {%- endif %} + {%- for tensor in args.split_tensors %} + {{ tensor }}_placement, + {{ tensor }}_offset, + {{ tensor }}_val, + {%- endfor %} + {{ args.split_kernel_arg_names | join(", ") }} + ); + {%- else %} + // Write deduplicated gradient to grad_dev_weights gradient is sparse + // for split_embedding and dense for dense_embedding + {%- if dense %} + const int64_t weights_offset = weights_offsets[t_0]; + {%- else %} + // Compute offset of sparse gradient + const int64_t weights_offset = run_id * max_D; + idx = 0; + {%- endif %} + store_grad_sum< + emb_t, + cache_t, + kFixedMaxVecsPerThread, + kThreadGroupSize, + VEC_WIDTH, + kUseVecBlocking>( + grad_dev_weights, + grad_sum, + kUseVecBlocking ? smem_grad_sum : nullptr, + D, + weights_offset, + idx, + max_vecs + ); + {%- endif %} // if not dense and optimizer != "none" + } + } +} +{%- endif %} + //////////////////////////////////////////////////////////////////////////////// // Explicit Template Instantiations @@ -447,6 +764,85 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row }} {%- endif %} ); + +{%- if enable_optimized_hip_mixed_D_kernel %} + +template __global__ __launch_bounds__(kBackwardMaxThreads) void +hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_warp_per_row_1 +< {{ emb_type }}, + {{ grad_type }}, + {{ cache_type }}, + {{ index_type }}, + {%- for ph_name in args.placeholder_tensor_names %} + {{ ph_type_combo[ph_name].primitive_type }}, + {%- endfor %} + {{ kFixedMaxVecsPerThread }}, + {{ kThreadGroupSize }}, + {{ kUseVecBlocking }} +> ( + const pta::PackedTensorAccessor64<{{ grad_type }}, {{ "1" if is_index_select else "2" }}, at::RestrictPtrTraits> grad_output, + {%- if optimizer != "none" %} + pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> dev_weights, + {%- if not dense %} + pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> uvm_weights, + pta::PackedTensorAccessor64<{{ cache_type }}, 2, at::RestrictPtrTraits> lxu_cache_weights, + const pta::PackedTensorAccessor32 weights_placements, + {%- endif %} + {%- endif %} + const pta::PackedTensorAccessor32 weights_offsets, + {%- if not nobag or is_index_select %} + const pta::PackedTensorAccessor32 D_offsets, + {%- else %} + int64_t D, + {%- endif %} + const pta::PackedTensorAccessor32 hash_size_cumsum, + const pta::PackedTensorAccessor32<{{ index_type }}, 1, at::RestrictPtrTraits> sorted_linear_indices_run, + const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, + {%- if not nobag %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- else %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- endif %} + {%- if not dense %} + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> sorted_{{ locs_or_addrs_tensor }}, + const bool use_uniq_cache_locations, + const pta::PackedTensorAccessor32 table_unique_indices_offsets, + {%- endif %} + {%- if weighted %} + const pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> sorted_indice_weights, + {%- endif %} + const pta::PackedTensorAccessor32 sorted_linear_indices_num_runs, + int32_t max_segment_length_per_warp, + {%- if not dense and optimizer != "none" %} + bool stochastic_rounding, + at::PhiloxCudaState stochastic_rounding_philox_args, + {%- else %} + pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> grad_dev_weights, + {%- endif %} // if not dense and optimizer != "none" + {%- if not nobag and vbe %} + const pta::PackedTensorAccessor32 B_offsets, + const pta::PackedTensorAccessor32 row_output_offsets, + {%- endif %} + {%- if not nobag %} + const int32_t info_B_num_bits, + const uint32_t info_B_mask, + {%- endif %} + const int32_t max_D, + const int32_t max_vecs_per_thread, + {%- if is_index_select %} + const pta::PackedTensorAccessor32 grad_offsets, + const bool permute_output_dim_0_1 + {%- else %} + {{ args.split_kernel_args_no_defaults | + replace_pta_namespace() | + replace_placeholder_types(ph_type_combo) | + join(",\n ") | + replace("cache_t", cache_type) + }} + {%- endif %} +); + +{%- endif %} {%- endmacro %} {%- macro bulk_template_instantiations(kFixedMaxVecsPerThread, kThreadGroupSize, kUseVecBlocking) %} @@ -530,7 +926,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row codegen/embedding_common_code_generator.py for more details */ #} -{{ instantiate_templates(use_subwarp_shuffle=False) }} +{{ instantiate_templates(use_subwarp_shuffle=True) }} //////////////////////////////////////////////////////////////////////////////// #endif @@ -538,7 +934,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row {%- endif %} -{%- if is_rocm and not is_index_select and optimizer == "rowwise_adagrad" and not dense and not is_gwd_kernel and not vbe and not ssd %} +{%- if is_optimized_hip_kernel_supported_mode %} #include #include #include "fbgemm_gpu/rocm/split_embeddings_common.h" @@ -612,12 +1008,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }} {%- endif %} ) { - {%- if not nobag %} int32_t T = D_offsets.size(0) - 1; - {%- else %} - int32_t T = weights_offsets.size(0); - {%- endif %} - auto p_output_grad = grad_output.data(); auto p_emb_table = dev_weights.data(); auto p_hash_size_cumsum = hash_size_cumsum.data(); @@ -632,8 +1023,6 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd constexpr int32_t segment_prefetch = 2; constexpr int32_t segment_unroll = 8; constexpr int32_t segment_split = 0; - auto batch = grad_output.size(0); - auto num_rows = dev_weights.size(0) / T / max_D; {%- if weighted %} constexpr bool is_weighted = true; {%- else %} @@ -646,24 +1035,9 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd // weight_decay(_mode) is supplied as args.split_function_args_no_defaults opt_karg.weight_decay_mode = weight_decay_mode_v; opt_karg.weight_decay = weight_decay; - auto batch_mdiv = [](uint32_t d) -> rocm::magic_div_u32_t { - assert(d >= 1 && d <= INT32_MAX); - uint8_t shift; - for(shift = 0; shift < 32; shift++) - if((1U << shift) >= d) - break; - - uint64_t one = 1; - uint64_t magic = ((one << 32) * ((one << shift) - d)) / d + 1; - assert(magic <= 0xffffffffUL); - - rocm::magic_div_u32_t result; - result.magic = magic; - result.shift = shift; - return result; - }(batch); + rocm::split_tbe_backward_hip_kernel_{{kdesc}}< - rocm::{{optimizer}}_optimizer_t, + rocm::{{optimizer}}_optimizer_t, rocm::{{optimizer}}_kernel_arg_t, emb_t, cache_t, @@ -680,16 +1054,11 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd p_sorted_linear_indices_run, p_sorted_linear_indices_cumulative_run_lengths, p_sorted_linear_indices_num_runs, - {%- if not nobag %} info_B_num_bits, info_B_mask, - {%- endif %} p_sorted_infos, - batch_mdiv, max_segment_length_per_warp, emb_dim, - batch, - num_rows, T, opt_karg {%- if weighted %} @@ -784,7 +1153,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {%- for emb_type in (['float', 'at::Half'] + (['at::Float8_e4m3fnuz'] if is_rocm else ['at::Float8_e4m3fn'])) %} {%- for cache_type in ['float', 'at::Half'] %} {%- for index_type in ['int32_t', 'int64_t'] %} - {%- for kEmbeddingDim in [64, 128, 160, 192, 256] %} + {%- for kEmbeddingDim in [64, 128, 160, 192, 256, 320] %} {%- for kWeighDecayMode in [0, 1, 2] %} {{ hip_template_instantiation( emb_type, diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu old mode 100644 new mode 100755 index 76eba64c99..f29e32024c --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -48,6 +48,23 @@ using namespace fbgemm_gpu; has_global_weight_decay_support, ssd) %} {%- set desc_suffix = get_desc_suffix(is_gwd_kernel) %} +{%- set is_optimized_hip_kernel_supported_mode = is_rocm and + optimizer == "rowwise_adagrad" and + not dense and + not nobag and + not is_index_select and + not is_gwd_kernel and + not vbe and + not ssd %} + +{%- set enable_optimized_hip_mixed_D_kernel = is_rocm and + optimizer == "rowwise_adagrad" and + not dense and + not is_index_select and + not is_gwd_kernel and + not nobag and + not ssd %} + template < typename emb_t, typename grad_t, @@ -227,8 +244,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( {%- endif %} ); -{%- if is_rocm and optimizer == "rowwise_adagrad" and not dense and not is_index_select - and not is_gwd_kernel and not vbe and not ssd %} +{%- if is_optimized_hip_kernel_supported_mode %} #include "fbgemm_gpu/rocm/split_embeddings_common.h" template < typename emb_t, @@ -299,6 +315,147 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {%- endif %} ); {%- endif %} + +{%- if enable_optimized_hip_mixed_D_kernel %} + +template < + typename emb_t, + typename grad_t, + typename cache_t, + typename index_t, + {%- for ph_name in args.placeholder_tensor_names %} + typename {{ ph_name + "_ph_t" }}, + {%- endfor %} + int32_t kFixedMaxVecsPerThread, + int32_t kThreadGroupSize, + bool kUseVecBlocking> +__global__ __launch_bounds__(kMaxThreads) void +hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_cta_per_row_1( + const pta::PackedTensorAccessor64 grad_output, + {%- if optimizer != "none" %} + pta::PackedTensorAccessor64 dev_weights, + {%- if not dense %} + pta::PackedTensorAccessor64 uvm_weights, + pta::PackedTensorAccessor64 lxu_cache_weights, + const pta::PackedTensorAccessor32 weights_placements, + {%- endif %} + {%- endif %} // if optimizer != "none" + const pta::PackedTensorAccessor32 weights_offsets, + {%- if not nobag or is_index_select %} + const pta::PackedTensorAccessor32 D_offsets, + {%- else %} + int64_t D, + {%- endif %} + const pta::PackedTensorAccessor32 hash_size_cumsum, + const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, + const pta::PackedTensorAccessor32 long_run_ids, + const pta::PackedTensorAccessor32 num_long_run_ids, + {%- if not nobag %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- else %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- endif %} + {%- if not dense %} + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> sorted_{{ locs_or_addrs_tensor }}, + const bool use_uniq_cache_locations, + const pta::PackedTensorAccessor32 table_unique_indices_offsets, + {%- endif %} + {%- if weighted %} + const pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> sorted_indice_weights, + {%- endif %} + {%- if not dense and optimizer != "none" %} + bool stochastic_rounding, + at::PhiloxCudaState stochastic_rounding_philox_args, + {%- else %} + pta::PackedTensorAccessor64 grad_dev_weights, + {%- if optimizer == "none" %} + const int32_t max_D, + {%- endif %} + {%- endif %} // if not dense and optimizer != "none" + {%- if vbe %} + const pta::PackedTensorAccessor32 B_offsets, + const pta::PackedTensorAccessor32 row_output_offsets, + {%- endif %} + {%- if not nobag %} + const int32_t info_B_num_bits, + const uint32_t info_B_mask, + {%- endif %} + const pta::PackedTensorAccessor32 long_run_id_to_really_long_run_ids, + pta::PackedTensorAccessor32, 2, at::RestrictPtrTraits> temp_grad_accum, + pta::PackedTensorAccessor32 grad_accum_counter, + const int32_t max_segment_length_per_cta, + const bool use_deterministic_algorithms, + const int32_t max_vecs_per_thread, + {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }} +); + +template < + typename emb_t, + typename grad_t, + typename cache_t, + typename index_t, + {%- for ph_name in args.placeholder_tensor_names %} + typename {{ ph_name + "_ph_t" }}, + {%- endfor %} + int32_t kFixedMaxVecsPerThread, + int32_t kThreadGroupSize, + bool kUseVecBlocking> +__global__ __launch_bounds__(kBackwardMaxThreads) void +hip_mixed_d_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_warp_per_row_1( + const pta::PackedTensorAccessor64 grad_output, + {%- if optimizer != "none" %} + pta::PackedTensorAccessor64 dev_weights, + {%- if not dense %} + pta::PackedTensorAccessor64 uvm_weights, + pta::PackedTensorAccessor64 lxu_cache_weights, + const pta::PackedTensorAccessor32 weights_placements, + {%- endif %} + {%- endif %} + const pta::PackedTensorAccessor32 weights_offsets, + {%- if not nobag or is_index_select %} + const pta::PackedTensorAccessor32 D_offsets, + {%- else %} + int64_t D, + {%- endif %} + const pta::PackedTensorAccessor32 hash_size_cumsum, + const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, + {%- if not nobag %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- else %} + const pta::PackedTensorAccessor32 sorted_infos, + {%- endif %} + {%- if not dense %} + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> sorted_{{ locs_or_addrs_tensor }}, + const bool use_uniq_cache_locations, + const pta::PackedTensorAccessor32 table_unique_indices_offsets, + {%- endif %} + {%- if weighted %} + const pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> sorted_indice_weights, + {%- endif %} + const pta::PackedTensorAccessor32 sorted_linear_indices_num_runs, + int32_t max_segment_length_per_warp, + {%- if not dense and optimizer != "none" %} + bool stochastic_rounding, + at::PhiloxCudaState stochastic_rounding_philox_args, + {%- else %} + pta::PackedTensorAccessor64 grad_dev_weights, + {%- endif %} // if not dense and optimizer != "none" + {%- if vbe %} + const pta::PackedTensorAccessor32 B_offsets, + const pta::PackedTensorAccessor32 row_output_offsets, + {%- endif %} + {%- if not nobag %} + const int32_t info_B_num_bits, + const uint32_t info_B_mask, + {%- endif %} + const int32_t max_D, + const int32_t max_vecs_per_thread, + {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }} +); +{%- endif %} + {% if is_index_select %} namespace index_select { {% else %} @@ -652,6 +809,16 @@ Tensor {{ embedding_cuda_op }}( CUDA_DEVICE_GUARD(dev_weights); + #ifdef USE_ROCM + if (!rocm::is_supported_cdna()) { + TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal."); + } + else { + // Ensure we're running on a supported CDNA architecture (including MI350) + TORCH_WARN_ONCE("Running on CDNA architecture"); + } + #endif + {%- if nobag and not is_index_select %} auto max_D = D; {%- endif %} @@ -852,15 +1019,24 @@ Tensor {{ embedding_cuda_op }}( } {%- endif %} - {%- if is_rocm and optimizer == "rowwise_adagrad" and not dense and not is_index_select - and not is_gwd_kernel and not vbe and not ssd %} + {%- if is_optimized_hip_kernel_supported_mode %} {%- set hip_kernel = "hip_split_embedding{}_backward_codegen_{}_{}{}_kernel_warp_per_row_1".format( ndesc, optimizer, wdesc, vdesc, ) - %} + %} + {%- endif %} + + {%- if enable_optimized_hip_mixed_D_kernel %} + {%- set hip_mixed_d_warp_kernel = "hip_mixed_d_split_embedding{}_backward_codegen_{}_{}{}_kernel_warp_per_row_1".format( + ndesc, + optimizer, + wdesc, + vdesc, + ) + %} {%- endif %} AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "{{ embedding_cuda_op }}_2", [&] { @@ -970,8 +1146,11 @@ Tensor {{ embedding_cuda_op }}( auto num_long_run_ids = at::zeros({1}, indices.options().dtype(at::kInt)); const bool use_deterministic_algorithms = at::globalContext().deterministicAlgorithms(); - const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 1024; - + {% if is_rocm %} + const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 4096; + {% else %} + const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 1024; + {%- endif %} Tensor long_run_id_to_really_long_run_ids; if (use_deterministic_algorithms) { long_run_id_to_really_long_run_ids = @@ -1009,6 +1188,10 @@ Tensor {{ embedding_cuda_op }}( {use_deterministic_algorithms ? 0 : grad_accum_counter.numel(), max_D}, aligned_grad_output.options().dtype(std::is_same::value ? at::kDouble : at::kFloat)); + {%- if enable_optimized_hip_mixed_D_kernel %} + const static auto use_hip_kernel = fbgemm_gpu::config::is_feature_enabled(fbgemm_gpu::config::FeatureGateName::TBE_ROCM_HIP_BACKWARD_KERNEL); + {%- endif %} + DISPATCH_PLACEHOLDER_TYPES( {%- for ph_name in args.placeholder_tensor_names %} {{ ph_name + "_dev" }}.scalar_type(), @@ -1027,7 +1210,7 @@ Tensor {{ embedding_cuda_op }}( ) %} - const auto backward_cta_per_row_kernel = + auto backward_cta_per_row_kernel = {{ cta_kernel }} ; + + {% if is_rocm %} + int32_t total_L = indices.numel(); + int32_t num_cta_per_row_groups; + int32_t work_group_size; + if (total_L/total_B > 1) { + num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize; + work_group_size = (kMaxThreads/4); + } + else { + num_cta_per_row_groups = kMaxThreads / kWarpSize; + work_group_size = kMaxThreads; + } + {%- else %} + int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; + const int32_t work_group_size = kMaxThreads; + {%- endif %} + {%- if enable_optimized_hip_mixed_D_kernel %} + auto cta_blockSize = dim3(kThreadGroupSize, num_cta_per_row_groups); + if (max_D <= 128) { + backward_cta_per_row_kernel = + {{ cta_kernel }} + ; + + cta_blockSize = dim3(32, num_cta_per_row_groups); + } + {%- else %} + auto cta_blockSize = dim3(kThreadGroupSize, num_cta_per_row_groups); + {%- endif %} // Compute shared memory size for cta_per_row constexpr auto kCacheAccBytes = sizeof(at::acc_type); - int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; const size_t cta_per_row_smem_bytes = compute_num_groups_and_dynamic_smem_bytes( &num_cta_per_row_groups, [&] (int num_groups) { @@ -1053,13 +1273,13 @@ Tensor {{ embedding_cuda_op }}( ); const int32_t cta_per_row_grid_size = std::min( - div_round_up(total_unique_indices, kMaxThreads), + div_round_up(total_unique_indices, work_group_size), get_max_thread_blocks_()); FBGEMM_LAUNCH_KERNEL( backward_cta_per_row_kernel, cta_per_row_grid_size, - dim3(kThreadGroupSize, num_cta_per_row_groups), + cta_blockSize, cta_per_row_smem_bytes, at::cuda::getCurrentCUDAStream(), grad_output_accessor, @@ -1161,8 +1381,53 @@ Tensor {{ embedding_cuda_op }}( kThreadGroupSize, kUseVecBlocking>; - // Compute shared memory size for warp_per_row - int32_t num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; + {%- if is_rocm %} + int32_t num_warp_per_row_groups; + if (total_L/total_B > 1){ + num_warp_per_row_groups = (kBackwardMaxThreads/2) / kThreadGroupSize; + } + else{ + num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; + } + {%- else %} + int32_t num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; + {%- endif %} + auto blockSize = dim3(kThreadGroupSize, num_warp_per_row_groups); + {%- if enable_optimized_hip_mixed_D_kernel %} + {%- if vbe %} + if (use_hip_kernel) { + {%- else %} + if (use_hip_kernel && mixed_D) { + {%- endif %} + backward_warp_per_row_kernel = + {{ hip_mixed_d_warp_kernel }} + ; + if (max_D <= 128) { + backward_warp_per_row_kernel = + {{ hip_mixed_d_warp_kernel }} + ; + blockSize = dim3(32, num_warp_per_row_groups); + } + } + {%- endif %} int32_t warp_per_row_smem_bytes = 0; if constexpr (kUseVecBlocking) { @@ -1177,26 +1442,22 @@ Tensor {{ embedding_cuda_op }}( backward_warp_per_row_kernel, used_shared_bytes); } - - auto blockSize = dim3(kThreadGroupSize, num_warp_per_row_groups); - int32_t warp_per_row_grid_size = std::min( div_round_up(total_unique_indices, num_warp_per_row_groups), get_max_thread_blocks_()); #ifdef USE_ROCM - {%- if is_rocm and not is_index_select and optimizer == "rowwise_adagrad" and - not dense and not is_gwd_kernel and not vbe and not ssd and not nobag %} + {%- if is_optimized_hip_kernel_supported_mode %} const static auto use_hip_kernel = fbgemm_gpu::config::is_feature_enabled(fbgemm_gpu::config::FeatureGateName::TBE_ROCM_HIP_BACKWARD_KERNEL); - const auto supported_weights_type = dev_weights.scalar_type() == at::ScalarType::Half - || dev_weights.scalar_type() == at::ScalarType::Float; + constexpr bool supported_weights_type = std::is_same_v || std::is_same_v; + constexpr bool supported_grad_type = std::is_same_v || std::is_same_v; - if (use_hip_kernel && supported_weights_type && !mixed_D && rocm::is_supported_cdna()) + if (use_hip_kernel && !mixed_D && supported_weights_type && supported_grad_type && rocm::is_supported_cdna()) { constexpr int segments_per_workgroup = 4; - {%- for kDimSize in [64, 128, 160, 192, 256] %} + {%- for kDimSize in [64, 128, 160, 192, 256, 320] %} {%- for kWeightDecayMode in [0, 1, 2] %} if (max_D == {{ kDimSize }} && weight_decay_mode == {{ kWeightDecayMode }}) { @@ -1221,7 +1482,6 @@ Tensor {{ embedding_cuda_op }}( } {%- endif %} #endif - FBGEMM_LAUNCH_KERNEL( backward_warp_per_row_kernel, warp_per_row_grid_size, diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip index 2fcbba395e..cd3d645775 100644 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -27,7 +27,7 @@ #include "fbgemm_gpu/rocm/split_embeddings_common.h" namespace fbgemm_gpu::rocm { -template +template struct rowwise_adagrad_optimizer_t { __device__ rowwise_adagrad_optimizer_t(const rowwise_adagrad_kernel_arg_t& karg_) @@ -36,7 +36,7 @@ struct rowwise_adagrad_optimizer_t } template - __device__ void update(cache_t* acc, emb_t* weight, uint32_t row_index) + __device__ void update(cache_t* acc, emb_t* weight, index_t row_index) { if constexpr(segment_split == 0) { @@ -122,20 +122,11 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( const index_t* p_sorted_linear_indices_run, const int32_t* p_sorted_linear_indices_cumulative_run_lengths, const int32_t* p_sorted_linear_indices_num_runs, - {%- if not nobag %} const int32_t info_B_num_bits, const uint32_t info_B_mask, - {%- endif %} - {%- if not nobag %} const int32_t* p_sorted_infos, - {%- else %} - const int64_t* p_sorted_infos, - {%- endif %} - magic_div_u32_t batch_mdiv, uint32_t max_segment_length_per_warp, uint32_t emb_dim, - uint32_t batch, - uint32_t num_rows, uint32_t num_tables, optimizer_karg_t opt_karg, const float * p_sorted_indice_weights = nullptr) @@ -157,13 +148,9 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( const int32_t segment_start = p_sorted_linear_indices_cumulative_run_lengths[run_id]; const int32_t segment_end = p_sorted_linear_indices_cumulative_run_lengths[run_id + 1]; - {%- if nobag %} - const auto info_0 = p_sorted_infos[segment_start]; - int32_t t_0 = info_0 % num_tables; - {%- else %} const auto info_0 = reinterpret_cast(&p_sorted_infos[0])[segment_start]; const auto t_0 = info_0 >> info_B_num_bits; - {%- endif %} + int64_t hash_size = p_hash_size_cumsum[t_0]; const int64_t emb_idx = linear_index - hash_size; @@ -179,7 +166,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( const int32_t segment_length_mod = segment_length & length_mask; cache_t grad_acc[dword_per_row]; - int32_t infos[segment_unroll]; + uint32_t infos[segment_unroll]; grad_t grad_data[dword_per_row * segment_prefetch]; emb_t emb_data[dword_per_row]; float indice_weights[segment_unroll]; @@ -221,22 +208,16 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( // LOOP for(; itr < segment_length_mod; itr += segment_unroll) { - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[1], batch, table_index, bag_index); - {%- else %} table_index = infos[1] >> info_B_num_bits; bag_index = infos[1] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); if constexpr (!weighted){ #pragma unroll @@ -244,24 +225,20 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -284,24 +261,20 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id, indice_weights[j-2]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[j-1]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -322,22 +295,16 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( } // LAST - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[1], batch, table_index, bag_index); - {%- else %} table_index = infos[1] >> info_B_num_bits; bag_index = infos[1] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); if constexpr (!weighted) { @@ -346,24 +313,20 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -377,24 +340,20 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id, indice_weights[j-2]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[j-1]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -414,13 +373,10 @@ L_tail_grad_acc: infos[0] = p_sorted_infos[segment_start]; p_sorted_infos++; - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id); @@ -435,13 +391,10 @@ L_tail_grad_acc: p_sorted_infos++; p_sorted_indice_weights++; - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id, indice_weights[0]); @@ -452,11 +405,11 @@ L_tail_grad_acc: } // load the old emb weight data - load_row_per_warp::run( + load_row_per_warp::run( &emb_data[0], emb_idx, p_emb_table, lane_id); optimizer_t optimizer(opt_karg); optimizer.template update(grad_acc, emb_data, emb_idx); - store_row_per_warp::run(&emb_data[0], p_emb_table + emb_idx * embedding_dim, lane_id); + store_row_per_warp::run(&emb_data[0], p_emb_table + emb_idx * embedding_dim, lane_id); } } // namespace fbgemm_gpu::rocm diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu old mode 100644 new mode 100755 index a39d33e391..acbf4563f3 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu @@ -84,11 +84,7 @@ using namespace fbgemm_gpu; {#-/* Set the weights row accessor */#} - {%- if is_rocm %} - const auto weights_row = rocm::WeightRowAccessorVec2 - {%- else %} const auto weights_row = WeightRowAccessor - {%- endif %} < {{ 'cache_t' if from_cache else 'emb_t' }}, cache_t @@ -182,11 +178,7 @@ using namespace fbgemm_gpu; {%- endif %} {#-/* Set the weights row accessor */#} - {%- if is_rocm %} - const auto weights_row = rocm::WeightRowAccessorVec2 - {%- else %} const auto weights_row = WeightRowAccessor - {%- endif %} < {{ 'cache_t' if from_cache else 'emb_t' }}, cache_t @@ -319,7 +311,7 @@ using namespace fbgemm_gpu; {%- if is_rocm %} {%- if not nobag %} - rocm::Vec2T vals[kManualUnrollLength * kMaxVecsPerThread]; + Vec4T vals[kManualUnrollLength * kMaxVecsPerThread]; {%- endif %} // Iterate over kThreadGroupSize indices for (auto outer_j = 0; outer_j < kThreadGroupSize && l_start + outer_j < L - L % kManualUnrollLength; outer_j += kManualUnrollLength) @@ -469,10 +461,10 @@ using namespace fbgemm_gpu; {%- endif %} {%- if is_rocm %} - for(auto j = L % kThreadGroupSize - L % kManualUnrollLength; l_start + kThreadGroupSize > L && l_start + j < L; ++j) { + for(auto j = L % kThreadGroupSize - L % kManualUnrollLength; l_start + (kThreadGroupSize) > L && l_start + j < L; ++j) { {%- else %} // Iterate over kThreadGroupSize indices - for (auto j = 0; j < kThreadGroupSize && l_start + j < L; ++j) { + for (auto j = 0; j < (kThreadGroupSize) && l_start + j < L; ++j) { {%- endif %} {%- if dense or lxu_miss_rate != "cache_conflict_miss_rate::zero" %} // Load index from thread j in the group @@ -633,12 +625,7 @@ batch_index_select_dim0_codegen_forward_kernel( #endif // Elements are processed 4 at a time through fbgemm_gpu::Vec4 (CUDA float4, 16 bytes) - // for CUDA devices and 2 at a time for ROCm - {%- if is_rocm %} - constexpr int VEC_WIDTH = 2; - {%- else %} constexpr int VEC_WIDTH = 4; - {%- endif %} {%- if is_rocm %} // Unroll factor for ROCm devices constexpr int kManualUnrollLength = 4; @@ -743,12 +730,8 @@ batch_index_select_dim0_codegen_forward_kernel( const float inv_L = (mean_pooling && L != 0) ? static_cast(1.0) / L: static_cast(1.0); // Set up the accumulator buffer - {%- if is_rocm %} - rocm::Vec2T accumulators[kMaxVecsPerThread]; - {%- else %} Vec4T accumulators[kMaxVecsPerThread]; {%- endif %} - {%- endif %} {%- if dense %} {{ embedding_pool_or_store("NULL") }} @@ -930,7 +913,7 @@ batch_index_select_dim0_codegen_forward_kernel {%- endmacro %} {%- macro bulk_template_instantiations(use_cache, kMaxVecsPerThread, kThreadGroupSize) %} - {%- set max_vecs_per_thread = 2 * kMaxVecsPerThread if is_rocm else kMaxVecsPerThread %} + {%- set max_vecs_per_thread = kMaxVecsPerThread %} {%- for emb_type in (['float', 'at::Half'] + (['at::Float8_e4m3fnuz'] if is_rocm else ['at::Float8_e4m3fn'])) %} {%- for cache_type in ['float', 'at::Half'] %} {%- for output_type in ['float', 'at::Half', 'at::BFloat16'] %} diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_v2_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_v2_template.cu index 42f499c6dd..34ce2c6f13 100644 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_v2_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_v2_template.cu @@ -975,6 +975,13 @@ __global__ void split_embedding_codegen_forward_{{ wdesc }}_v2_kernel( else if (tail_warp_size <= 16) { INVOKE_PROCESS_ALL_INDICES(large_Ls, 16, 0x55) } +#if defined(USE_ROCM) + // not sure step mask value to set when group size is 32 + // while use_lxu_cache is false step mask makes no sense + else if (tail_warp_size <= 32 && !use_lxu_cache) { + INVOKE_PROCESS_ALL_INDICES(large_Ls, 32, 0xf) + } +#endif else { INVOKE_PROCESS_ALL_INDICES(large_Ls, kWarpSize, 0xf) } diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_meta_template.cpp b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_meta_template.cpp index 09630b57cf..e2705d16fd 100644 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_meta_template.cpp +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_meta_template.cpp @@ -6,6 +6,7 @@ * LICENSE file in the root directory of this source tree. */ +// clang-format off {# // @lint-ignore LINTIGNORE // @lint-ignore-every CLANGFORMAT @@ -103,7 +104,12 @@ Tensor const int64_t iter, const double gwd_lower_bound, {%- endif %} + {%- if vbe and not dense %} + const bool is_experimental, + std::optional vbe_output + {%- else %} const bool is_experimental + {%- endif %} ) { // NB: omitted the device tests TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL {%- if not nobag %} @@ -210,4 +216,4 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { {%- endfor %} {#-/* for is_gwd */#} {%- endif %} {#/* if (not nobag or (not weighted and not vbe)) */#} {%- endfor %} {#-/* for nobag */#} - // clang-format on + // clang-format on diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu old mode 100644 new mode 100755 index 6574bda45e..a3edb6b965 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu @@ -6,10 +6,10 @@ * LICENSE file in the root directory of this source tree. */ -{# // @lint-ignore LINTIGNORE // @lint-ignore-every CLANGFORMAT // clang-format off +{# // Note: clang-format off doesn't work with this templaterized code, // so we need to keep lint-ignore-every. // See https://fburl.com/dw9ljh4h @@ -31,6 +31,10 @@ #include "fbgemm_gpu/utils/dispatch_macros.h" {%- endif %} +{%- if is_rocm %} +#include "fbgemm_gpu/rocm/cdna_guard.h" +{%- endif %} + {%- if not is_index_select %} //////////////////////////////////////////////////////////////////////////////// // Required for op registrations @@ -391,7 +395,12 @@ batch_index_select_dim0_codegen_forward_cuda( const int64_t iter, const double gwd_lower_bound, {%- endif %} + {%- if vbe and not dense %} + const bool is_experimental, + std::optional vbe_output + {%- else %} const bool is_experimental + {%- endif %} {%- endif %} {#- /*if is_index_select*/ #} ) { {%- if not nobag or is_index_select %} @@ -454,6 +463,16 @@ batch_index_select_dim0_codegen_forward_cuda( CUDA_DEVICE_GUARD(dev_weights); + {% if is_rocm %} + if (!rocm::is_supported_cdna()) { + TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal."); + } + else { + // Ensure we're running on a supported CDNA architecture (including MI350) + TORCH_WARN_ONCE("Running on CDNA architecture"); + } + {%- endif %} + {%- if not nobag %} int32_t T = D_offsets.numel() - 1; {%- else %} @@ -529,11 +548,24 @@ batch_index_select_dim0_codegen_forward_cuda( o_dtype == SparseType::BF16 || o_dtype == SparseType::INT8); {%- if vbe %} - // Use a 2D tensor to make it compatible with 2D PackedTensorsAccessor of other output + {%- if dense %} output = at::empty( {1, vbe_output_size}, dev_weights.options().dtype(getScalarType(o_dtype)) - ); + ); + {%- else %} + // Use a 2D tensor to make it compatible with 2D PackedTensorsAccessor of other output + TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(vbe_row_output_offsets, vbe_output); + if (vbe_output.has_value()){ + output = vbe_output.value().reshape({1, -1}); + } + else { + output = at::empty( + {1, vbe_output_size}, + dev_weights.options().dtype(getScalarType(o_dtype)) + ); + } + {%- endif %} {#-/* if dense */#} {%- else %} int64_t total_adjusted_D = total_D; if (o_dtype == SparseType::INT8) { @@ -702,12 +734,7 @@ batch_index_select_dim0_codegen_forward_cuda( // kFixedMaxVecsPerThread instead of kMaxVecsPerThread. But // kMaxVecsPerThread and kFixedMaxVecsPerThread are the same // forward - {%- if is_rocm %} - // Account for Vec2 load for ROCm - constexpr auto kMaxVecsPerThread = 2 * kFixedMaxVecsPerThread; - {%- else %} constexpr auto kMaxVecsPerThread = kFixedMaxVecsPerThread; - {%- endif %} const auto grid = min( div_round_up(total_B, kForwardMaxThreads / kThreadGroupSize), @@ -781,9 +808,14 @@ batch_index_select_dim0_codegen_forward_cuda( // if (!is_experimental) } else { // Allocate num warps per table based on max_D + const int num_warps_per_table = B * div_round_up(max_D, kWarpSize * 4); - const uint32_t num_warps_per_threadblock = kForwardMaxThreads / kWarpSize; - + #ifdef USE_ROCM + const uint32_t num_warps_per_threadblock = kForwardMaxThreads / (kWarpSize * 2); + #else + const uint32_t num_warps_per_threadblock = kForwardMaxThreads / kWarpSize; + #endif + const auto kernel_func = (use_lxu_cache ? split_embedding_codegen_forward_{{ wdesc }}_v2_kernel< emb_t, cache_t, output_t, index_t, true> @@ -877,7 +909,12 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { " int iter, " " float gwd_lower_bound, " {%- endif %} + {%- if vbe and not dense %} + " bool is_experimental," + " Tensor? vbe_output" + {%- else %} " bool is_experimental" + {%- endif %} ") -> Tensor" {%- if not dense and not nobag and not vbe %} // only split_embedding_codegen_forward_[un]weighted_cuda diff --git a/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh b/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh index e4fb6c548c..b4c943f769 100644 --- a/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh +++ b/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh @@ -11,8 +11,42 @@ #include "fbgemm_gpu/utils/tensor_accessor_builder.h" #include "fbgemm_gpu/split_embeddings_utils.cuh" -#define GROUP_REDUCE_ALL_SUM(val, ...) \ - warpReduceAllSum<__VA_ARGS__, kThreadGroupSize>(val, shfl_sync_mask) +{%- set enable_optimized_hip_mixed_D_kernel = is_rocm and + optimizer == "rowwise_adagrad" and + not dense and + not is_index_select and + not is_gwd_kernel and + not nobag and + not ssd %} + +template +DEVICE_INLINE __device__ T subwarp_reduce_add(T value) { + static_assert(kThreadGroupSize == 8 || kThreadGroupSize == 16 || kThreadGroupSize == 32 || kThreadGroupSize == 64, "Wavefront size must be 16/32/64"); + if (kThreadGroupSize == 16) { + // Reduce across 4 groups of 16 threads + value += __shfl_xor(value, 1, 16); + value += __shfl_xor(value, 2, 16); + value += __shfl_xor(value, 4, 16); + value += __shfl_xor(value, 8, 16); + } else if (kThreadGroupSize == 32) { + // Reduce across 2 groups of 32 threads + value += __shfl_xor(value, 1, 32); + value += __shfl_xor(value, 2, 32); + value += __shfl_xor(value, 4, 32); + value += __shfl_xor(value, 8, 32); + value += __shfl_xor(value, 16, 32); + } else if (kThreadGroupSize == 64) { + value += __shfl_xor(value, 1, 64); + value += __shfl_xor(value, 2, 64); + value += __shfl_xor(value, 4, 64); + value += __shfl_xor(value, 8, 64); + value += __shfl_xor(value, 16, 64); + value += __shfl_xor(value, 32, 64); + } + return value; +} + +#define GROUP_REDUCE_ALL_SUM(val, ...) subwarp_reduce_add(val) {%- set mdesc = "ssd" if ssd else "split" %} {%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} @@ -176,4 +210,164 @@ DEVICE_INLINE void {{ mdesc }}_{{ optimizer }}_table_update_kernel( {{ split_post_update }} } +{%- if enable_optimized_hip_mixed_D_kernel %} +template < + typename emb_t, + typename cache_t, + {%- for ph_name in args.placeholder_tensor_names %} + {%- set ph_type = "{}_ph_t".format(ph_name) %} + typename {{ ph_type }}, + {%- endfor %} + int32_t kFixedMaxVecsPerThread, + int32_t kThreadGroupSize = kWarpSize, + int32_t VEC_WIDTH, + bool kUseVecBlocking +> +DEVICE_INLINE void {{ mdesc }}_{{ optimizer }}_table_update_kernel( + pta::PackedTensorAccessor64& dev_weights, + pta::PackedTensorAccessor64& uvm_weights, + pta::PackedTensorAccessor64& lxu_cache_weights, + const int32_t weights_placement, + const int64_t weights_offset, + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits>& sorted_{{ locs_or_addrs_tensor }}, + Vec4TAcc* grad_sum, + Vec4TAcc* smem_grad_sum, + Vec4TAcc* shared_weight_update_row, + const bool stochastic_rounding, + const at::PhiloxCudaState& stochastic_rounding_philox_args, + const uint32_t run_id, + const uint32_t cache_loc_run_id, + const int32_t D, + const int32_t t, + const int64_t idx, + {%- if has_global_weight_decay_support %} + const float global_weight_decay, + {%- endif %} + const uint32_t shfl_sync_mask, + const int32_t max_vecs_per_thread, + {%- if ssd %} + const bool enable_optimizer_offloading, + {%- endif %} + {%- for tensor in args.split_tensors %} + const int32_t {{ tensor }}_placement, + const int64_t {{ tensor }}_offset, + const {{ args.split_tensor_types[tensor] }} {{ tensor }}_val, + {%- endfor %} + {{ args.split_ref_kernel_args | replace_pta_namespace() | join(",\n ") }} +) { + constexpr auto kIsInt8 = std::is_same_v; + // Copy value to max_vecs to make max_vecs_per_thread known at compile time + // when kUseVecBlocking == false + const int32_t max_vecs = + kUseVecBlocking ? max_vecs_per_thread : kFixedMaxVecsPerThread; + emb_t* __restrict__ weights {nullptr}; + cache_t* __restrict__ cache_weights {nullptr}; + int32_t D_emb = D; + if constexpr (kIsInt8) { + D_emb += kINT8QparamsBytes; + } + if (static_cast(weights_placement) == PlacementType::DEVICE) { + weights = &dev_weights[weights_offset + idx * D_emb]; + } else { + weights = {{ "nullptr" if ssd else "&uvm_weights[weights_offset + idx * D_emb]" }}; + } + if (static_cast(weights_placement) == PlacementType::MANAGED_CACHING) { + const auto {{ locs_or_addrs_idx }} = sorted_{{ locs_or_addrs_tensor }}[cache_loc_run_id]; + {%- if ssd %} + cache_weights = reinterpret_cast( + *reinterpret_cast(&{{ locs_or_addrs_idx }})); + {%- else %} + if ({{ locs_or_addrs_idx }} != kCacheLocationMissing) { + cache_weights = &lxu_cache_weights[{{ locs_or_addrs_idx }}][0]; + } + {%- endif %} + } + {%- for tensor in args.split_tensors %} + {{ args.split_tensor_types[tensor] }}* __restrict__ {{ tensor }}; + // const auto {{ tensor }}_placement = static_cast({{ tensor }}_placements[t]); + // const int64_t {{ tensor }}_offset = {{ tensor }}_offsets[t]; + if (static_cast({{ tensor }}_placement) == PlacementType::DEVICE) { + {{ tensor }} = &{{ tensor }}_dev[{{ tensor }}_offset]; + } else { + {{ tensor }} = &{{ tensor }}_uvm[{{ tensor }}_offset]; + } + {%- endfor %} + + auto weight_row_template = + WeightRow>( + weights, + cache_weights, + D, + stochastic_rounding, + &stochastic_rounding_philox_args, + threadIdx.x + run_id * blockDim.x); + + float2 qparams_template; + if constexpr (kIsInt8) { + if (!cache_weights) { + qparams_template = weight_row_template.load_qparams(); + } + } + + {%- if not ssd %} + [[maybe_unused]] constexpr auto enable_optimizer_offloading = false; + {%- endif %} + + {{ split_precomputation_preload }} + + {# /* Note: technically, global weight decay (gwd) compensation should be done before + `split_precomputation`). But since decouple mode in `rowwise_adagrad` only computes correction, + the order of applying gwd does not matter. We perform gwd update before `split_weight_update` + below to minimize number of times to load weights. + So, note that the behavior may be different if you want to enable gwd for other optimizers + such as `lamb` or `partial_rowwise_lamb`. + */#} + float2 qparams_new; + {{ + generate_optimized_grad_sum_loop_access( + """ + Vec4TAcc weight_new = weight_row_template.load(d, qparams_template); + Vec4TAcc& grad = {grad_vec}; + {global_weight_decay_update} + {split_weight_update} + if (kIsInt8 && !cache_weights) { + shared_weight_update_row[d_vec] = weight_new; + } else { + // qparams_new not used if type is not int8 + weight_row_template.store(weight_new, d, qparams_new); + } + """, + other_formats={ + "split_weight_update": split_weight_update, + "global_weight_decay_update": "weight_new.mul_(global_weight_decay);" if has_global_weight_decay_support else "" + }, + ) + }} + + if constexpr (kIsInt8) { + if (!cache_weights) { + // Calculate new qparams after row update + qparams_new = thrust_find_qparams>( + shared_weight_update_row, D); + weight_row_template.store_qparams(qparams_new); + + // Fetch cached updated row from shared mem and quantize on-the-fly + // when saving to lowp embedding + for (int32_t vec = 0; + (vec * kThreadGroupSize + threadIdx.x) * VEC_WIDTH < D; + ++vec) { + const auto d_vec = vec * kThreadGroupSize + threadIdx.x; + const int32_t d = d_vec * VEC_WIDTH; + weight_row_template.store( + shared_weight_update_row[d_vec], + d, + qparams_new); + } + } + } + + {{ split_post_update }} +} +{%- endif %} + // clang-format on diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index 3720f1ea42..a2304b3fb3 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -137,7 +137,12 @@ enum SSDTensor { const double /*gwd_lower_bound*/, {%- endif %} const bool /*is_experimental*/, + {%- if vbe and not dense %} + const int64_t /*output_dtype*/, + std::optional /*vbe_output*/ + {%- else %} const int64_t /*output_dtype*/ + {%- endif %} )>(); auto output = embedding_codegen_forward_op.call( @@ -186,7 +191,12 @@ enum SSDTensor { {%- endif %} {# /* if is_gwd */ #} {%- endif %} {# /* if not nobag */ #} is_experimental, + {%- if vbe and not dense %} + output_dtype, + vbe_output + {%- else %} output_dtype + {%- endif %} ); if (is_annotate_trace_enabled) { @@ -259,7 +269,7 @@ enum SSDTensor { const bool /*use_homogeneous_placements*/, {%- if ssd %} const bool /*enable_optimizer_offloading*/, - {%- endif %} + {%- endif %} {%- if is_gwd %} {%- if "prev_iter_dev" not in args_pt2.split_function_arg_names %} const Tensor& /*prev_iter_dev*/, @@ -359,6 +369,14 @@ enum SSDTensor { // The number of items in the tensorlist differ between devices and is determined at runtime std::vector ret; + {%- if vbe and not dense %} + // To avoid overhead of merging multiple VBE embedding outputs, each embedding ops return + // the same output tensor i.e., vbe_output. To ensure all backward ops are triggered, the embedding + // ops are called in chain. We hence need to pass the grad_outputs to the next embedding op. + // So, if vbe_output is passed, we return the grad_outputs. + Tensor grad_vbe_output = has_vbe_output ? grad_outputs[0] : Variable(); + {%- endif %} + {%- if not dense %} ret.push_back(Variable()); // placeholder autograd tensor {%- endif %} @@ -400,18 +418,21 @@ enum SSDTensor { ret.push_back(Variable()); // max_B ret.push_back(Variable()); // max_B_feature_rank ret.push_back(Variable()); // vbe_output_size + {%- if not dense %} + ret.push_back(grad_vbe_output); // vbe_output + {%- endif %} {# /* if not dense */ #} {%- endif %} {# /* if vbe */ #} {%- if not dense %} ret.push_back(Variable()); // aux_tensor ret.push_back(Variable()); // aux_int ret.push_back(Variable()); // aux_float ret.push_back(Variable()); // aux_bool - {%- endif %} + {%- endif %} {# /* if not dense */ #} {%- if ssd %} {%- for tensor in ssd_tensors %} ret.push_back(Variable()); // {{ tensor }} {%- endfor %} - {%- endif %} + {%- endif %} {# /* if ssd */ #} {{ args_pt2.unified_pt2.split_variables | join("\n") }} return ret; {%- endmacro %} @@ -472,6 +493,9 @@ enum SSDTensor { max_B, max_B_feature_rank, vbe_output_size, + {%- if not dense %} + vbe_output, + {%- endif %} {# /* if not dense */ #} {%- endif %} {# /* if vbe */ #} {%- if not dense %} aux_tensor, @@ -504,7 +528,7 @@ enum SSDTensor { TENSORS_EMPTY_OR_ON_SAME_DEVICE({{ name }}[0], {{ name }}[2]); {{ name }}_host = {{ name }}[0]; {{ name }}_placements = {{ name }}[1]; - {{ name }}_offsets = {{ name }}[2]; + {{ name }}_offsets = {{ name }}[2]; } else if ({{ name }}.size() == {{ 5 if name == "weights" else 4 }}) { TENSOR_ON_CUDA_GPU({{ name }}[0]); @@ -514,7 +538,7 @@ enum SSDTensor { {%- if name == "weights" %} TENSORS_EMPTY_OR_ON_SAME_DEVICE({{ name }}[0], {{ name }}[4]); {%- endif %} - {{ name }}_dev = {{ name }}[0]; + {{ name }}_dev = {{ name }}[0]; {{ name }}_uvm = {{ name }}[1]; {{ name }}_placements = {{ name }}[2]; {{ name }}_offsets = {{ name }}[3]; @@ -548,7 +572,7 @@ enum SSDTensor { {%- endmacro %} //////////////////////////////////////////////////////////////////////////////// -// MACROS +// MACROS //////////////////////////////////////////////////////////////////////////////// #define GET_OPTIONAL_TENSOR_VALUE(name, empty_tensor) name.has_value() ? name.value() : empty_tensor; @@ -631,6 +655,9 @@ class {{ autograd_func }} : const c10::SymInt max_B, const c10::SymInt max_B_feature_rank, const c10::SymInt vbe_output_size, + {%- if not dense %} + std::optional vbe_output, + {%- endif %} {# /* if not dense */ #} {%- endif %} {# /* if vbe */ #} {%- if not dense %} std::vector> aux_tensor, @@ -662,6 +689,24 @@ class {{ autograd_func }} : const auto vbe_output_offsets_feature_rank_ = GET_OPTIONAL_TENSOR_VALUE(aux_tensor[IDX_VBE_OUTPUT_OFFSETS_FEATURE_RANK], Tensor()); const auto vbe_B_offsets_rank_per_feature_ = GET_OPTIONAL_TENSOR_VALUE(aux_tensor[IDX_VBE_B_OFFSETS_RANK_PER_FEATURE], Tensor()); const c10::SymInt max_B_ = max_B; + {%- if not dense %} + // The pipeline relies on frontend to supply vbe_output_offsets through aux_tensor + // However, if a model uses old frontend package (i.e., does not include frontend changes from this diff) + // with new backend package, aux_tensor will not contain vbe_output_offsets. + // This means old frontend will send aux_tensor of size 6, but the new backend (from this diff) expects 7, + // which accessing aux_tensor[IDX_VBE_OUTPUT_OFFSETS] can cause segmentation fault + const std::optional vbe_output_offsets = aux_tensor.size() == AUX_TENSOR_SIZE ? aux_tensor[IDX_VBE_OUTPUT_OFFSETS] : std::nullopt; + TORCH_CHECK( + vbe_output.has_value() == vbe_output_offsets.has_value(), + "Expected both vbe_output and vbe_output_offsets to either be None or have value. However, vbe_output ", + vbe_output.has_value() ? " has value" : " is None", + " but vbe_output_offsets ", + vbe_output_offsets.has_value() ? " has value." : " is None. ", + "Note: Frontend passes aux_tensor of size ", aux_tensor.size(), + "and backend expects aux_tensor of ", AUX_TENSOR_SIZE, + ". If the aux_tensor size mismatch, please update your frontend/backend package. Contact FBGEMM team for any assistance." + ); + {%- endif %} {%- else %} const auto max_B_ = offsets.sym_size(0) / T; {%- endif %} @@ -698,6 +743,7 @@ class {{ autograd_func }} : TORCH_CHECK(aux_tensor[IDX_LXU_CACHE_LOCATIONS].has_value(), "lxu_cache_locations should have value."); const auto lxu_cache_locations = aux_tensor[IDX_LXU_CACHE_LOCATIONS].value(); const auto is_experimental = aux_bool[IDX_IS_EXPERIMENTAL_TBE]; + const auto mixed_D = static_cast(aux_bool[IDX_MIXED_D]); {%- endif %} // Default values for Dynamo tracing @@ -719,7 +765,8 @@ class {{ autograd_func }} : const bool, const c10::SymInt, const int64_t, - const c10::SymInt)>(); + const c10::SymInt, + const std::optional&)>(); auto [ vbe_row_output_offsets, vbe_b_t_map @@ -739,7 +786,8 @@ class {{ autograd_func }} : {%- endif %} max_B_feature_rank, info_B_num_bits, - /*total_B=*/offsets.sym_size(0) - 1 + /*total_B=*/offsets.sym_size(0) - 1, + vbe_output_offsets ); {%- endif %} // vbe @@ -755,9 +803,9 @@ class {{ autograd_func }} : const auto indice_weights_value = GET_OPTIONAL_TENSOR_VALUE(indice_weights, Tensor()); {%- endif %} - // Setting learning rate tensor with `.fill_()` breaks apf_dlrm bento kernel with + // Setting learning rate tensor with `.fill_()` breaks apf_dlrm bento kernel with // `RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation.` - // This is because if a tensor is saved for backward and it is mutated later, this can cause correctness problems. + // This is because if a tensor is saved for backward and it is mutated later, this can cause correctness problems. // Since the forward compute and backward compute see different data values for this tensor. // To work around, we pass the cloned tensor instead the mutated tensor {%- if "learning_rate_tensor" in args_pt2.unified_pt2.split_unpacked_arg_names %} @@ -810,7 +858,7 @@ class {{ autograd_func }} : {%- if not nobag %} ctx->saved_data["max_D"] = max_D; - ctx->saved_data["mixed_D"] = static_cast(aux_bool[IDX_MIXED_D]); + ctx->saved_data["mixed_D"] = mixed_D; ctx->saved_data["pooling_mode"] = pooling_mode; {%- else %} ctx->saved_data["D"] = D; @@ -844,8 +892,12 @@ class {{ autograd_func }} : ctx->saved_data["output_dtype"] = output_dtype; {%- endif %} {%- if vbe %} - ctx->saved_data["max_B"] = max_B_; // for reshaping vbe cpu offsets and grad_output - {%- endif %} + ctx->saved_data["max_B"] = max_B_; // for reshaping vbe cpu offsets and grad_output + // This is needed to determine whether to return grads_output + {%- if not dense %} + ctx->saved_data["has_vbe_output"] = vbe_output.has_value(); + {%- endif %} {# /* if not dense */ #} + {%- endif %} {# /* if vbe */ #} {%- if not dense %} // unpack optim args @@ -978,13 +1030,16 @@ static torch::autograd::variable_list backward( {%- if is_gwd %} const auto gwd_lower_bound = ctx->saved_data["gwd_lower_bound"].toDouble(); {%- endif %} - + {%- if not nobag %} auto output_dtype = ctx->saved_data["output_dtype"].toInt(); {%- endif %} {%- if not dense %} {%- if vbe %} auto max_B = ctx->saved_data["max_B"].toSymInt(); // for reshaping vbe cpu offsets and grad_output + {%- if not dense %} + const auto has_vbe_output = ctx->saved_data["has_vbe_output"].toBool(); // for whether to return grad_output + {%- endif %} {# /* if not dense */ #} {%- endif %} {%- for (var, _ , ivalue_cast, type) in args_pt2.unified_pt2.split_saved_data %} @@ -1005,7 +1060,25 @@ static torch::autograd::variable_list backward( #ifdef USE_ROCM constexpr int32_t BT_block_size = 64; - constexpr int32_t max_segment_length_per_warp = 64; + int32_t max_segment_length_per_warp = 64; + int32_t total_L = indices.numel(); + {%- if (not nobag) and + (optimizer == "rowwise_adagrad") and + (not vbe) and + (not is_gwd) and + (not ssd) and + (not is_index_select) and + (not dense) %} + const auto T = weights_offsets.sym_numel(); + auto total_B = (offsets.size(0) - 1); + const auto B = total_B / T; + {%- for kDimSize in [64, 128, 160, 192, 256, 320] %} + if(!mixed_D && total_L / total_B > 1 && (max_D == {{ kDimSize }})) + { + max_segment_length_per_warp = 16384; + } + {%- endfor %} + {%- endif %} #else constexpr int32_t BT_block_size = 32; constexpr int32_t max_segment_length_per_warp = 32; @@ -1097,7 +1170,7 @@ static torch::autograd::variable_list backward( feature_requires_grad {%- endif %} ); - + Tensor grad_weights_dev; // weighted if (indice_weights.defined()) @@ -1147,7 +1220,7 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function_pt2( {%- else %} const Tensor& placeholder_autograd_tensor, const at::TensorList weights, - {%- endif %} + {%- endif %} {#-/* if dense */#} const Tensor& D_offsets, const c10::SymInt total_D, const c10::SymInt max_D, @@ -1164,20 +1237,32 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function_pt2( const std::vector& aux_int, const std::vector& aux_float, c10::List aux_bool, - {%- endif %} + {%- endif %} {#-/* if not dense */#} {{ args_pt2.unified_pt2.split_function_args | join(", ") }}, const c10::SymInt max_B = -1, const c10::SymInt max_B_feature_rank = -1, - {%- if ssd %} + {%- if not dense %} const c10::SymInt vbe_output_size = -1, - const std::optional& ssd_tensors = std::nullopt + {%- if ssd %} + const std::optional& ssd_tensors = std::nullopt, + {%- endif %} {#-/* if ssd */#} + std::optional vbe_output = std::nullopt {%- else %} + {#- /* ssd and pre-allocated vbe_output is not yet supported in Dense TBE */ -#} const c10::SymInt vbe_output_size = -1 - {%- endif %} + {%- endif %} {#-/* if not dense */#} ) { {%- if has_gpu_support or has_cpu_support %} + TORCH_WARN(aux_tensor.size() <= AUX_TENSOR_SIZE, + "aux_tensor.size() should not be larger than ", + AUX_TENSOR_SIZE, + "but found to be ", + aux_tensor.size(), + ". This means frontend package does not match with backend package, so some functionalities might be missing. Please contact FBGEMM team for any assistance." + ); + {%- if not dense %} // Load the config value from JK once static auto is_tbev2_enabled = config::is_feature_enabled(config::FeatureGateName::TBE_V2); @@ -1229,7 +1314,7 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function_pt2( "{{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function is deprecated. Please see https://github.com/pytorch/FBGEMM/discussions/1727 for more detail." ); return Tensor(); - {%- endif %} + {%- endif %} } @@ -1259,16 +1344,19 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { " int[] aux_int, " " float[] aux_float, " " bool[] aux_bool, " + {%- endif %} {#-/* if not dense */#} " {{ args_pt2.unified_pt2.split_function_schemas | join(", ") }}, " " SymInt max_B=-1, " " SymInt max_B_feature_rank=-1, " - {%- if ssd %} + {%- if not dense %} " SymInt vbe_output_size=-1, " - " Tensor[]? ssd_tensors=None " - {%- else %} - " SymInt vbe_output_size=-1 " - {%- endif %} + {%- if ssd %} + " Tensor[]? ssd_tensors=None, " {%- endif %} + " Tensor? vbe_output=None " + {%- else %} + " SymInt vbe_output_size=-1 " + {%- endif %} {#-/* if not dense */#} ") -> Tensor", {PT2_COMPLIANT_TAG}); // We're playing a funny trick here: we're using the autograd diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp index 14deb1af5e..c06dd5efef 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp @@ -143,7 +143,13 @@ Tensor split_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_pt2_cpu const Tensor& B_offsets, {%- endif %} const bool /*is_experimental = false*/, - const int64_t output_dtype = static_cast(SparseType::FP32)) { + {%- if vbe %} + const int64_t output_dtype = static_cast(SparseType::FP32), + std::optional vbe_output = std::nullopt + {%- else %} + const int64_t output_dtype = static_cast(SparseType::FP32) + {%- endif %} + ){ Tensor offsets_; {%- if vbe %} const int64_t max_B_int = max_B.guard_int(__FILE__, __LINE__); @@ -406,7 +412,12 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { " Tensor B_offsets, " {%- endif %} " bool is_experimental, " + {%- if vbe %} + " int output_dtype, " + " Tensor? vbe_output " + {%- else %} " int output_dtype " + {%- endif %} ") -> Tensor" {%- if not nobag and not vbe %} // only split_embedding_codegen_forward_[un]weighted_cuda diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cuda_wrapper_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cuda_wrapper_template.cpp index 1a0cb0fa80..b7070deb83 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cuda_wrapper_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cuda_wrapper_template.cpp @@ -107,7 +107,12 @@ Tensor {{ fwd_mdesc }}_embedding{{ ndesc }}_codegen_forward_{{ desc_suffix }}_pt const double gwd_lower_bound, {%- endif %} const bool is_experimental, + {%- if vbe and not dense %} + const int64_t output_dtype, + std::optional vbe_output + {%- else %} const int64_t output_dtype + {%- endif %} ){ {%- set op = "{}_embedding{}_codegen_forward_{}_cuda".format( fwd_mdesc, ndesc, desc_suffix @@ -155,7 +160,12 @@ Tensor {{ fwd_mdesc }}_embedding{{ ndesc }}_codegen_forward_{{ desc_suffix }}_pt const int64_t /*iter*/, const double /*gwd_lower_bound*/, {%- endif %} + {%- if vbe and not dense %} + const bool, + std::optional /*vbe_output*/ + {%- else %} const bool + {%- endif %} )>(); return op.call( @@ -201,7 +211,12 @@ Tensor {{ fwd_mdesc }}_embedding{{ ndesc }}_codegen_forward_{{ desc_suffix }}_pt iter, gwd_lower_bound, {%- endif %} {# /* if is_gwd */ #} + {%- if vbe and not dense %} + is_experimental, + vbe_output + {%- else %} is_experimental + {%- endif %} ); }; {%- else %} @@ -561,7 +576,12 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { " float gwd_lower_bound, " {%- endif %} " bool is_experimental, " + {%- if vbe and not dense %} + " int output_dtype, " + " Tensor? vbe_output" + {%- else %} " int output_dtype " + {%- endif %} ") -> Tensor" {%- if not nobag and not vbe %} // only split_embedding_codegen_forward_[un]weighted_cuda diff --git a/fbgemm_gpu/codegen/training/pt2/pt2_arg_utils_template.h b/fbgemm_gpu/codegen/training/pt2/pt2_arg_utils_template.h index ec033c89d2..675bb7df9b 100644 --- a/fbgemm_gpu/codegen/training/pt2/pt2_arg_utils_template.h +++ b/fbgemm_gpu/codegen/training/pt2/pt2_arg_utils_template.h @@ -21,7 +21,7 @@ enum ArgIndex_{{ name }} { {%- for var in aux_args[name] %} IDX_{{ var | upper }} = {{ loop.index - 1 }}, {%- endfor %} - {{ name | upper }}_SIZE = {{ name | length }} + {{ name | upper }}_SIZE = {{ aux_args[name] | length }} }; {%- endfor %} diff --git a/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template b/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template index 0ecf71bdb5..6fe7292db4 100644 --- a/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template +++ b/fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template @@ -42,7 +42,7 @@ torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") {%- endif %} -{# This macro generates a code blob to pack Tensor arguments into a TensorList +{# This macro generates a code blob to pack Tensor arguments into a TensorList as number of arguments for some optimizers exceed 64 #} {%- macro pack_tensors(arg) %} {{ arg }}_list = [ @@ -58,7 +58,7 @@ torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") {%- endmacro %} {# This macro generates a code blob to pack optim optional tensor into an optional TensorList. - All optim optional tensors are packed together into `optim_tensor`. + All optim optional tensors are packed together into `optim_tensor`. This poses challenge to handle unpacking in autograd if we do per device (i.e, 3 for cpu and 4 for cuda). Hence, we pack unified args (i.e., 5 items) for readability and programmability. #} @@ -92,14 +92,14 @@ torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") "Please check the frontend and backend version. " ) {{ arg_type }}.append(dict_{{ arg_type }}["{{ var }}"]) - + {%- endfor %} {%- endmacro %} {%- if is_prototype_optimizer %} # Decorate the prototype optimizers which may be deprecated in the future with jit.ignore to avoid -# possible errors from torch.jit.script. +# possible errors from torch.jit.script. # Note that backends can be removed but the lookup invoker is still needed for backward compatibility @torch.jit.ignore {%- endif %} @@ -187,14 +187,15 @@ def invoke( "vbe_B_offsets_rank_per_feature": vbe_metadata.B_offsets_rank_per_feature, "lxu_cache_locations": common_args.lxu_cache_locations, "uvm_cache_stats": common_args.uvm_cache_stats, + "vbe_output_offsets" : None, } dict_aux_int: Dict[str, int] = { - "iter": iter, - "info_B_num_bits": common_args.info_B_num_bits, + "iter": iter, + "info_B_num_bits": common_args.info_B_num_bits, "info_B_mask": common_args.info_B_mask, } - + dict_aux_float: Dict[str, float] = { "gwd_lower_bound": gwd_lower_bound, } @@ -219,7 +220,7 @@ def invoke( {%- else %} dict_aux_tensor["prev_iter_dev"] = prev_iter.dev {%- endif %} - + # optimizer_args {%- if optimizer == "none" %} @@ -302,13 +303,13 @@ def invoke( {{ pack_tensors("row_counter") }} {%- endif %} {%- if "row_counter" in args_pt2.unified_pt2.split_saved_tensorlist_optional %} - + if optimizer_args.use_rowwise_bias_correction and row_counter is not None: row_counter_host = None # not supported on CPU row_counter_dev = row_counter.dev row_counter_uvm = row_counter.uvm row_counter_offsets = row_counter.offsets - row_counter_placements = row_counter.placements + row_counter_placements = row_counter.placements elif optimizer_args.use_rowwise_bias_correction: assert False, "`use_rowwise_bias_correction` is set, `row_counter` cannot be None" else: @@ -316,7 +317,7 @@ def invoke( row_counter_dev = None row_counter_uvm = None row_counter_offsets = None - row_counter_placements = None + row_counter_placements = None {%- endif %} {{ pack_to_list("aux_tensor") }} @@ -358,7 +359,7 @@ def invoke( {%- for name in args_pt2.unified_pt2.split_args_dict["optim_bool"] %} optim_bool.append(dict_optim_bool["{{ name }}"]) {%- endfor %} - {%- endif %} + {%- endif %} return torch.ops.fbgemm.{{ mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function_pt2( # common_args diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py index 5d373ea266..9fe94f8d46 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py @@ -509,14 +509,13 @@ def _fbgemm_grouped_gemm_ws( num_tiles = num_m_tiles * NUM_N_TILES if USE_TMA_STORE: - with tl.async_task([0]): - c_desc_ptr = tl.make_tensor_descriptor( - c_ptr + M_start_offset * N, - shape=[m_size, N], - # pyre-ignore - strides=[N, 1], - block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], - ) + c_desc_ptr = tl.make_tensor_descriptor( + c_ptr + M_start_offset * N, + shape=[m_size, N], + # pyre-ignore + strides=[N, 1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], + ) # Move across tiles next_iterated_tiles = iterated_tiles + num_tiles @@ -534,72 +533,59 @@ def _fbgemm_grouped_gemm_ws( m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32) n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32) for k_offset in range(0, K, BLOCK_SIZE_K): - with tl.async_task([0]): - a = tl._experimental_descriptor_load( - a_desc_ptr, - [m_offset, k_offset], - [BLOCK_SIZE_M, BLOCK_SIZE_K], - dtype, - ) - b = tl._experimental_descriptor_load( - b_desc_ptr, - [n_offset, k_offset], - [BLOCK_SIZE_N, BLOCK_SIZE_K], - dtype, - ) - with tl.async_task([1, NUM_CONSUMER_GROUPS]): - if USE_FAST_ACCUM: - accumulator = tl.dot(a, b.T, accumulator) - else: - accumulator += tl.dot(a, b.T) + a = tl._experimental_descriptor_load( + a_desc_ptr, + [m_offset, k_offset], + [BLOCK_SIZE_M, BLOCK_SIZE_K], + dtype, + ) + b = tl._experimental_descriptor_load( + b_desc_ptr, + [n_offset, k_offset], + [BLOCK_SIZE_N, BLOCK_SIZE_K], + dtype, + ) + if USE_FAST_ACCUM: + accumulator = tl.dot(a, b.T, accumulator) + else: + accumulator += tl.dot(a, b.T) if USE_TMA_STORE: - with tl.async_task([1, NUM_CONSUMER_GROUPS]): - m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32) - n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32) - # pyre-ignore - c_desc_ptr.store( - [m_offset, n_offset], - accumulator.to(c_ptr.dtype.element_ty), - ) + m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32) + n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32) + # pyre-ignore + c_desc_ptr.store( + [m_offset, n_offset], + accumulator.to(c_ptr.dtype.element_ty), + ) elif FUSE_SCATTER_ADD: - with tl.async_task([1, NUM_CONSUMER_GROUPS]): - offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange( - 0, BLOCK_SIZE_M - ) - mask = offs_am < m_size - m_offsets = tl.load( - scatter_add_indices + M_start_offset + offs_am, - mask=mask, - cache_modifier=".ca", - ) - offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange( - 0, BLOCK_SIZE_N - ) - c = accumulator.to(c_ptr.dtype.element_ty) - tl.atomic_add( - c_ptr + m_offsets[:, None] * N + offs_bn[None, :], - c, - mask=mask[:, None], - sem="relaxed", - ) + offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + mask = offs_am < m_size + m_offsets = tl.load( + scatter_add_indices + M_start_offset + offs_am, + mask=mask, + cache_modifier=".ca", + ) + offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c = accumulator.to(c_ptr.dtype.element_ty) + tl.atomic_add( + c_ptr + m_offsets[:, None] * N + offs_bn[None, :], + c, + mask=mask[:, None], + sem="relaxed", + ) else: - with tl.async_task([1, NUM_CONSUMER_GROUPS]): - offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange( - 0, BLOCK_SIZE_M - ) - offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange( - 0, BLOCK_SIZE_N - ) - c = accumulator.to(c_ptr.dtype.element_ty) - tl.store( - c_ptr - + (M_start_offset + offs_am[:, None]) * N - + offs_bn[None, :], - c, - mask=offs_am[:, None] < m_size, - cache_modifier=".cs", - ) + offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c = accumulator.to(c_ptr.dtype.element_ty) + tl.store( + c_ptr + + (M_start_offset + offs_am[:, None]) * N + + offs_bn[None, :], + c, + mask=offs_am[:, None] < m_size, + cache_modifier=".cs", + ) tidx += NUM_SMS iterated_tiles += num_tiles @@ -841,14 +827,13 @@ def _fbgemm_grouped_gemm_fp8_rowwise_ws( num_tiles = num_m_tiles * NUM_N_TILES if USE_TMA_STORE: - with tl.async_task([0]): - c_desc_ptr = tl.make_tensor_descriptor( - c_ptr + M_start_offset * N, - shape=[m_size, N], - # pyre-ignore - strides=[N, 1], - block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], - ) + c_desc_ptr = tl.make_tensor_descriptor( + c_ptr + M_start_offset * N, + shape=[m_size, N], + # pyre-ignore + strides=[N, 1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], + ) # Move across tiles next_iterated_tiles = iterated_tiles + num_tiles @@ -867,107 +852,85 @@ def _fbgemm_grouped_gemm_fp8_rowwise_ws( m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32) n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32) for k_offset in range(0, K, BLOCK_SIZE_K): - with tl.async_task([0]): - a = tl._experimental_descriptor_load( - a_desc_ptr, - [m_offset, k_offset], - [BLOCK_SIZE_M, BLOCK_SIZE_K], - dtype, - ) - b = tl._experimental_descriptor_load( - b_desc_ptr, - [n_offset, k_offset], - [BLOCK_SIZE_N, BLOCK_SIZE_K], - dtype, - ) - with tl.async_task([1, NUM_CONSUMER_GROUPS]): - if USE_FAST_ACCUM: - accumulator = tl.dot(a, b.T, accumulator) - else: - accumulator += tl.dot(a, b.T) + a = tl._experimental_descriptor_load( + a_desc_ptr, + [m_offset, k_offset], + [BLOCK_SIZE_M, BLOCK_SIZE_K], + dtype, + ) + b = tl._experimental_descriptor_load( + b_desc_ptr, + [n_offset, k_offset], + [BLOCK_SIZE_N, BLOCK_SIZE_K], + dtype, + ) + if USE_FAST_ACCUM: + accumulator = tl.dot(a, b.T, accumulator) + else: + accumulator += tl.dot(a, b.T) if USE_TMA_LOAD_ON_SCALES: - with tl.async_task([0]): - b_scale = tl._experimental_descriptor_load( - b_scale_desc_ptr, - [n_offset], - [BLOCK_SIZE_N], - tl.float32, - ) - - with tl.async_task([1, NUM_CONSUMER_GROUPS]): - offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange( - 0, BLOCK_SIZE_M - ) - a_scale = tl.load( - a_scale_ptr + M_start_offset + offs_am[:, None], - mask=offs_am[:, None] < m_size, - cache_modifier=".ca", - ) - c = accumulator.to(tl.float32) * a_scale * b_scale[None, :] + b_scale = tl._experimental_descriptor_load( + b_scale_desc_ptr, + [n_offset], + [BLOCK_SIZE_N], + tl.float32, + ) + + offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + a_scale = tl.load( + a_scale_ptr + M_start_offset + offs_am[:, None], + mask=offs_am[:, None] < m_size, + cache_modifier=".ca", + ) + c = accumulator.to(tl.float32) * a_scale * b_scale[None, :] else: - with tl.async_task([1, NUM_CONSUMER_GROUPS]): - offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange( - 0, BLOCK_SIZE_M - ) - offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange( - 0, BLOCK_SIZE_N - ) - a_scale = tl.load( - a_scale_ptr + M_start_offset + offs_am[:, None], - mask=offs_am[:, None] < m_size, - cache_modifier=".ca", - ) - b_scale = tl.load( - b_scale_ptr + N_start_offset + offs_bn[None, :], - cache_modifier=".ca", - ) - c = accumulator.to(tl.float32) * a_scale * b_scale + offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + a_scale = tl.load( + a_scale_ptr + M_start_offset + offs_am[:, None], + mask=offs_am[:, None] < m_size, + cache_modifier=".ca", + ) + b_scale = tl.load( + b_scale_ptr + N_start_offset + offs_bn[None, :], + cache_modifier=".ca", + ) + c = accumulator.to(tl.float32) * a_scale * b_scale if USE_TMA_STORE: - with tl.async_task([1, NUM_CONSUMER_GROUPS]): - m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32) - n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32) - # pyre-ignore - c_desc_ptr.store( - [m_offset, n_offset], c.to(c_ptr.dtype.element_ty) - ) + m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32) + n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32) + # pyre-ignore + c_desc_ptr.store( + [m_offset, n_offset], c.to(c_ptr.dtype.element_ty) + ) elif FUSE_SCATTER_ADD: - with tl.async_task([1, NUM_CONSUMER_GROUPS]): - offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange( - 0, BLOCK_SIZE_M - ) - mask = offs_am < m_size - m_offsets = tl.load( - scatter_add_indices + M_start_offset + offs_am, - mask=mask, - cache_modifier=".ca", - ) - offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange( - 0, BLOCK_SIZE_N - ) - tl.atomic_add( - c_ptr + m_offsets[:, None] * N + offs_bn[None, :], - c, - mask=mask[:, None], - sem="relaxed", - ) + offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + mask = offs_am < m_size + m_offsets = tl.load( + scatter_add_indices + M_start_offset + offs_am, + mask=mask, + cache_modifier=".ca", + ) + offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + tl.atomic_add( + c_ptr + m_offsets[:, None] * N + offs_bn[None, :], + c, + mask=mask[:, None], + sem="relaxed", + ) else: - with tl.async_task([1, NUM_CONSUMER_GROUPS]): - offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange( - 0, BLOCK_SIZE_M - ) - offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange( - 0, BLOCK_SIZE_N - ) - tl.store( - c_ptr - + (M_start_offset + offs_am[:, None]) * N - + offs_bn[None, :], - c, - mask=offs_am[:, None] < m_size, - cache_modifier=".cs", - ) + offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + tl.store( + c_ptr + + (M_start_offset + offs_am[:, None]) * N + + offs_bn[None, :], + c, + mask=offs_am[:, None] < m_size, + cache_modifier=".cs", + ) tidx += NUM_SMS iterated_tiles += num_tiles diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu index 5b618b6526..5227510217 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu @@ -301,7 +301,7 @@ at::Tensor dispatch_fmha_gen_fwd( return DISPATCH_ELEMENT_TYPE(q.scalar_type(), Element, [&] { return DISPATCH_KERNEL_TYPE(static_cast(kernel_type), KType, [&] { - GenRunner, Shape<_1, _1, _1>> + GenRunner, Shape<_1, _1, _1>> runner; return runner.fmha_fwd(q, k, v, seqlen_kv, batch_idx); }); diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/fmha_common.hpp b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/fmha_common.hpp index 2d3e2b166d..1e0ea6d449 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/fmha_common.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/fmha_common.hpp @@ -78,10 +78,10 @@ to_tiled_mma_sm100_ts( TiledMMA, cute::C, - cute::integral_constant, - cute::integral_constant, - cute::integral_constant, - cute::integral_constant>, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant>, TAs...>, TMs...>) { return TiledMMA, + a_major, + b_major, + a_neg, + b_neg>, TAs...>, TMs...>) { return TiledMMA +struct kValTyPair { + static constexpr auto key = keyVal; + using valueT = _valueT; +}; + +template +struct kValTyMap { + template + using query = std::conditional_t< + QueryKey == FirstMapping::key, + typename FirstMapping::valueT, + typename kValTyMap::template query>; +}; + +template +struct kValTyMap { + template + using query = std::conditional_t< + QueryKey == LastMapping::key, + typename LastMapping::valueT, + Default>; +}; + +} // namespace constexpr_type_map + +namespace constexpr_constexpr_map { + +template +struct kValValPair { + static constexpr auto key = keyVal; + static constexpr auto value = valueVal; +}; + +template +struct kValValMap { + using ValType = std::add_const_t; + static_assert( + std::is_same_v, + "Map value type mismatch"); + static_assert( + (std::is_same_v && ...), + "Map value type mismatch"); + template + static constexpr decltype(FirstMapping::value) query = + (QueryKey == FirstMapping::key) + ? FirstMapping::value + : kValValMap::template query; +}; + +template +struct kValValMap { + using ValType = std::add_const_t; + static_assert( + std::is_same_v, + "Map value type mismatch"); + template + static constexpr decltype(LastMapping::value) query = + (QueryKey == LastMapping::key) ? LastMapping::value : Default; +}; + +} // namespace constexpr_constexpr_map diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp index e8e9aafceb..1be2e43145 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp @@ -41,10 +41,13 @@ #include "collective/fmha_common.hpp" #include "collective/fmha_fusion.hpp" #include "collective/sm100_fmha_load_cpasync_warpspecialized.hpp" +#include "cutlass/detail/dependent_false.hpp" namespace cutlass::fmha::collective { using namespace cute; +using namespace constexpr_type_map; +using namespace constexpr_constexpr_map; template< class Element_, @@ -85,10 +88,32 @@ struct Sm100FmhaGenMainloopWarpspecialized { using StrideO = decltype(replace<0>(StrideO_{}, 0)); using Mask = Mask_; + using TileM = decltype(get<0>(TileShape{})); // seq Q dim + static_assert(TileM::value == 64 || TileM::value == 128, "Only expecting TileM to be 64 or 128"); static constexpr int StageCountQ = get<1>(TileShape{}) == 256 ? 1 : 2; - // local changes - static constexpr int StageCountKV = StageCountQ * (sizeof(Element) == 1 ? 11 : 5) ; - + // Choose StageCountKV based on: + // - Tile shape on the M (i.e., Query) dimension + // - Element size + using StageCountKVSelector = kValTyMap< + void, + kValTyPair<64, + kValValMap< + 65536 /* default, arbitrarily large to trigger smem OOM error */, + kValValPair<1, 12>, // fp8 + kValValPair<2, 6> // bf16/fp16 + >>, + kValTyPair<128, + kValValMap< + 65536 /* default, arbitrarily large to trigger smem OOM error */, + kValValPair<1, 11>, // fp8 + kValValPair<2, 5> // bf16/fp16 + >> + >; + static constexpr int StageCountKV = StageCountQ * + StageCountKVSelector:: + template query:: + template query; + using StagesQ = cutlass::gemm::collective::StageCount; using StagesKV = cutlass::gemm::collective::StageCount; @@ -129,28 +154,52 @@ struct Sm100FmhaGenMainloopWarpspecialized { }; }; + // indices for V0 / V1 + enum : int { + kIdxOldRowMax = 0, + kIdxNewRowMax = 1, + kIdxFinalRowSum = 0, + kIdxFinalRowMax = 1, + kIdxStatsEnd = 2 + }; + + // Each storage reserves kTMEM_V_COLUMNS for row max/sum stats + // TileM=64 uses 16dp64b --> two threads processing a row + // TileM=128 uses 32dp32b --> one thread processing a row + using kTMEM_V_COLUMNS = typename kValTyMap>, + kValTyPair<128, Int> + >::template query; + + // TMEM column allocation, offset will be used to calc the lower 16-bit of tmem addresses. + // TMEM row/lane dimension is for the Q dim. enum class TmemAllocation : uint32_t { - kSizeS = 128, - kSizeO = 128, - kSizeP = 32, + kSizeS = get<1>(TileShapeQK{}), // i.e., KV dim in a tile + kSizeO = get<2>(TileShapeQK{}), // i.e., head dim + // carve kSizeS to two parts: first 1/4 for V0/V1 stats storage; the rest for P0/P1 + // 1/4 is wasting some storage here but there seems to be column-wise address alignment requirements not found in spec. + // Since there is enough storage left for P0/P1, chose to not debug alignment issues. + kSizeV = kSizeS / 2, + // P will be casted to the same type as V + kSizeP = kSizeS * sizeof(Element) / sizeof(float), S0 = 0, S1 = S0 + kSizeS, V0 = S0, // stats storage from softmax to correction V1 = S1, - P0 = S0 + kSizeP, - P1 = S1 + kSizeP, + P0 = V0 + kSizeV, + P1 = V1 + kSizeV, O0 = S1 + kSizeS, O1 = O0 + kSizeO, kEnd = O1 + kSizeO }; - - // indices for V0 / V1 - enum : int { - kIdxOldRowMax = 0, - kIdxNewRowMax = 1, - kIdxFinalRowSum = 0, - kIdxFinalRowMax = 1 - }; + static_assert(static_cast(TmemAllocation::kEnd) <= 512, "Exceeds TMEM 512 columns"); + static_assert( + static_cast(TmemAllocation::kSizeV) + static_cast(TmemAllocation::kSizeP) <= + static_cast(TmemAllocation::kSizeS), + "Not enough storage to carve V and P out of S"); + static_assert( + static_cast(kTMEM_V_COLUMNS::value) <= static_cast(TmemAllocation::kSizeV), + "Not enough storage reserved for V"); // from load to mma warp, protects q in smem using PipelineQ = cutlass::PipelineUmmaConsumerAsync< @@ -526,35 +575,48 @@ struct Sm100FmhaGenMainloopWarpspecialized { PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state, PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state, OrderBarrierSoftmax& order_s) { - - Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + Tensor tScS = + typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); tStS.data() = uint32_t(stage == _0{} ? TmemAllocation::S0 : TmemAllocation::S1); - Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); - tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1); - Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); + Tensor tStS_v = + tStS.compose(make_layout(make_shape(TileM{}, kTMEM_V_COLUMNS{}))); + tStS_v.data() = + uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1); + Tensor tScS_v = + tScS.compose(make_layout(make_shape(TileM{}, kTMEM_V_COLUMNS{}))); auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int{} * Int{}; - Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); - tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1)); - Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); - - // local changes - // Each thread owns a single row - using TMEM_LOAD = conditional_t< - size<1>(TileShapeQK{}) < _128{}, - SM100_TMEM_LOAD_32dp32b8x, - SM100_TMEM_LOAD_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem - using TMEM_STORE = conditional_t< - size<1>(TileShapeQK{}) < _128{}, - SM100_TMEM_STORE_32dp32b16x, - SM100_TMEM_STORE_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem - using TMEM_STORE_V = - SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + Tensor tStS_P = tStS.compose( + make_layout(make_shape(TileM{}, tilePlikeFP32))); + tStS_P.data() = warp_uniform( + uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1)); + Tensor tScS_P = tScS.compose( + make_layout(make_shape(TileM{}, tilePlikeFP32))); + + // needed number of cols to load from tmem to reg + constexpr int kConversionsPerStep = 2; + constexpr int kTmemLoadNcells = cute::min(32, size<1>(TileShapeQK{}) / kConversionsPerStep); + constexpr int kTmemStoreNcells = kTmemLoadNcells * sizeof_bits_v / sizeof_bits_v; + + using TMEM_LOAD_1xOP = typename kValTyMap, + // Each thread owns a single row + kValTyPair<128, SM100_TMEM_LOAD_32dp32b1x> + >::template query; + using TMEM_STORE_1xOP = decltype(TMEM::tmem_load_to_store(TMEM_LOAD_1xOP{})); + using TMEM_LOAD = decltype(TMEM::op_repeater()); + using TMEM_STORE = decltype(TMEM::op_repeater()); + + using TMEM_STORE_V = typename kValTyMap, + kValTyPair<128, SM100_TMEM_STORE_32dp32b2x> // 4x32 threads with 2 cols of 32b elem + >::template query; - int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tStS); auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); @@ -603,11 +665,15 @@ struct Sm100FmhaGenMainloopWarpspecialized { row_max = ::fmax(row_max_0, row_max_1); row_max = ::fmax(row_max, row_max_2); row_max = ::fmax(row_max, row_max_3); + if constexpr (TileM{} == 64) { + ElementQK shuffled_row_max = __shfl_xor_sync(0xffffffff, row_max, 16); + row_max = ::fmax(row_max, shuffled_row_max); + } } - ElementQK row_max_safe = row_max == -INFINITY ? 0 : row_max; Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); + static_assert(size(tTMEM_STOREVrS) == 2); tTMEM_STOREVrS(kIdxOldRowMax) = old_row_max; tTMEM_STOREVrS(kIdxNewRowMax) = row_max_safe; copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); @@ -625,54 +691,64 @@ struct Sm100FmhaGenMainloopWarpspecialized { Tensor tTMEM_STORErS_x4 = make_tensor(shape(tTMEM_STOREcS)); - constexpr int kConversionsPerStep = 2; + Tensor tTMEM_STORErS_x4_e = recast>(tTMEM_STORErS_x4); NumericArrayConverter convert; - const int kReleasePipeCount = 10; // must be multiple of 2 order_s.wait(); + static_assert(kReleasePipeCount % kConversionsPerStep == 0); + static_assert(kConversionsPerStep == 2); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tTMEM_LOADrS); i += 2) { - float2 in = make_float2( - tTMEM_LOADrS(i + 0), - tTMEM_LOADrS(i + 1) - ); - float2 out; - cute::fma(out, scale_fp32x2, in, minus_row_max_scale_fp32x2); - tTMEM_LOADrS(i + 0) = out.x; - tTMEM_LOADrS(i + 1) = out.y; + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += kConversionsPerStep) { + float2 in = make_float2( + tTMEM_LOADrS(i + 0), + tTMEM_LOADrS(i + 1) + ); - tTMEM_LOADrS(i+0) = ::exp2f(tTMEM_LOADrS(i+0)); - tTMEM_LOADrS(i+1) = ::exp2f(tTMEM_LOADrS(i+1)); + float2 out; + cute::fma(out, scale_fp32x2, in, minus_row_max_scale_fp32x2); + tTMEM_LOADrS(i + 0) = out.x; + tTMEM_LOADrS(i + 1) = out.y; - Array in_conv; - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < kConversionsPerStep; j++) { - in_conv[j] = tTMEM_LOADrS(i + j); - } - tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv); + tTMEM_LOADrS(i+0) = ::exp2f(tTMEM_LOADrS(i+0)); + tTMEM_LOADrS(i+1) = ::exp2f(tTMEM_LOADrS(i+1)); - - if (i == size(tTMEM_LOADrS) - kReleasePipeCount) { - order_s.arrive(); - } + Array in_conv; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kConversionsPerStep; j++) { + in_conv[j] = tTMEM_LOADrS(i + j); + } + tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv); + + + if (i == size(tTMEM_LOADrS) - kReleasePipeCount) { + order_s.arrive(); + } - // this prevents register spills in fp16 - if constexpr (size<2>(tTMEM_STORErS_x4) == _2{}) { - if (i == size(tTMEM_LOADrS) - 6) { - copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, 0), tTMEM_STOREtS_x4(_, _, 0)); + if constexpr (TileM::value == 128) { + if constexpr (size<2>(tTMEM_STORErS_x4) == _2{}) { + //this prevents register spills in fp16 + if (i == size(tTMEM_LOADrS) - 6) { + copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, 0), tTMEM_STOREtS_x4(_, _, 0)); + } + } } } - } + } // tmem_store(reg_S8) -> op_P CUTE_STATIC_ASSERT_V(size<2>(tTMEM_STORErS_x4) <= _2{}); CUTE_STATIC_ASSERT_V(size<1>(tTMEM_STORErS_x4) == _1{}); - copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1), tTMEM_STOREtS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1)); + if constexpr (TileM::value == 128) { + copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1), tTMEM_STOREtS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1)); + } else { + copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4); + } cutlass::arch::fence_view_async_tmem_store(); @@ -714,8 +790,14 @@ struct Sm100FmhaGenMainloopWarpspecialized { row_sum = local_row_sum; if (final_call) { + if constexpr (TileM{} == 64) { + // Sync threads 0 and 16 to get the sum of row_sum between them + row_sum += __shfl_xor_sync(0xffffffff, row_sum, 16); + } + // re-acquire the S part in the final step pipeline_s.consumer_wait(pipeline_s_consumer_state); + Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); tTMEM_STOREVrS(kIdxFinalRowMax) = row_max; @@ -802,18 +884,34 @@ struct Sm100FmhaGenMainloopWarpspecialized { // As opposed to the softmax, we do not have enough registers here // to load all of the values (for tile kv = 128), so we loop // good values would be either 32 or 64 - const int kCorrectionTileSize = 32 / sizeof(ElementOut); - - using TMEM_LOAD = std::conditional_t; // 4x32 threads with 64 cols of 32b elem + const int kCorrectionTileSize = 32 / sizeof(ElementOut); + // TODO: load all values + + + // Choose TMEM OP based on + // - TileM shape + // - kCorrectionTileSize + using TMEM_LOAD_OPMAP = kValTyMap + > + >, + kValTyPair<128, + kValTyMap + >> // 4x32 threads with 64 cols of 32b elem + >; + using TMEM_LOAD = typename TMEM_LOAD_OPMAP::template query::template query; typename CollectiveMmaPV::TiledMma mma; Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); Tensor tOcO = mma.get_slice(0).partition_C(cO); Tensor tOgO = mma.get_slice(0).partition_C(gO); - Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int{}))); - Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int{}))); - Tensor tOgO_i = tOgO.compose(make_layout(make_shape(_128{}, Int{}))); + Tensor tOtO_i = tOtO.compose(make_layout(make_shape(TileM{}, Int{}))); + Tensor tOcO_i = tOcO.compose(make_layout(make_shape(TileM{}, Int{}))); + Tensor tOgO_i = tOgO.compose(make_layout(make_shape(TileM{}, Int{}))); Tensor tOtO0 = tOtO_i; tOtO0.data() = tOtO0.data().get() + uint32_t(TmemAllocation::O0); @@ -879,13 +977,13 @@ struct Sm100FmhaGenMainloopWarpspecialized { tCd(j) = convert.convert(tCs(j)); } - Tensor tSMgO_i = recast(tTMEM_LOADgO_i); - Tensor tSMrO_i = recast(tSMrO); + Tensor tSMgO_i = recast(tTMEM_LOADgO_i); + Tensor tSMrO_i = recast(tSMrO); - // could use masking do this right for smaller D - if (get<0>(tTMEM_LOADcO(_0{})) < get<0>(g_shape)) { + // could use masking do this right for smaller D + if (get<0>(tTMEM_LOADcO(_0{})) < get<0>(g_shape)) { copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tSMrO_i, tSMgO_i); - } + } } } @@ -901,16 +999,22 @@ struct Sm100FmhaGenMainloopWarpspecialized { // good values would be either 32 or 64 const int kCorrectionTileSize = 32; - using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 64 cols of 32b elem - using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 64 cols of 32b elem + using TMEM_LOAD = typename kValTyMap, + kValTyPair<128, SM100_TMEM_LOAD_32dp32b32x> // 4x32 threads with 64 cols of 32b elem + >::template query; + using TMEM_STORE = typename kValTyMap, + kValTyPair<128, SM100_TMEM_STORE_32dp32b32x> // 4x32 threads with 64 cols of 32b elem + >::template query; typename CollectiveMmaPV::TiledMma mma; Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); Tensor tOcO = mma.get_slice(0).partition_C(cO); - Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int{}))); - Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int{}))); + Tensor tOtO_i = tOtO.compose(make_layout(make_shape(TileM{}, Int{}))); + Tensor tOcO_i = tOcO.compose(make_layout(make_shape(TileM{}, Int{}))); tOtO_i.data() = tOtO_i.data().get() + tmem_O; @@ -991,12 +1095,15 @@ struct Sm100FmhaGenMainloopWarpspecialized { Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{})); Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); - - Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); - Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); - using TMEM_LOAD_V = SM100_TMEM_LOAD_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + Tensor tStS_v = tStS.compose(make_layout(make_shape(TileM{}, kTMEM_V_COLUMNS{}))); + Tensor tScS_v = tScS.compose(make_layout(make_shape(TileM{}, kTMEM_V_COLUMNS{}))); + using TMEM_LOAD_V = + typename kValTyMap, + kValTyPair<128, SM100_TMEM_LOAD_32dp32b2x> // 4x32 threads with 2 cols of 32b elem + >::template query; auto tiled_tmem_loadv = make_tmem_copy(TMEM_LOAD_V{}, tStS_v); auto thr_tmem_loadv = tiled_tmem_loadv.get_slice(thread_idx); @@ -1024,6 +1131,7 @@ struct Sm100FmhaGenMainloopWarpspecialized { pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); Tensor tTMEM_LOADVrS = make_tensor(shape(tTMEM_LOADVcS)); + static_assert(size(tTMEM_LOADVrS) == 2); // read row_wise new global max copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); diff --git a/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py b/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py index 3ce07debff..a3a51d15b4 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py @@ -693,7 +693,7 @@ def _execute_cutlass_blackwell_attn_varlen( for is_mqa in [True, False] for window_size in [(-1, -1), (0, 0), (0, 128), (128, 0), (1024, 0)] for head_dim in [128] - for sm_scale in [None, 1.0 / head_dim] + for sm_scale in [None] for num_groups in [1, 2] ] ) @@ -711,6 +711,14 @@ def test_decode( ) -> None: seqlen_q = 1 causal = True + if True: + print( + f"Running test_decode with params: " + f"dtype={dtype}, seqlen_k={seqlen_k}, batch_size={batch_size}, " + f"is_mqa={is_mqa}, window_size={window_size}, head_dim={head_dim}, " + f"sm_scale={sm_scale}, q_heads={q_heads}" + ) + self._execute_cutlass_blackwell_attn_dense( batch_size, seqlen_q, diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py index 4d55ed2738..27c388d716 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py @@ -86,19 +86,19 @@ class EvictionPolicy(NamedTuple): None # feature_score_counter_decay_rates for each table if eviction strategy is feature score ) training_id_eviction_trigger_count: Optional[list[int]] = ( - None # training_id_eviction_trigger_count for each table + None # Number of training IDs that, when exceeded, will trigger eviction for each table. ) training_id_keep_count: Optional[list[int]] = ( - None # training_id_keep_count for each table + None # Target number of training IDs to retain in each table after eviction. ) l2_weight_thresholds: Optional[list[float]] = ( None # l2_weight_thresholds for each table if eviction strategy is feature l2 norm ) threshold_calculation_bucket_stride: Optional[float] = ( - 0.2 # threshold_calculation_bucket_stride if eviction strategy is feature score + 0.2 # The width of each feature score bucket used for threshold calculation in feature score-based eviction. ) threshold_calculation_bucket_num: Optional[int] = ( - 1000000 # 1M, threshold_calculation_bucket_num if eviction strategy is feature score + 1000000 # 1M, Total number of feature score buckets used for threshold calculation in feature score-based eviction. ) interval_for_insufficient_eviction_s: int = ( # wait at least # seconds before trigger next round of eviction, if last finished eviction is insufficient @@ -114,10 +114,16 @@ class EvictionPolicy(NamedTuple): 24 * 3600 # 1 day, interval for feature statistics decay ) meta_header_lens: Optional[list[int]] = None # metaheader length for each table + eviction_free_mem_threshold_gb: Optional[int] = ( + None # Minimum free memory (in GB) required before triggering eviction when using free_mem trigger mode. + ) + eviction_free_mem_check_interval_batch: Optional[int] = ( + None # Number of batches between checks for free memory threshold when using free_mem trigger mode. + ) def validate(self) -> None: - assert self.eviction_trigger_mode in [0, 1, 2, 3, 4], ( - "eviction_trigger_mode must be 0, 1, 2, 3 or 4 " + assert self.eviction_trigger_mode in [0, 1, 2, 3, 4, 5], ( + "eviction_trigger_mode must be 0, 1, 2, 3, 4, 5" f"actual {self.eviction_trigger_mode}" ) if self.eviction_trigger_mode == 0: @@ -143,6 +149,13 @@ def validate(self) -> None: assert ( self.training_id_eviction_trigger_count is not None ), "training_id_eviction_trigger_count must be set if eviction_trigger_mode is 4" + elif self.eviction_trigger_mode == 5: + assert ( + self.eviction_free_mem_threshold_gb is not None + ), "eviction_free_mem_threshold_gb must be set if eviction_trigger_mode is 5" + assert ( + self.eviction_free_mem_check_interval_batch is not None + ), "eviction_free_mem_check_interval_batch must be set if eviction_trigger_mode is 5" if self.eviction_strategy == 0: assert self.ttls_in_mins is not None, ( @@ -228,6 +241,7 @@ class KVZCHParams(NamedTuple): backend_return_whole_row: bool = False eviction_policy: EvictionPolicy = EvictionPolicy() embedding_cache_mode: bool = False + feature_score_collection_enabled: bool = False def validate(self) -> None: assert len(self.bucket_offsets) == len(self.bucket_sizes), ( @@ -240,6 +254,19 @@ def validate(self) -> None: ), "backend_return_whole_row can only be enabled when enable_optimizer_offloading is enabled" +class KVZCHEvictionTBEConfig(NamedTuple): + # Eviction trigger model for kvzch table: 0: disabled, 1: iteration, 2: mem_util, 3: manual, 4: id count, 5: free_mem + kvzch_eviction_trigger_mode: int = 2 # mem_util + # Minimum free memory (in GB) required before triggering eviction when using free_mem trigger mode. + eviction_free_mem_threshold_gb: int = 200 # 200GB + # Number of batches between checks for free memory threshold when using free_mem trigger mode. + eviction_free_mem_check_interval_batch: int = 1000 + # The width of each feature score bucket used for threshold calculation in feature score-based eviction. + threshold_calculation_bucket_stride: float = 0.2 + # Total number of feature score buckets used for threshold calculation in feature score-based eviction. + threshold_calculation_bucket_num: Optional[int] = 1000000 # 1M + + class BackendType(enum.IntEnum): SSD = 0 DRAM = 1 diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index a572de0738..4f1741b3dc 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -200,6 +200,13 @@ class RESParams: ) # table sizes for the global rows the TBE holds +@dataclass(frozen=True) +class PrefetchedInfo: + linear_unique_indices: torch.Tensor + linear_unique_indices_length: torch.Tensor + hash_zch_identities: Optional[torch.Tensor] + + def construct_split_state( embedding_specs: list[tuple[int, int, EmbeddingLocation, ComputeDevice]], rowwise: bool, @@ -813,7 +820,7 @@ def __init__( # noqa C901 assert ( self.pooling_mode != PoolingMode.NONE ), "Mixed dimension tables only supported for pooling tables." - + self.mixed_D = mixed_D assert all( cd == compute_devices[0] for cd in compute_devices ), "Heterogenous compute_devices are NOT supported!" @@ -1553,10 +1560,10 @@ def get_table_name_for_logging(table_names: Optional[list[str]]) -> str: return "" # Do this because sometimes multiple shards of the same table could appear # in one TBE. - table_name_set = set(table_names) + table_name_set = sorted(set(table_names)) if len(table_name_set) == 1: return next(iter(table_name_set)) - return f"<{len(table_name_set)} tables>" + return f"<{len(table_name_set)} tables>: {table_name_set}" @staticmethod def get_prefetch_passes( @@ -2100,6 +2107,12 @@ def forward( # noqa: C901 requires this information for allocating the weight gradient tensor in the backward pass. + hash_zch_identities (Optional[Tensor]): The original raw IDs before + remapping to ZCH (Zero-Collision Hash) table slots. This tensor is + populated when using Multi-Probe Zero Collision Hash (MPZCH) modules + and is required for Raw Embedding Streaming (RES) to maintain + consistency between training and inference. + Returns: A 2D-tensor containing looked up data. Shape `(B, total_D)` where `B` = batch size and `total_D` = the sum of all embedding dimensions in the @@ -2217,7 +2230,6 @@ def forward( # noqa: C901 # In forward, we don't enable multi-pass prefetch as we want the process # to be as fast as possible and memory usage doesn't matter (will be recycled # by dense fwd/bwd) - # TODO: Properly pass in the hash_zch_identities self._prefetch( indices, offsets, @@ -2505,6 +2517,7 @@ def forward( # noqa: C901 row_counter, iter_int, self.max_counter.item(), + mixed_D=self.mixed_D, ), ) elif self._used_rowwise_adagrad_with_global_weight_decay: @@ -2523,6 +2536,7 @@ def forward( # noqa: C901 # `Optional[Tensor]` but got `Union[Module, Tensor]`. prev_iter_dev=self.prev_iter_dev, gwd_lower_bound=self.gwd_lower_bound, + mixed_D=self.mixed_D, ), ) else: @@ -2532,6 +2546,7 @@ def forward( # noqa: C901 common_args, self.optimizer_args, momentum1, + mixed_D=self.mixed_D, ), ) @@ -4140,6 +4155,60 @@ def raw_embedding_stream(self) -> None: False, # blocking_tensor_copy ) + @staticmethod + @torch.jit.ignore + def _get_prefetched_info( + linear_cache_indices_merged: torch.Tensor, + total_cache_hash_size: int, + hash_zch_identities: Optional[torch.Tensor], + ) -> PrefetchedInfo: + compute_inverse_indices = hash_zch_identities is not None + ( + linear_unique_indices, + linear_unique_indices_length, + linear_unique_indices_cnt, + linear_unique_inverse_indices, + ) = torch.ops.fbgemm.get_unique_indices_with_inverse( + linear_cache_indices_merged, + total_cache_hash_size, + compute_count=compute_inverse_indices, + compute_inverse_indices=compute_inverse_indices, + ) + # linear_unique_indices is the result after deduplication and sorting + linear_unique_indices = linear_unique_indices.narrow( + 0, 0, linear_unique_indices_length[0] + ) + + if hash_zch_identities is None: + return PrefetchedInfo( + linear_unique_indices, + linear_unique_indices_length, + None, + ) + + # Compute cumulative sum as indices for selecting unique elements to + # map hash_zch_identities to linear_unique_indices + count_cum_sum = torch.ops.fbgemm.asynchronous_complete_cumsum( + linear_unique_indices_cnt + ) + count_cum_sum = count_cum_sum.narrow(0, 0, linear_unique_indices_length[0]) + + # Select indices corresponding to first occurrence of each unique element + linear_unique_inverse_indices = linear_unique_inverse_indices.index_select( + dim=0, index=count_cum_sum + ) + + # Map hash_zch_identities to unique indices + hash_zch_identities_cpu = hash_zch_identities.index_select( + dim=0, index=linear_unique_inverse_indices + ).to(device=torch.device("cpu")) + + return PrefetchedInfo( + linear_unique_indices, + linear_unique_indices_length, + hash_zch_identities_cpu, + ) + @torch.jit.ignore def _store_prefetched_tensors( self, @@ -4150,35 +4219,26 @@ def _store_prefetched_tensors( NOTE: this needs to be a method with jit.ignore as the identities tensor is conditional. This function stores the prefetched tensors for the raw embedding streaming. """ - if self.enable_raw_embedding_streaming: - with record_function( - "## uvm_save_prefetched_rows {} {} ##".format(self.timestep, self.uuid) - ): + if not self.enable_raw_embedding_streaming: + return + + with record_function( + "## uvm_save_prefetched_rows {} {} ##".format(self.timestep, self.uuid) + ): + # Process hash_zch_identities using helper function + prefetched_info = self._get_prefetched_info( + linear_cache_indices_merged, + self.total_cache_hash_size, + hash_zch_identities, + ) + + self.prefetched_info.append( ( - linear_unique_indices, - linear_unique_indices_length, - _, - ) = torch.ops.fbgemm.get_unique_indices( - linear_cache_indices_merged, - self.total_cache_hash_size, - compute_count=False, - ) - linear_unique_indices = linear_unique_indices.narrow( - 0, 0, linear_unique_indices_length[0] - ) - self.prefetched_info.append( - ( - linear_unique_indices, - linear_unique_indices_length, - ( - hash_zch_identities.index_select( - dim=0, index=linear_unique_indices - ).to(device=torch.device("cpu")) - if hash_zch_identities is not None - else None - ), - ) + prefetched_info.linear_unique_indices, + prefetched_info.linear_unique_indices_length, + prefetched_info.hash_zch_identities, ) + ) @torch.jit.ignore def __report_input_params_factory( diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py index 1243f14db4..f0ac6f1a70 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py @@ -153,6 +153,13 @@ def benchmark_cpu_requests_mp( float: The average runtime per iteration in seconds. """ + import os + + strategy = os.environ.get("PYTORCH_SHARE_STRATEGY") + current_strategy = torch.multiprocessing.get_sharing_strategy() + if strategy is not None and current_strategy != strategy: + torch.multiprocessing.set_sharing_strategy(strategy) + cpu_bm_barrier.create_barrier(num_copies) worker_pool = torch.multiprocessing.Pool(num_copies) @@ -699,4 +706,4 @@ def benchmark_vbe( # pyre-ignore[61] bwd_time_sec = statistics.median(bwd_times_sec) - return fwd_time_sec, bwd_time_sec + return fwd_time_sec, bwd_time_sec \ No newline at end of file diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index a497cf9a5b..4cdbe4a2eb 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -18,8 +18,9 @@ import time from functools import cached_property from math import floor, log2 -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, ClassVar, Optional, Union import torch # usort:skip +import weakref # @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers @@ -34,6 +35,7 @@ BoundsCheckMode, CacheAlgorithm, EmbeddingLocation, + EvictionPolicy, get_bounds_check_version_for_platform, KVZCHParams, PoolingMode, @@ -54,6 +56,8 @@ from torch import distributed as dist, nn, Tensor # usort:skip from dataclasses import dataclass +import psutil + from torch.autograd.profiler import record_function from ..cache import get_unique_indices_v2 @@ -100,6 +104,9 @@ class SSDTableBatchedEmbeddingBags(nn.Module): _local_instance_index: int = -1 res_params: RESParams table_names: list[str] + _all_tbe_instances: ClassVar[weakref.WeakSet] = weakref.WeakSet() + _first_instance_ref: ClassVar[weakref.ref] = None + _eviction_triggered: ClassVar[bool] = False def __init__( self, @@ -179,6 +186,7 @@ def __init__( table_names: Optional[list[str]] = None, use_rowwise_bias_correction: bool = False, # For Adam use optimizer_state_dtypes: dict[str, SparseType] = {}, # noqa: B006 + pg: Optional[dist.ProcessGroup] = None, ) -> None: super(SSDTableBatchedEmbeddingBags, self).__init__() @@ -567,6 +575,10 @@ def __init__( # loading checkpoint flag, set by checkpoint loader, and cleared after weight is applied to backend self.load_state_dict: bool = False + SSDTableBatchedEmbeddingBags._all_tbe_instances.add(self) + if SSDTableBatchedEmbeddingBags._first_instance_ref is None: + SSDTableBatchedEmbeddingBags._first_instance_ref = weakref.ref(self) + # create tbe unique id using rank index | local tbe idx if tbe_unique_id == -1: SSDTableBatchedEmbeddingBags._local_instance_index += 1 @@ -584,6 +596,7 @@ def __init__( self.tbe_unique_id = tbe_unique_id self.l2_cache_size = l2_cache_size logging.info(f"tbe_unique_id: {tbe_unique_id}") + self.enable_free_mem_trigger_eviction: bool = False if self.backend_type == BackendType.SSD: logging.info( f"Logging SSD offloading setup, tbe_unique_id:{tbe_unique_id}, l2_cache_size:{l2_cache_size}GB, " @@ -688,25 +701,31 @@ def __init__( if self.kv_zch_params.eviction_policy.eviction_mem_threshold_gb else self.l2_cache_size ) + kv_zch_params = self.kv_zch_params + eviction_policy = self.kv_zch_params.eviction_policy + if eviction_policy.eviction_trigger_mode == 5: + # If trigger mode is free_mem(5), populate config + self.set_free_mem_eviction_trigger_config(eviction_policy) + # Please refer to https://fburl.com/gdoc/nuupjwqq for the following eviction parameters. eviction_config = torch.classes.fbgemm.FeatureEvictConfig( - self.kv_zch_params.eviction_policy.eviction_trigger_mode, # eviction is disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual, 4: id count - self.kv_zch_params.eviction_policy.eviction_strategy, # evict_trigger_strategy: 0: timestamp, 1: counter, 2: counter + timestamp, 3: feature l2 norm, 4: timestamp threshold 5: feature score - self.kv_zch_params.eviction_policy.eviction_step_intervals, # trigger_step_interval if trigger mode is iteration + eviction_policy.eviction_trigger_mode, # eviction is disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual, 4: id count + eviction_policy.eviction_strategy, # evict_trigger_strategy: 0: timestamp, 1: counter, 2: counter + timestamp, 3: feature l2 norm, 4: timestamp threshold 5: feature score + eviction_policy.eviction_step_intervals, # trigger_step_interval if trigger mode is iteration eviction_mem_threshold_gb, # mem_util_threshold_in_GB if trigger mode is mem_util - self.kv_zch_params.eviction_policy.ttls_in_mins, # ttls_in_mins for each table if eviction strategy is timestamp - self.kv_zch_params.eviction_policy.counter_thresholds, # counter_thresholds for each table if eviction strategy is counter - self.kv_zch_params.eviction_policy.counter_decay_rates, # counter_decay_rates for each table if eviction strategy is counter - self.kv_zch_params.eviction_policy.feature_score_counter_decay_rates, # feature_score_counter_decay_rates for each table if eviction strategy is feature score - self.kv_zch_params.eviction_policy.training_id_eviction_trigger_count, # training_id_eviction_trigger_count for each table - self.kv_zch_params.eviction_policy.training_id_keep_count, # training_id_keep_count for each table - self.kv_zch_params.eviction_policy.l2_weight_thresholds, # l2_weight_thresholds for each table if eviction strategy is feature l2 norm + eviction_policy.ttls_in_mins, # ttls_in_mins for each table if eviction strategy is timestamp + eviction_policy.counter_thresholds, # counter_thresholds for each table if eviction strategy is counter + eviction_policy.counter_decay_rates, # counter_decay_rates for each table if eviction strategy is counter + eviction_policy.feature_score_counter_decay_rates, # feature_score_counter_decay_rates for each table if eviction strategy is feature score + eviction_policy.training_id_eviction_trigger_count, # training_id_eviction_trigger_count for each table + eviction_policy.training_id_keep_count, # training_id_keep_count for each table + eviction_policy.l2_weight_thresholds, # l2_weight_thresholds for each table if eviction strategy is feature l2 norm table_dims.tolist() if table_dims is not None else None, - self.kv_zch_params.eviction_policy.threshold_calculation_bucket_stride, # threshold_calculation_bucket_stride if eviction strategy is feature score - self.kv_zch_params.eviction_policy.threshold_calculation_bucket_num, # threshold_calculation_bucket_num if eviction strategy is feature score - self.kv_zch_params.eviction_policy.interval_for_insufficient_eviction_s, - self.kv_zch_params.eviction_policy.interval_for_sufficient_eviction_s, - self.kv_zch_params.eviction_policy.interval_for_feature_statistics_decay_s, + eviction_policy.threshold_calculation_bucket_stride, # threshold_calculation_bucket_stride if eviction strategy is feature score + eviction_policy.threshold_calculation_bucket_num, # threshold_calculation_bucket_num if eviction strategy is feature score + eviction_policy.interval_for_insufficient_eviction_s, + eviction_policy.interval_for_sufficient_eviction_s, + eviction_policy.interval_for_feature_statistics_decay_s, ) self._ssd_db = torch.classes.fbgemm.DramKVEmbeddingCacheWrapper( self.cache_row_dim, @@ -1065,6 +1084,8 @@ def __init__( self.bounds_check_version: int = get_bounds_check_version_for_platform() + self._pg = pg + @cached_property def cache_row_dim(self) -> int: """ @@ -2042,6 +2063,9 @@ def _prefetch( # noqa C901 if dist.get_rank() == 0: self._report_kv_backend_stats() + # May trigger eviction if free mem trigger mode enabled before get cuda + self.may_trigger_eviction() + # Fetch data from SSD if linear_cache_indices.numel() > 0: self.record_function_via_dummy_profile( @@ -2089,7 +2113,7 @@ def _prefetch( # noqa C901 torch.tensor( [weights.shape[0]], device="cpu", dtype=torch.long ), - weights.cpu().view(torch.float32).view(-1, 2), + weights.cpu(), ) # Generate row addresses (pointing to either L1 or the current @@ -4650,3 +4674,97 @@ def direct_write_embedding( ) # Return control to the main stream without waiting for the backend operation to complete + + def get_free_cpu_memory_gb(self) -> float: + mem = psutil.virtual_memory() + return mem.available / (1024**3) + + @classmethod + def trigger_evict_in_all_tbes(cls) -> None: + for tbe in cls._all_tbe_instances: + tbe.ssd_db.trigger_feature_evict() + + @classmethod + def tbe_has_ongoing_eviction(cls) -> bool: + for tbe in cls._all_tbe_instances: + if tbe.ssd_db.is_evicting(): + return True + return False + + def set_free_mem_eviction_trigger_config( + self, eviction_policy: EvictionPolicy + ) -> None: + self.enable_free_mem_trigger_eviction = True + self.eviction_trigger_mode: int = eviction_policy.eviction_trigger_mode + assert ( + eviction_policy.eviction_free_mem_check_interval_batch is not None + ), "eviction_free_mem_check_interval_batch is unexpected none for free_mem eviction trigger mode" + self.eviction_free_mem_check_interval_batch: int = ( + eviction_policy.eviction_free_mem_check_interval_batch + ) + assert ( + eviction_policy.eviction_free_mem_threshold_gb is not None + ), "eviction_policy.eviction_free_mem_threshold_gb is unexpected none for free_mem eviction trigger mode" + self.eviction_free_mem_threshold_gb: int = ( + eviction_policy.eviction_free_mem_threshold_gb + ) + logging.info( + f"[FREE_MEM Eviction] eviction config, trigger model: FREE_MEM, {self.eviction_free_mem_check_interval_batch=}, {self.eviction_free_mem_threshold_gb=}" + ) + + def may_trigger_eviction(self) -> None: + def is_first_tbe() -> bool: + first = SSDTableBatchedEmbeddingBags._first_instance_ref + return first is not None and first() is self + + # We assume that the eviction time is less than free mem check interval time + # So every time we reach this check, all evictions in all tbes should be finished. + # We only need to check the first tbe because all tbes share the same free mem, + # once the first tbe detect need to trigger eviction, it will call trigger func + # in all tbes from _all_tbe_instances + if ( + self.enable_free_mem_trigger_eviction + and self.step % self.eviction_free_mem_check_interval_batch == 0 + and self.training + and is_first_tbe() + ): + if not SSDTableBatchedEmbeddingBags.tbe_has_ongoing_eviction(): + SSDTableBatchedEmbeddingBags._eviction_triggered = False + + free_cpu_mem_gb = self.get_free_cpu_memory_gb() + local_evict_trigger = int( + free_cpu_mem_gb < self.eviction_free_mem_threshold_gb + ) + tensor_flag = torch.tensor( + local_evict_trigger, + device=self.current_device, + dtype=torch.int, + ) + world_size = dist.get_world_size(self._pg) + if world_size > 1: + dist.all_reduce(tensor_flag, op=dist.ReduceOp.SUM, group=self._pg) + global_evict_trigger = tensor_flag.item() + else: + global_evict_trigger = local_evict_trigger + if ( + global_evict_trigger >= 1 + and SSDTableBatchedEmbeddingBags._eviction_triggered + ): + logging.warning( + f"[FREE_MEM Eviction] {global_evict_trigger} ranks triggered eviction, but SSDTableBatchedEmbeddingBags._eviction_triggered is true" + ) + if ( + global_evict_trigger >= 1 + and not SSDTableBatchedEmbeddingBags._eviction_triggered + ): + SSDTableBatchedEmbeddingBags._eviction_triggered = True + SSDTableBatchedEmbeddingBags.trigger_evict_in_all_tbes() + logging.info( + f"[FREE_MEM Eviction] Evict all at batch {self.step}, {free_cpu_mem_gb} GB free CPU memory, {global_evict_trigger} ranks triggered eviction" + ) + + def reset_inference_mode(self) -> None: + """ + Reset the inference mode + """ + self.eval() diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h index b55fd72fce..447613c5fc 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h @@ -38,7 +38,7 @@ namespace fbgemm_gpu::rocm { [[nodiscard]] inline bool is_supported_cdna() { - const std::set supported_archs{"gfx942", "gfx90a"}; + const std::set supported_archs{"gfx942", "gfx90a", "gfx950"}; int device_id = 0; HIP_CHECK(hipGetDevice(&device_id)); hipDeviceProp_t dev_props; diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index b3a56c4b52..5475f74ddd 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -21,9 +21,14 @@ * ******************************************************************************/ #pragma once + +#include #include +#include + #include #include +#include /******************************************************************************/ typedef int32_t int32x4_t __attribute__((ext_vector_type(4))); @@ -46,10 +51,10 @@ union amdgcn_buffer_resource { }; template -__device__ int32x4_t amdgcn_make_buffer_resource(const T* addr) { +__device__ int32x4_t amdgcn_make_buffer_resource(const T* addr, const int32_t size_in_bytes = 0xFFFFFFFF) { amdgcn_buffer_resource buffer_resource; buffer_resource.address = const_cast(addr); - buffer_resource.range = 0xffffffff; + buffer_resource.range = size_in_bytes; buffer_resource.config = AMDGCN_BUFFER_RES_3; // for gfx9 return buffer_resource.content; @@ -59,34 +64,70 @@ __device__ int32x4_t amdgcn_make_buffer_resource(const T* addr) { __device__ half llvm_amdgcn_raw_buffer_load_fp16( int32x4_t srsrc, int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16"); + int32_t soffset = 0, + int32_t glc_slc = 0) +#if ROCM_VERSION_MAJOR >= 7 + __asm("llvm.amdgcn.raw.buffer.load.i16"); +#else + __asm("llvm.amdgcn.raw.buffer.load.f16"); +#endif __device__ float llvm_amdgcn_raw_buffer_load_fp32( int32x4_t srsrc, int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32"); + int32_t soffset = 0, + int32_t glc_slc = 0) __asm("llvm.amdgcn.raw.buffer.load.f32"); __device__ half2 llvm_amdgcn_raw_buffer_load_fp16x2( int32x4_t srsrc, int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16"); + int32_t soffset = 0, + int32_t glc_slc = 0) +#if ROCM_VERSION_MAJOR >= 7 + __asm("llvm.amdgcn.raw.buffer.load.i32"); +#else + __asm("llvm.amdgcn.raw.buffer.load.v2f16"); +#endif + +__device__ void llvm_amdgcn_raw_buffer_store_fp16( + const half vdata, + int32x4_t rsrc, + int32_t voffset, + int32_t soffset = 0, + int32_t glc_slc = 0 +) +#if ROCM_VERSION_MAJOR >= 7 + __asm("llvm.amdgcn.raw.buffer.store.i16"); +#else + __asm("llvm.amdgcn.raw.buffer.store.f16"); +#endif + +__device__ void llvm_amdgcn_raw_buffer_store_fp16x2( + const half2 vdata, + int32x4_t rsrc, + int32_t voffset, + int32_t soffset = 0, + int32_t glc_slc = 0 +) +#if ROCM_VERSION_MAJOR >= 7 + __asm("llvm.amdgcn.raw.buffer.store.i32"); +#else + __asm("llvm.amdgcn.raw.buffer.store.v2f16"); +#endif __device__ void llvm_amdgcn_raw_buffer_store_fp32( float vdata, int32x4_t rsrc, int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32"); + int32_t soffset = 0, + int32_t glc_slc = 0) __asm("llvm.amdgcn.raw.buffer.store.f32"); __device__ void llvm_amdgcn_raw_buffer_store_fp32x2( floatx2_t vdata, int32x4_t rsrc, int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32"); + int32_t soffset = 0, + int32_t glc_slc = 0) __asm("llvm.amdgcn.raw.buffer.store.v2f32"); /******************************************************************************/ @@ -96,35 +137,15 @@ struct load_row_per_warp { emb_t* emb_data, index_t row_index, const emb_t* p_emb_table, - int lane_id) {} -}; - -template -struct load_row_per_warp { - static constexpr int dword_per_row = - (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; - static __device__ void run( - float* emb_data, - index_t row_index, - const float* p_emb_table, int lane_id) { - int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * embedding_dim); -#pragma unroll - for (int i = 0; i < dword_per_row; i++) { - if constexpr (embedding_dim == 160) { - if ((lane_id + i * THREADS_PER_ROW) % 192 < 160) { - emb_data[i] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + i * THREADS_PER_ROW) * sizeof(float), 0, 0); + // Types are not supported, but we need an instance of run method to avoid run-time .so symbol + // failure. Currently, the kernel dispatch for unsupported type is guarded on host side + if constexpr (std::is_same_v || std::is_same_v) { + __builtin_trap(); } else { - emb_data[i] = 0.f; + static_assert(false, "HIP: Optimized load operation is not supported yet"); } - } else { - emb_data[i] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + i * THREADS_PER_ROW) * sizeof(float), 0, 0); } - } - } }; template @@ -134,7 +155,7 @@ struct load_row_per_warp { int32x4_t emb_res = amdgcn_make_buffer_resource(p_emb_table + row_index * 64); emb_data[0] = - llvm_amdgcn_raw_buffer_load_fp16(emb_res, lane_id * sizeof(half), 0, 0); + llvm_amdgcn_raw_buffer_load_fp16(emb_res, lane_id * sizeof(half)); } }; @@ -145,7 +166,7 @@ struct load_row_per_warp { int32x4_t emb_res = amdgcn_make_buffer_resource(p_emb_table + row_index * 128); *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); + emb_res, lane_id * sizeof(half2)); } }; @@ -154,15 +175,11 @@ struct load_row_per_warp { static __device__ void run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 192); + amdgcn_make_buffer_resource(p_emb_table + row_index * 160, sizeof(half) * 160); *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); - if ((lane_id + 128) % 192 < 160) { + emb_res, lane_id * sizeof(half2)); emb_data[2] = llvm_amdgcn_raw_buffer_load_fp16( - emb_res, (lane_id + 128) * sizeof(half), 0, 0); - } else { - emb_data[2] = __float2half(0.0); - } + emb_res, (lane_id + 128) * sizeof(half)); } }; @@ -173,9 +190,9 @@ struct load_row_per_warp { int32x4_t emb_res = amdgcn_make_buffer_resource(p_emb_table + row_index * 192); *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); + emb_res, lane_id * sizeof(half2)); emb_data[2] = llvm_amdgcn_raw_buffer_load_fp16( - emb_res, (lane_id + 128) * sizeof(half), 0, 0); + emb_res, (lane_id + 128) * sizeof(half)); } }; @@ -187,31 +204,133 @@ struct load_row_per_warp { amdgcn_make_buffer_resource(p_emb_table + row_index * 256); *reinterpret_cast(&emb_data[0]) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); + emb_res, lane_id * sizeof(half2)); *reinterpret_cast(&emb_data[2]) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64) * sizeof(half2), 0, 0); + emb_res, (lane_id + 64) * sizeof(half2)); } }; template -struct load_row_per_warp { +struct load_row_per_warp { static __device__ void run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 512); + amdgcn_make_buffer_resource(p_emb_table + row_index * 320, sizeof(half) * 320); *reinterpret_cast(&emb_data[0]) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); + emb_res, lane_id * sizeof(half2)); *reinterpret_cast(&emb_data[2]) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64) * sizeof(half2), 0, 0); - *reinterpret_cast(&emb_data[4]) = - llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64 * 2) * sizeof(half2), 0, 0); - *reinterpret_cast(&emb_data[6]) = - llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64 * 3) * sizeof(half2), 0, 0); + emb_res, (lane_id + 64) * sizeof(half2)); + emb_data[4] = llvm_amdgcn_raw_buffer_load_fp16( + emb_res, (lane_id + 128) * sizeof(half)); + } +}; + +template +struct load_row_per_warp { + static __device__ void run( + c10::Half* emb_data, + index_t row_index, + const c10::Half* p_emb_table, + int lane_id) { + load_row_per_warp::run( + reinterpret_cast(emb_data), + row_index, + reinterpret_cast(p_emb_table), + lane_id + ); + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(float* emb_data, index_t row_index, const float* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 64); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(float* emb_data, index_t row_index, const float* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 128); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + emb_data[1] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 64) * sizeof(float)); + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(float* emb_data, index_t row_index, const float* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 160, sizeof(float) * 160); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + emb_data[1] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 64) * sizeof(float)); + emb_data[2] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 128) * sizeof(float)); + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(float* emb_data, index_t row_index, const float* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 192); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + emb_data[1] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 64) * sizeof(float)); + emb_data[2] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 128) * sizeof(float)); + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(float* emb_data, index_t row_index, const float* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 256); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + emb_data[1] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 64) * sizeof(float)); + emb_data[2] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 128) * sizeof(float)); + emb_data[3] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 192) * sizeof(float)); + } +}; + +template +struct load_row_per_warp { + static __device__ void + run(float* emb_data, index_t row_index, const float* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 320, sizeof(float) * 320); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + emb_data[1] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 64) * sizeof(float)); + emb_data[2] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 128) * sizeof(float)); + emb_data[3] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 192) * sizeof(float)); + emb_data[4] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, (lane_id + 256) * sizeof(float)); } }; @@ -233,93 +352,156 @@ struct accumulate_row_per_warp { } else { #pragma unroll for (int i = 0; i < dword_per_row; i++) { - acc[i] += static_cast((float)emb_data[i] * row_weight); + if constexpr (std::is_same_v) + { + acc[i] += static_cast(__half2float(emb_data[i]) * row_weight); + } + else + { + acc[i] += static_cast(static_cast(emb_data[i]) * row_weight); + } } } } }; -template +template struct store_row_per_warp { - static constexpr int dword_per_row = - (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; - static __device__ void run(output_t* acc, output_t* p_output, int lane_id) { - if constexpr (embedding_dim == 160) { - for (int i = 0; i < dword_per_row; i++) { - if ((lane_id + i * THREADS_PER_ROW) % 192 < 160) { - p_output[lane_id + i * THREADS_PER_ROW] = acc[i]; - } - } + static __device__ void run(const emb_t* acc, emb_t* p_output, int lane_id) { + // Types are not supported, but we need an instance of run method to avoid run-time .so symbol + // failure. Currently, the kernel dispatch for unsupported type is guarded on host function + if constexpr (std::is_same_v || std::is_same_v) { + __builtin_trap(); } else { -#pragma unroll - for (int i = 0; i < dword_per_row; i++) { - p_output[lane_id + i * THREADS_PER_ROW] = acc[i]; - } + static_assert(false, "HIP: Optimized load operation is not supported yet"); } } }; template <> -struct store_row_per_warp { - static __device__ void run(float* acc, float* p_output, int lane_id) { +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), - out_res, - lane_id * sizeof(floatx2_t), - 0, - 0); + llvm_amdgcn_raw_buffer_store_fp16(acc[0], out_res, lane_id * sizeof(half)); } }; template <> -struct store_row_per_warp { - static __device__ void run(float* acc, float* p_output, int lane_id) { +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), - out_res, - lane_id * sizeof(floatx2_t), - 0, - 0); - if ((lane_id + 128) % 192 < 160) { - llvm_amdgcn_raw_buffer_store_fp32( - acc[2], out_res, (lane_id + 128) * sizeof(float), 0, 0); - } + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output, 160 * sizeof(half)); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16(acc[2], out_res, (lane_id + 128) * sizeof(half)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16(acc[2], out_res, (lane_id + 128) * sizeof(half)); } }; template <> -struct store_row_per_warp { - static __device__ void run(float* acc, float* p_output, int lane_id) { +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), - out_res, - lane_id * sizeof(floatx2_t), - 0, - 0); - llvm_amdgcn_raw_buffer_store_fp32( - acc[2], out_res, (lane_id + 128) * sizeof(float), 0, 0); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc + 2), out_res, (lane_id + 64) * sizeof(half2)); } }; template <> -struct store_row_per_warp { - static __device__ void run(float* acc, float* p_output, int lane_id) { +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output, 320 * sizeof(half)); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16x2(*reinterpret_cast(acc + 2), out_res, (lane_id + 64) * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16(acc[4], out_res, (lane_id + 256) * sizeof(half)); + } +}; + +template +struct store_row_per_warp { + static __device__ void run( + const c10::Half* emb_data, + c10::Half* p_emb_table, + int lane_id) { + store_row_per_warp::run( + reinterpret_cast(emb_data), + reinterpret_cast(p_emb_table), + lane_id + ); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), - out_res, - lane_id * sizeof(floatx2_t), - 0, - 0); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(&acc[2]), - out_res, - (lane_id + 64) * sizeof(floatx2_t), - 0, - 0); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[1], out_res, (lane_id + 64) * sizeof(float)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output, sizeof(float) * 160); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[1], out_res, (lane_id + 64) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[2], out_res, (lane_id + 128) * sizeof(float)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[1], out_res, (lane_id + 64) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[2], out_res, (lane_id + 128) * sizeof(float)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[1], out_res, (lane_id + 64) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[2], out_res, (lane_id + 128) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[3], out_res, (lane_id + 192) * sizeof(float)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output, sizeof(float) * 320); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[1], out_res, (lane_id + 64) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[2], out_res, (lane_id + 128) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[4], out_res, (lane_id + 192) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32(acc[5], out_res, (lane_id + 256) * sizeof(float)); } }; @@ -471,7 +653,7 @@ __device__ __forceinline__ void generic_dpp_reduction(data_t& result) { // of trivial operation with an option to use custom operation template __device__ __forceinline__ void dpp_reduction(data_t& result) { -#if defined(__gfx942__) || defined(__gfx90a__) +#if defined(__gfx942__) || defined(__gfx90a__) || defined(__gfx950__) if constexpr (std::is_same_v) { DPP_REDUCE_F16_F32(add); return; diff --git a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.h b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.h index 108b8eba5e..4334efd4b8 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.h +++ b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.h @@ -32,6 +32,7 @@ generate_vbe_metadata( const at::Tensor& D_offsets, const int64_t D, const bool nobag, - const int64_t max_B_feature_rank, + const c10::SymInt max_B_feature_rank, const int64_t info_B_num_bits, - const int64_t total_B); + const c10::SymInt total_B, + const std::optional& vbe_output_offsets); diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh b/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh old mode 100644 new mode 100755 index 0d65c4798a..d51e3fa475 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh @@ -21,7 +21,9 @@ #include #endif #include - +#ifdef USE_ROCM +#include "fbgemm_gpu/rocm/split_embeddings_common.h" +#endif namespace { inline int get_device_sm_cnt_() { @@ -138,11 +140,19 @@ template DEVICE_INLINE T warpReduceAllSum( T val, unsigned shfl_sync_mask = static_cast(kFullWarpMask)) { -#pragma unroll - for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) { - val += shfl_xor(val, mask, ReduceWidth, shfl_sync_mask); - } - return val; + #ifdef USE_ROCM + return rocm::wave_reduce< + rocm::reduce_op::sum, // Sum reduction + T, // Data type + ReduceWidth // Wave/Warp size + >(val); + #else + #pragma unroll + for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) { + val += shfl_xor(val, mask, ReduceWidth, shfl_sync_mask); + } + return val; + #endif } DEVICE_INLINE void syncwarp() { diff --git a/fbgemm_gpu/requirements.txt b/fbgemm_gpu/requirements.txt index c1f0bb92ff..dcd13bfcd9 100644 --- a/fbgemm_gpu/requirements.txt +++ b/fbgemm_gpu/requirements.txt @@ -29,3 +29,4 @@ setuptools_git_versioning tabulate patchelf fairscale +psutil diff --git a/fbgemm_gpu/requirements_genai.txt b/fbgemm_gpu/requirements_genai.txt index 59741362a5..722de8de37 100644 --- a/fbgemm_gpu/requirements_genai.txt +++ b/fbgemm_gpu/requirements_genai.txt @@ -30,3 +30,4 @@ setuptools_git_versioning tabulate patchelf fairscale +psutil diff --git a/fbgemm_gpu/setup.py b/fbgemm_gpu/setup.py index 97600fc0bb..dd3246539a 100644 --- a/fbgemm_gpu/setup.py +++ b/fbgemm_gpu/setup.py @@ -4,8 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# @licenselint-loose-mode - import argparse import logging import os @@ -655,7 +653,7 @@ def main(argv: list[str]) -> None: ] + [ f"Programming Language :: Python :: {x}" - for x in ["3", "3.9", "3.10", "3.11", "3.12", "3.13"] + for x in ["3", "3.10", "3.11", "3.12", "3.13"] ], ) diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h index 9738b846cc..4d1d2895a6 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h @@ -770,7 +770,6 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB { CHECK_EQ(indices.size(0), engege_rates.size(0)); auto indices_data_ptr = indices.data_ptr(); auto engage_rate_ptr = engege_rates.data_ptr(); - int64_t stride = 2; { auto before_write_lock_ts = facebook::WallClockUtil::NowInUsecFast(); @@ -785,8 +784,7 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB { index_iter++) { const auto& id_index = *index_iter; auto id = int64_t(indices_data_ptr[id_index]); - float engege_rate = - float(engage_rate_ptr[id_index * stride + 0]); + float engege_rate = float(engage_rate_ptr[id_index]); // use mempool weight_type* block = nullptr; auto before_lookup_cache_ts = @@ -1177,17 +1175,8 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB { void compact() override {} - void trigger_feature_evict( - std::optional inplace_update_ts = std::nullopt) { + void trigger_feature_evict() { if (feature_evict_) { - if (inplace_update_ts.has_value() && - feature_evict_config_.value()->trigger_strategy_ == - EvictTriggerStrategy::BY_TIMESTAMP_THRESHOLD) { - auto* tt_evict = dynamic_cast*>( - feature_evict_.get()); - CHECK(tt_evict != nullptr); - tt_evict->set_eviction_timestamp_threshold(inplace_update_ts.value()); - } feature_evict_->trigger_evict(); } } @@ -1223,6 +1212,11 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB { } break; } + case EvictTriggerMode::FREE_MEM: { + // For free mem eviction, all conditions checked in frontend, no check + // option in backend + return; + } default: break; } @@ -1271,6 +1265,13 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB { } } + bool is_evicting() override { + if (feature_evict_) { + return feature_evict_->is_evicting(); + } + return false; + } + // for inference only, this logs the total hit/miss count // this should be called at the end of full/delta snapshot chunk by chunk // update diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache_wrapper.h b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache_wrapper.h index 8e70b41b93..11c4e43930 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache_wrapper.h +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache_wrapper.h @@ -179,6 +179,14 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder { impl_->set_backend_return_whole_row(backend_return_whole_row); } + void trigger_feature_evict() { + impl_->trigger_feature_evict(); + } + + bool is_evicting() { + return impl_->is_evicting(); + } + void set_feature_score_metadata_cuda( at::Tensor indices, at::Tensor count, diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_wrapper.cpp b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_wrapper.cpp index 8145a42023..6361c4878a 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_wrapper.cpp +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_wrapper.cpp @@ -10,11 +10,17 @@ #include #include #include "deeplearning/fbgemm/fbgemm_gpu/include/fbgemm_gpu/embedding_common.h" // @manual=//deeplearning/fbgemm/fbgemm_gpu:fbgemm_gpu +#include "deeplearning/fbgemm/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_inference_embedding.h" DEFINE_int64( dram_kv_embedding_num_shards, 32, "Number of shards for DRAM KV inference embedding"); +DEFINE_bool( + kv_embedding_async_get_set, + true, + "Whether to use async get/set for DRAM KV inference embedding." + "This should be true for dram but might be different for other non-Dram backends."); namespace fbgemm_gpu { @@ -52,10 +58,10 @@ void DramKVEmbeddingInferenceWrapper::init( << ", row_alignment: " << row_alignment << ", scale_bias_size_in_bytes: " << scale_bias_size_in_bytes << ", max_row_bytes_: " << max_row_bytes_; - if (dram_kv_ != nullptr) { + if (kv_backend_ != nullptr) { return; } - dram_kv_ = std::make_shared>( + kv_backend_ = std::make_shared>( max_row_bytes_, uniform_init_lower_, uniform_init_upper_, @@ -86,14 +92,19 @@ void DramKVEmbeddingInferenceWrapper::init( disable_random_init_); } -std::shared_ptr> -DramKVEmbeddingInferenceWrapper::get_dram_kv() { - return dram_kv_; +int64_t DramKVEmbeddingInferenceWrapper::get_max_row_bytes() const { + return max_row_bytes_; } -void DramKVEmbeddingInferenceWrapper::set_dram_kv( - std::shared_ptr> dram_kv) { - dram_kv_ = std::move(dram_kv); +std::shared_ptr> +DramKVEmbeddingInferenceWrapper::get_kv_backend() { + return kv_backend_; +} + +void DramKVEmbeddingInferenceWrapper::set_kv_backend( + std::shared_ptr> + kv_backend) { + kv_backend_ = std::move(kv_backend); } void DramKVEmbeddingInferenceWrapper::set_embeddings( @@ -106,8 +117,13 @@ void DramKVEmbeddingInferenceWrapper::set_embeddings( inplacee_update_ts = static_cast(inplace_update_ts_opt.value()); } - folly::coro::blockingWait(dram_kv_->inference_set_kv_db_async( - indices, weights, count, inplacee_update_ts)); + + if (FLAGS_kv_embedding_async_get_set) { + folly::coro::blockingWait(kv_backend_->inference_set_kv_db_async( + indices, weights, count, inplacee_update_ts)); + } else { + kv_backend_->set_kv_db_sync(indices, weights, count, inplacee_update_ts); + } } at::Tensor DramKVEmbeddingInferenceWrapper::get_embeddings( @@ -119,24 +135,30 @@ at::Tensor DramKVEmbeddingInferenceWrapper::get_embeddings( max_row_bytes_, }, at::kByte); - folly::coro::blockingWait(dram_kv_->get_kv_db_async(indices, weights, count)); + + if (FLAGS_kv_embedding_async_get_set) { + folly::coro::blockingWait( + kv_backend_->get_kv_db_async(indices, weights, count)); + } else { + kv_backend_->get_kv_db_sync(indices, weights, count); + } return weights; } void DramKVEmbeddingInferenceWrapper::log_inplace_update_stats() { - dram_kv_->log_inplace_update_stats(); + kv_backend_->log_inplace_update_stats(); } void DramKVEmbeddingInferenceWrapper::trigger_evict( int64_t inplace_update_ts_64b) { uint32_t inplace_update_ts_32b = static_cast(inplace_update_ts_64b); - dram_kv_->trigger_feature_evict(inplace_update_ts_32b); - dram_kv_->resume_ongoing_eviction(); + kv_backend_->trigger_feature_evict(inplace_update_ts_32b); + kv_backend_->resume_ongoing_eviction(); } void DramKVEmbeddingInferenceWrapper::wait_evict_completion() { - dram_kv_->wait_until_eviction_done(); + kv_backend_->wait_until_eviction_done(); } c10::List DramKVEmbeddingInferenceWrapper::serialize() const { diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_wrapper.h b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_wrapper.h index 7af9a83d74..1c0af807e9 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_wrapper.h +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_wrapper.h @@ -10,9 +10,10 @@ #include #include -#include "deeplearning/fbgemm/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_inference_embedding.h" +#include "deeplearning/fbgemm/fbgemm_gpu/src/dram_kv_embedding_cache/kv_inference_embedding_interface.h" DECLARE_int64(dram_kv_embedding_num_shards); +DECLARE_bool(kv_embedding_async_get_set); namespace fbgemm_gpu { @@ -46,22 +47,26 @@ class DramKVEmbeddingInferenceWrapper : public torch::jit::CustomClassHolder { void wait_evict_completion(); - std::shared_ptr> get_dram_kv(); + std::shared_ptr> + get_kv_backend(); - void set_dram_kv( - std::shared_ptr> dram_kv); + void set_kv_backend( + std::shared_ptr> + kv_backend); c10::List serialize() const; void deserialize(const c10::List& states); + int64_t get_max_row_bytes() const; + private: int64_t num_shards_ = 32; double uniform_init_lower_ = 0.0; double uniform_init_upper_ = 0.0; bool disable_random_init_ = false; - std::shared_ptr> dram_kv_; + std::shared_ptr> kv_backend_; int64_t max_row_bytes_ = 0; }; diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_inference_embedding.h b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_inference_embedding.h index e40d0ffd5f..57c1f28160 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_inference_embedding.h +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_inference_embedding.h @@ -30,6 +30,7 @@ #include "fbgemm_gpu/utils/dispatch_macros.h" #include "feature_evict.h" #include "fixed_block_pool.h" +#include "kv_inference_embedding_interface.h" namespace kv_mem { @@ -64,7 +65,8 @@ namespace kv_mem { /// @brief An implementation of EmbeddingKVDB for ZCH v.Next /// template -class DramKVInferenceEmbedding { +class DramKVInferenceEmbedding + : public KVInferenceEmbeddingInterface { public: /// DramKVInferenceEmbedding constructor /// @@ -163,7 +165,7 @@ class DramKVInferenceEmbedding { double uniform_init_lower, double uniform_init_upper, int64_t row_storage_bitwidth, - bool disable_random_init) { + bool disable_random_init) override { for (auto i = 0; i < num_shards; ++i) { auto* gen = at::check_generator( at::detail::getDefaultCPUGenerator()); @@ -181,11 +183,26 @@ class DramKVInferenceEmbedding { disable_random_init_ = disable_random_init; } + void set_kv_db_sync( + const at::Tensor& /*indices*/, + const at::Tensor& /*weights*/, + const at::Tensor& /*count*/, + std::optional /*inplace_update_ts*/) override { + throw std::runtime_error("set_kv_db_sync is not implemented for DRAM"); + } + + void get_kv_db_sync( + const at::Tensor& /*indices*/, + const at::Tensor& /*weights*/, + const at::Tensor& /*count*/) override { + throw std::runtime_error("get_kv_db_sync is not implemented for DRAM"); + } + folly::SemiFuture> inference_set_kv_db_async( const at::Tensor& indices, const at::Tensor& weights, const at::Tensor& count, - std::optional inplace_update_ts) { + std::optional inplace_update_ts) override { std::vector>> futures; auto shardid_to_indexes = shard_input(indices, count); @@ -552,15 +569,15 @@ class DramKVInferenceEmbedding { folly::SemiFuture> get_kv_db_async( const at::Tensor& indices, const at::Tensor& weights, - const at::Tensor& count) { + const at::Tensor& count) override { current_iter_++; return get_kv_db_async_impl(indices, weights, count); } - void compact() {} + void compact() override {} void trigger_feature_evict( - std::optional inplace_update_ts = std::nullopt) { + std::optional inplace_update_ts = std::nullopt) override { if (feature_evict_) { if (inplace_update_ts.has_value() && feature_evict_config_.value()->trigger_strategy_ == @@ -574,7 +591,7 @@ class DramKVInferenceEmbedding { } } - void maybe_evict() { + void maybe_evict() override { if (!feature_evict_config_.has_value()) { return; } @@ -603,25 +620,25 @@ class DramKVInferenceEmbedding { } // wait until eviction finishes, if any - void wait_until_eviction_done() { + void wait_until_eviction_done() override { if (feature_evict_) { feature_evict_->wait_until_eviction_done(); } } - size_t get_map_used_memsize_in_bytes() const { + size_t get_map_used_memsize_in_bytes() const override { return kv_store_.getUsedMemSizeInBytes(); } - size_t get_map_actual_used_chunk_in_bytes() const { + size_t get_map_actual_used_chunk_in_bytes() const override { return kv_store_.getActualUsedChunkInBytes(); } - size_t get_num_rows() const { + size_t get_num_rows() const override { return kv_store_.getNumRows(); } - void resume_ongoing_eviction(bool force_resume = false) { + void resume_ongoing_eviction(bool force_resume = false) override { if (!force_resume) { return; } @@ -630,7 +647,7 @@ class DramKVInferenceEmbedding { } } - void pause_ongoing_eviction(bool force_pause = false) { + void pause_ongoing_eviction(bool force_pause = false) override { if (!force_pause) { return; } @@ -648,7 +665,7 @@ class DramKVInferenceEmbedding { // for inference only, this logs the total hit/miss count // this should be called at the end of full/delta snapshot chunk by chunk // update - void log_inplace_update_stats() { + void log_inplace_update_stats() override { int reset_val = 0; auto inplace_update_hit_cnt = inplace_update_hit_cnt_.exchange(reset_val); @@ -661,7 +678,8 @@ class DramKVInferenceEmbedding { << (total_cnt > 0 ? (double)inplace_update_hit_cnt / total_cnt : 0.0); } - std::optional get_feature_evict_metric() const { + std::optional get_feature_evict_metric() + const override { if (!feature_evict_config_.has_value()) { return std::nullopt; } @@ -789,11 +807,11 @@ class DramKVInferenceEmbedding { return shardid_to_indexes; } - void flush_or_compact(const int64_t timestep) {} + void flush_or_compact(const int64_t timestep) override {} std::vector get_dram_kv_perf( const int64_t step, - const int64_t interval) { + const int64_t interval) override { std::vector ret(23, 0); // num metrics if (step > 0 && step % interval == 0) { int reset_val = 0; diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/feature_evict.h b/fbgemm_gpu/src/dram_kv_embedding_cache/feature_evict.h index e0443ee640..5637224754 100644 --- a/fbgemm_gpu/src/dram_kv_embedding_cache/feature_evict.h +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/feature_evict.h @@ -34,7 +34,8 @@ enum class EvictTriggerMode { ITERATION, // Trigger based on iteration steps MEM_UTIL, // Trigger based on memory usage MANUAL, // Manually triggered by upstream - ID_COUNT // Trigger based on id count + ID_COUNT, // Trigger based on id count + FREE_MEM, // Trigger based on free memory }; inline std::string to_string(EvictTriggerMode mode) { switch (mode) { @@ -48,6 +49,8 @@ inline std::string to_string(EvictTriggerMode mode) { return "MANUAL"; case EvictTriggerMode::ID_COUNT: return "ID_COUNT"; + case EvictTriggerMode::FREE_MEM: + return "FREE_MEM"; } } @@ -184,6 +187,9 @@ struct FeatureEvictConfig : public torch::jit::CustomClassHolder { eviction_trigger_stats_log += "]"; break; } + case EvictTriggerMode::FREE_MEM: { + break; + } default: throw std::runtime_error("Unknown evict trigger mode"); } @@ -202,7 +208,6 @@ struct FeatureEvictConfig : public torch::jit::CustomClassHolder { case EvictTriggerStrategy::BY_FEATURE_SCORE: { CHECK(feature_score_counter_decay_rates_.has_value()); - CHECK(training_id_eviction_trigger_count_.has_value()); CHECK(training_id_keep_count_.has_value()); CHECK(threshold_calculation_bucket_stride_.has_value()); CHECK(threshold_calculation_bucket_num_.has_value()); @@ -210,8 +215,6 @@ struct FeatureEvictConfig : public torch::jit::CustomClassHolder { LOG(INFO) << "eviction config, trigger mode:" << to_string(trigger_mode_) << eviction_trigger_stats_log << ", strategy: " << to_string(trigger_strategy_) - << ", training_id_eviction_trigger_count: " - << training_id_eviction_trigger_count_.value() << ", training_id_keep_count:" << training_id_keep_count_.value() << ", ttls_in_mins: " << ttls_in_mins_.value() diff --git a/fbgemm_gpu/src/dram_kv_embedding_cache/kv_inference_embedding_interface.h b/fbgemm_gpu/src/dram_kv_embedding_cache/kv_inference_embedding_interface.h new file mode 100644 index 0000000000..0f9090221b --- /dev/null +++ b/fbgemm_gpu/src/dram_kv_embedding_cache/kv_inference_embedding_interface.h @@ -0,0 +1,158 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include +#include +#include +#include + +#include "feature_evict.h" + +namespace kv_mem { + +/// @ingroup KVMemEmbedding +/// +/// @brief Interface for KV Inference Embedding implementations +/// +/// This interface defines the core API that all KV embedding implementations +/// must provide, enabling different backend implementations (DRAM, SSD, etc.) +/// to be used interchangeably. +/// +template +class KVInferenceEmbeddingInterface { + public: + virtual ~KVInferenceEmbeddingInterface() = default; + + /// Initialize the initializers for weight initialization + /// + /// @param num_shards number of shards for the kvstore + /// @param max_D the maximum dimension of embedding tensor + /// @param uniform_init_lower the lower bound of the uniform distribution + /// @param uniform_init_upper the upper bound of the uniform distribution + /// @param row_storage_bitwidth storage bitwidth for each row + /// @param disable_random_init whether to disable random initialization + virtual void initialize_initializers( + int64_t num_shards, + int64_t max_D, + double uniform_init_lower, + double uniform_init_upper, + int64_t row_storage_bitwidth, + bool disable_random_init) = 0; + + /// Set embeddings in the KV store (sync version) + /// + /// @param indices The 1D embedding index tensor + /// @param weights The 2D tensor containing embeddings + /// @param count A single element tensor with number of indices to process + /// @param inplace_update_ts Optional timestamp for inplace update + virtual void set_kv_db_sync( + const at::Tensor& indices, + const at::Tensor& weights, + const at::Tensor& count, + std::optional inplace_update_ts) = 0; + + /// Get embeddings from KV store (sync version) + /// + /// @param indices The 1D embedding index tensor + /// @param weights The 2D tensor to be filled with embeddings + /// @param count A single element tensor with number of indices to process + virtual void get_kv_db_sync( + const at::Tensor& indices, + const at::Tensor& weights, + const at::Tensor& count) = 0; + + /// Set embeddings in the KV store (async inference version) + /// + /// @param indices The 1D embedding index tensor + /// @param weights The 2D tensor containing embeddings + /// @param count A single element tensor with number of indices to process + /// @param inplace_update_ts Optional timestamp for inplace update + /// @return SemiFuture for async completion + virtual folly::SemiFuture> inference_set_kv_db_async( + const at::Tensor& indices, + const at::Tensor& weights, + const at::Tensor& count, + std::optional inplace_update_ts) = 0; + + /// Get embeddings from KV store (async) + /// + /// @param indices The 1D embedding index tensor + /// @param weights The 2D tensor to be filled with embeddings + /// @param count A single element tensor with number of indices to process + /// @return SemiFuture for async completion + virtual folly::SemiFuture> get_kv_db_async( + const at::Tensor& indices, + const at::Tensor& weights, + const at::Tensor& count) = 0; + + /// Compact the KV store (placeholder for future implementations) + virtual void compact() = 0; + + /// Trigger feature eviction + /// + /// @param inplace_update_ts Optional timestamp for eviction threshold + virtual void trigger_feature_evict( + std::optional inplace_update_ts = std::nullopt) = 0; + + /// Maybe trigger eviction based on configured trigger mode + virtual void maybe_evict() = 0; + + /// Wait until ongoing eviction completes + virtual void wait_until_eviction_done() = 0; + + /// Get the total memory used by the KV store + /// + /// @return Memory size in bytes + virtual size_t get_map_used_memsize_in_bytes() const = 0; + + /// Get the actual memory used by allocated chunks + /// + /// @return Memory size in bytes + virtual size_t get_map_actual_used_chunk_in_bytes() const = 0; + + /// Get the number of rows in the KV store + /// + /// @return Number of rows + virtual size_t get_num_rows() const = 0; + + /// Resume ongoing eviction + /// + /// @param force_resume Force resume even if not paused + virtual void resume_ongoing_eviction(bool force_resume = false) = 0; + + /// Pause ongoing eviction + /// + /// @param force_pause Force pause even if not running + virtual void pause_ongoing_eviction(bool force_pause = false) = 0; + + /// Log statistics for inplace update (inference only) + virtual void log_inplace_update_stats() = 0; + + /// Get feature eviction metrics + /// + /// @return Optional metrics tensors + virtual std::optional get_feature_evict_metric() + const = 0; + + /// Get performance metrics + /// + /// @param step Current step/iteration + /// @param interval Reporting interval + /// @return Vector of performance metrics + virtual std::vector get_dram_kv_perf( + const int64_t step, + const int64_t interval) = 0; + + /// Flush or compact at a specific timestep + /// + /// @param timestep The timestep for flush/compact + virtual void flush_or_compact(const int64_t timestep) = 0; +}; + +} // namespace kv_mem diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops.cu b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops.cu index dbccc6fdfd..bb6e3e2b96 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops.cu @@ -8,11 +8,6 @@ #include "common.cuh" -FBGEMM_OP_DISPATCH(CUDA, "dense_to_jagged", fbgemm_gpu::dense_to_jagged); -FBGEMM_OP_DISPATCH( - CUDA, - "jagged_to_padded_dense", - fbgemm_gpu::jagged_to_padded_dense); FBGEMM_OP_DISPATCH( CUDA, "jagged_dense_elementwise_add", diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp index 1a20f680b8..ac14fdd975 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp @@ -48,6 +48,8 @@ class JaggedToPaddedDenseOp const std::vector& offsets, at::ArrayRef max_lengths, const double padding_value)>(); + + at::AutoDispatchBelowAutograd mode; Tensor padded_values = op.call(values, offsets, max_lengths, padding_value); return {padded_values}; @@ -286,6 +288,7 @@ class DenseToJaggedOp : public torch::autograd::Function { const Tensor& dense, const std::vector& offsets, std::optional total_L)>(); + at::AutoDispatchBelowAutograd mode; auto output = op.call(dense, offsets, total_L); return {output}; @@ -785,7 +788,7 @@ class JaggedSliceOp : public torch::autograd::Function { } // namespace ///@ingroup jagged-tensor-ops-cpu -Tensor jagged_to_padded_dense( +Tensor jagged_to_padded_dense_forward_autograd( const Tensor& values, const std::vector& offsets, const c10::SymIntArrayRef max_lengths, @@ -793,6 +796,22 @@ Tensor jagged_to_padded_dense( return JaggedToPaddedDenseOp::apply( values, offsets, max_lengths, padding_value)[0]; } +Tensor jagged_to_padded_dense( + const Tensor& values, + const std::vector& offsets, + const c10::SymIntArrayRef max_lengths, + const double padding_value) { + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::jagged_to_padded_dense_forward", "") + .typed& offsets, + at::ArrayRef max_lengths, + const double padding_value)>(); + Tensor output = op.call(values, offsets, max_lengths, padding_value); + return output; +} ///@ingroup jagged-tensor-ops-cpu /// Output = x + y where x is jagged, y and output are dense @@ -855,7 +874,20 @@ std::tuple> dense_to_jagged( const Tensor& dense, const std::vector& offsets, std::optional total_L) { - return {DenseToJaggedOp::apply(dense, offsets, total_L)[0], offsets}; + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::dense_to_jagged_forward", "") + .typed& offsets, + std::optional total_L)>(); + auto output = op.call(dense, offsets, total_L); + return {output, offsets}; +} +Tensor dense_to_jagged_forward_autograd( + const Tensor& dense, + const std::vector& offsets, + std::optional total_L) { + return DenseToJaggedOp::apply(dense, offsets, total_L)[0]; } ///@ingroup jagged-tensor-ops-cpu @@ -973,6 +1005,12 @@ TORCH_LIBRARY_IMPL(fbgemm, Autograd, m) { m.impl("jagged_jagged_bmm", TORCH_FN(fbgemm_gpu::jagged_jagged_bmm)); m.impl("jagged_dense_bmm", TORCH_FN(fbgemm_gpu::jagged_dense_bmm)); m.impl("jagged_slice", TORCH_FN(fbgemm_gpu::jagged_slice)); + m.impl( + "jagged_to_padded_dense_forward", + TORCH_FN(fbgemm_gpu::jagged_to_padded_dense_forward_autograd)); + m.impl( + "dense_to_jagged_forward", + TORCH_FN(fbgemm_gpu::dense_to_jagged_forward_autograd)); } TORCH_LIBRARY_IMPL(fbgemm, CompositeImplicitAutograd, m) { diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp index eb047b882e..c5512509ff 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp @@ -1818,13 +1818,11 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { DISPATCH_TO_CPU("jagged_2d_to_dense", fbgemm_gpu::jagged_2d_to_dense); DISPATCH_TO_CPU("jagged_1d_to_dense", fbgemm_gpu::jagged_1d_to_dense); - DISPATCH_TO_CPU("dense_to_jagged", fbgemm_gpu::dense_to_jagged); DISPATCH_TO_CPU( "dense_to_jagged_forward", fbgemm_gpu::dense_to_jagged_forward); - DISPATCH_TO_CPU("jagged_to_padded_dense", fbgemm_gpu::jagged_to_padded_dense); DISPATCH_TO_CPU( "jagged_to_padded_dense_forward", - fbgemm_gpu::jagged_to_padded_dense_forward); + fbgemm_gpu::jagged_to_padded_dense_forward_cpu); DISPATCH_TO_CPU( "jagged_to_padded_dense_backward", fbgemm_gpu::jagged_to_padded_dense_backward); diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp index 43cbb1c9bf..87c2ad23f0 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp @@ -53,18 +53,21 @@ Tensor jagged_to_padded_dense_meta( Tensor jagged_to_padded_dense_backward_meta( const at::Tensor& grad_output, - const std::vector& /*offsets*/, + const std::vector& offsets, at::SymInt total_L) { const auto& grad_padded_values = grad_output; - at::SymInt D = grad_padded_values.sym_size(-1); + const bool D_folded = grad_padded_values.dim() == offsets.size() + 1; + const auto& grad_padded_values_view = + D_folded ? grad_padded_values.unsqueeze(-1) : grad_padded_values; + at::SymInt D = grad_padded_values_view.sym_size(-1); // Initialize with zeros so output will be zero for the portion truncated // in forward. auto grad_values = at::zeros_symint({std::move(total_L), D}, grad_padded_values.options()); TORCH_CHECK(grad_values.is_meta()); - return grad_values; + return D_folded ? grad_values.squeeze(-1) : grad_values; } Tensor jagged_dense_dense_elementwise_add_jagged_output_forward_meta( diff --git a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu index c1ac40dea6..96c57cde68 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_group_index.cu @@ -12,10 +12,18 @@ using Tensor = at::Tensor; namespace fbgemm_gpu { +#ifdef USE_ROCM +// The wave size is forced to be 32 on ROCm devices in favor +// of granularity losses reduction. +constexpr int EMULATED_WARP_SIZE = 32; +#else +constexpr int EMULATED_WARP_SIZE = kWarpSize; +#endif + // TODO: Update UNROLL_FACTOR constexpr int GROUP_INDEX_SELECT_UNROLL_FACTOR = 1; constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP = - GROUP_INDEX_SELECT_UNROLL_FACTOR * kWarpSize; + GROUP_INDEX_SELECT_UNROLL_FACTOR * EMULATED_WARP_SIZE; // GROUP_INDEX_SELECT_COLS_PER_WARP must be power of two constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP = @@ -43,12 +51,21 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( const int64_t num_work_rows, // number of rows to work on per member const int64_t group_size) { const auto total_num_warps = warp_offsets_group[group_size]; + int32_t num_cols = 0; + int32_t warps_per_row = 0; + + if constexpr (!USE_VAR_COLS) { + num_cols = num_cols_group[0]; + warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; + } + for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x; warp_id < total_num_warps; warp_id += gridDim.x * blockDim.y) { - int32_t member_id, member_warp_id, num_cols, warps_per_row; - if (USE_VAR_COLS) { - __shared__ int member_ids[kMaxThreads / kWarpSize]; + int32_t member_id = 0; + int32_t member_warp_id = 0; + if constexpr (USE_VAR_COLS) { + __shared__ int member_ids[kMaxThreads / EMULATED_WARP_SIZE]; if (threadIdx.x == 0) { binary_search_range( &member_ids[threadIdx.y], @@ -63,8 +80,6 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( member_warp_id = warp_id - warp_offsets_group[member_id]; } else { // All columns are the same - num_cols = num_cols_group[0]; - warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP; member_id = warp_id / (warps_per_row * num_work_rows); member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows); } @@ -82,7 +97,7 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel( #pragma unroll for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) { // Compile time conditional - if (USE_INDEX_SELECT) { + if constexpr (USE_INDEX_SELECT) { output[row * num_cols + i] = LDG(&input[idx * num_cols + i]); } else { gpuAtomicAddNoReturn( @@ -113,13 +128,13 @@ DLL_PUBLIC void group_index_select_or_add_cuda( at::cuda::OptionalCUDAGuard device_guard(device); // Partition work based on num_work_rows - uint32_t num_warps_per_threadblock = kMaxThreads / kWarpSize; + uint32_t num_warps_per_threadblock = kMaxThreads / EMULATED_WARP_SIZE; uint32_t max_grid_size = at::cuda::getCurrentDeviceProperties()->multiProcessorCount * 8; uint32_t grid_size = std::min( cuda_calc_xblock_count(total_num_warps, num_warps_per_threadblock), max_grid_size); - dim3 block_size(kWarpSize, num_warps_per_threadblock, 1); + dim3 block_size(EMULATED_WARP_SIZE, num_warps_per_threadblock, 1); #define INVOKE_GROUP_INDEX_SELECT_OR_ADD(USE_INDEX_SELECT, USE_VAR_COLS) \ FBGEMM_LAUNCH_KERNEL( \ diff --git a/fbgemm_gpu/src/sparse_ops/sparse_permute_2d.cu b/fbgemm_gpu/src/sparse_ops/sparse_permute_2d.cu index abb6d8abd4..cf7e23c17e 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_permute_2d.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_permute_2d.cu @@ -20,7 +20,7 @@ template < typename indices_t, typename weights_t> __global__ __launch_bounds__(kMaxThreads) void permute_2D_data_kernel( - int32_t len, + int64_t len, int32_t T, int32_t B, const indices_t* __restrict__ indices, diff --git a/fbgemm_gpu/src/split_embeddings_utils/generate_vbe_metadata.cu b/fbgemm_gpu/src/split_embeddings_utils/generate_vbe_metadata.cu index 17b1fd0edb..ab1ecc7f1d 100644 --- a/fbgemm_gpu/src/split_embeddings_utils/generate_vbe_metadata.cu +++ b/fbgemm_gpu/src/split_embeddings_utils/generate_vbe_metadata.cu @@ -34,7 +34,9 @@ __launch_bounds__(kMaxThreads) void generate_vbe_metadata_foreach_sample_kernel( D_offsets, const int32_t D, const bool nobag, - const int32_t info_B_num_bits) { + const int32_t info_B_num_bits, + const pta::PackedTensorAccessor32 + predefined_vbe_output_offsets) { // Relative sample ID in the rank-table matrix const auto b = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; // Rank ID @@ -50,6 +52,8 @@ __launch_bounds__(kMaxThreads) void generate_vbe_metadata_foreach_sample_kernel( return; } + const bool use_predefined_offsets = predefined_vbe_output_offsets.size(0) > 0; + const auto* __restrict__ output_offsets_feature = &output_offsets_feature_rank[r * T]; @@ -57,8 +61,9 @@ __launch_bounds__(kMaxThreads) void generate_vbe_metadata_foreach_sample_kernel( const auto b_t = static_cast(B_start_t) + static_cast(B_start_r_t) + b; const auto D_ = nobag ? D : (D_offsets[t + 1] - D_offsets[t]); - row_output_offsets[b_t] = - output_offsets_feature[t] + b * static_cast(D_); + auto offset = use_predefined_offsets ? predefined_vbe_output_offsets[r][t] + : output_offsets_feature[t]; + row_output_offsets[b_t] = offset + b * static_cast(D_); // Relative sample ID in the table const auto b_ = B_start_r_t + b; @@ -114,11 +119,15 @@ generate_vbe_metadata( const Tensor& D_offsets, const int64_t D, const bool nobag, - const int64_t max_B_feature_rank, + const c10::SymInt max_B_feature_rank, const int64_t info_B_num_bits, - const int64_t total_B) { + const c10::SymInt total_B, + const std::optional& vbe_output_offsets = std::nullopt) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( - B_offsets, B_offsets_rank_per_feature, output_offsets_feature_rank); + B_offsets, + B_offsets_rank_per_feature, + output_offsets_feature_rank, + vbe_output_offsets); TENSOR_NDIM_EQUALS(B_offsets, 1); TENSOR_NDIM_EQUALS(B_offsets_rank_per_feature, 2); @@ -132,25 +141,53 @@ generate_vbe_metadata( TORCH_CHECK(D_offsets.numel() == T + 1) } + const int64_t total_B_ = total_B.guard_int(__FILE__, __LINE__); + const int64_t max_B_feature_rank_ = + max_B_feature_rank.guard_int(__FILE__, __LINE__); + const auto num_ranks = B_offsets_rank_per_feature.size(1) - 1; TORCH_CHECK( num_ranks > 0, "generate_vbe_metadata: Invalid num_ranks ", num_ranks); TORCH_CHECK(T > 0, "generate_vbe_metadata: Invalid T ", T); TORCH_CHECK( - max_B_feature_rank > 0, + max_B_feature_rank_ > 0, "generate_vbe_metadata: Invalid max_B_feature_rank ", - max_B_feature_rank); + max_B_feature_rank_); TORCH_CHECK(B_offsets_rank_per_feature.size(0) == T); TORCH_CHECK(output_offsets_feature_rank.numel() == num_ranks * T + 1); + Tensor predefined_vbe_output_offsets; + if (vbe_output_offsets.has_value()) { + predefined_vbe_output_offsets = vbe_output_offsets.value(); + TORCH_CHECK( + predefined_vbe_output_offsets.dim() == 2, + "Expected a tensor of 2 dims: [num_ranks, num_features] but got ", + predefined_vbe_output_offsets.dim()); + TORCH_CHECK( + predefined_vbe_output_offsets.size(0) == num_ranks, + "Expected predefined_vbe_output_offsets.size(0) to be", + num_ranks, + " but got ", + predefined_vbe_output_offsets.size(0)); + TORCH_CHECK( + predefined_vbe_output_offsets.size(1) == T, + "Expected predefined_vbe_output_offsets.size(1) to be", + T, + " but got ", + predefined_vbe_output_offsets.size(1)); + } else { + predefined_vbe_output_offsets = + at::empty({0, 0}, output_offsets_feature_rank.options()); + } + CUDA_DEVICE_GUARD(B_offsets); Tensor row_output_offsets = - at::empty({total_B}, output_offsets_feature_rank.options()); - Tensor b_t_map = at::empty({total_B}, B_offsets.options()); + at::empty({total_B_}, output_offsets_feature_rank.options()); + Tensor b_t_map = at::empty({total_B_}, B_offsets.options()); - const auto grid_dim_x = div_round_up(max_B_feature_rank, kMaxThreads); + const auto grid_dim_x = div_round_up(max_B_feature_rank_, kMaxThreads); const dim3 grid_size(grid_dim_x, num_ranks, T); const auto& [max_grid_x, max_grid_y, max_grid_z] = get_max_grid_size(); TORCH_CHECK( @@ -181,7 +218,9 @@ generate_vbe_metadata( PTA_B(D_offsets, int32_t, 1, 32), D, nobag, - info_B_num_bits); + info_B_num_bits, + MAKE_PTA_WITH_NAME( + func_name, predefined_vbe_output_offsets, int64_t, 2, 32)); return {row_output_offsets, b_t_map}; } diff --git a/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils_cpu.cpp b/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils_cpu.cpp index 453b097774..df09f5dbd0 100644 --- a/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils_cpu.cpp +++ b/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils_cpu.cpp @@ -119,7 +119,8 @@ generate_vbe_metadata_cpu( const bool /*nobag*/, const c10::SymInt /*max_B_feature_rank*/, const int64_t info_B_num_bits, - const c10::SymInt total_B) { + const c10::SymInt total_B, + const std::optional& vbe_output_offsets = std::nullopt) { TENSOR_ON_CPU(B_offsets); TENSORS_ON_SAME_DEVICE(B_offsets, B_offsets_rank_per_feature); TENSORS_ON_SAME_DEVICE(B_offsets, output_offsets_feature_rank); @@ -139,6 +140,11 @@ generate_vbe_metadata_cpu( Tensor row_output_offsets = at::empty({total_B_}, output_offsets_feature_rank.options()); TORCH_CHECK(B_offsets.dtype() == at::kInt, "B_offsets should be int32"); + + if (vbe_output_offsets.has_value()) { + TORCH_CHECK(vbe_output_offsets->numel() == total_B, "size mismatch"); + } + Tensor b_t_map = at::empty({total_B_}, B_offsets.options()); auto B_offsets_acc = B_offsets.accessor(); auto D_offsets_acc = D_offsets.accessor(); @@ -166,7 +172,8 @@ generate_vbe_metadata_cpu( } } } - return {row_output_offsets, b_t_map}; + auto row_output_offsets_ = vbe_output_offsets.value_or(row_output_offsets); + return {row_output_offsets_, b_t_map}; } std::tuple @@ -204,7 +211,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { " bool nobag, " " SymInt max_B_feature_rank, " " int info_B_num_bits, " - " SymInt total_B" + " SymInt total_B, " + " Tensor? vbe_output_offsets=None" ") -> (Tensor, Tensor)"); DISPATCH_TO_CPU("generate_vbe_metadata", generate_vbe_metadata_cpu); DISPATCH_TO_CPU("get_infos_metadata", get_infos_metadata_cpu); diff --git a/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils_meta.cpp b/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils_meta.cpp index 38077da03d..18b4fd7bc4 100644 --- a/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils_meta.cpp +++ b/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils_meta.cpp @@ -23,7 +23,8 @@ generate_vbe_metadata_meta( const bool /*nobag*/, const c10::SymInt /*max_B_feature_rank*/, const int64_t /*info_B_num_bits*/, - const c10::SymInt total_B) { + const c10::SymInt total_B, + const std::optional& /*vbe_output_offsets*/ = std::nullopt) { Tensor row_output_offsets = at::empty_symint({total_B}, output_offsets_feature_rank.options()); Tensor b_t_map = at::empty_symint({total_B}, B_offsets.options()); diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h index 8cebdef1eb..4ca404c157 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h @@ -236,6 +236,14 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder { impl_->set_backend_return_whole_row(backend_return_whole_row); } + void trigger_feature_evict() { + impl_->trigger_feature_evict(); + } + + bool is_evicting() { + return impl_->is_evicting(); + } + private: friend class KVTensorWrapper; diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp index eb95d343e6..e0077058ee 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp @@ -383,6 +383,14 @@ void EmbeddingKVDB::set_backend_return_whole_row( return; } +void EmbeddingKVDB::trigger_feature_evict() { + return; +} + +bool EmbeddingKVDB::is_evicting() { + return false; +} + void EmbeddingKVDB::set( const at::Tensor& indices, const at::Tensor& weights, diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h index 94e1a62711..a8082af235 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h @@ -301,6 +301,10 @@ class EmbeddingKVDB : public std::enable_shared_from_this { FBEXCEPTION("Not implemented"); } + virtual void trigger_feature_evict(); + + virtual bool is_evicting(); + /** * @brief need to support set backend_return_whole_row from frontend * if one model changed from SSD to DRAM, or vice versa we need to diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp index 0b95285a8f..64d4dc134c 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp @@ -880,6 +880,10 @@ static auto embedding_rocks_db_wrapper = { torch::arg("backend_return_whole_row"), }) + .def( + "trigger_feature_evict", + &EmbeddingRocksDBWrapper::trigger_feature_evict) + .def("is_evicting", &EmbeddingRocksDBWrapper::is_evicting) .def("stream_sync_cuda", &EmbeddingRocksDBWrapper::stream_sync_cuda) .def("get_cuda", &EmbeddingRocksDBWrapper::get_cuda) .def("compact", &EmbeddingRocksDBWrapper::compact) @@ -980,6 +984,10 @@ static auto dram_kv_embedding_cache_wrapper = { torch::arg("backend_return_whole_row"), }) + .def( + "trigger_feature_evict", + &DramKVEmbeddingCacheWrapper::trigger_feature_evict) + .def("is_evicting", &DramKVEmbeddingCacheWrapper::is_evicting) .def("set", &DramKVEmbeddingCacheWrapper::set) .def( "set_range_to_storage", diff --git a/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp b/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp old mode 100644 new mode 100755 diff --git a/fbgemm_gpu/test/dram_kv_embedding_cache/kv_embedding_inference_test.cpp b/fbgemm_gpu/test/dram_kv_embedding_cache/kv_embedding_inference_test.cpp new file mode 100644 index 0000000000..02f0960a8a --- /dev/null +++ b/fbgemm_gpu/test/dram_kv_embedding_cache/kv_embedding_inference_test.cpp @@ -0,0 +1,217 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "deeplearning/fbgemm/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_inference_embedding.h" + +#include +#include +#include +#include +#include +#include + +namespace kv_mem { + +class KVEmbeddingInferenceTest : public ::testing::Test { + protected: + static constexpr int EMBEDDING_DIM = 128; + static constexpr int NUM_SHARDS = 8; + + void SetUp() override { + FLAGS_logtostderr = true; + FLAGS_minloglevel = 0; + FLAGS_v = 1; + + auto feature_evict_config = c10::make_intrusive( + 3, + 4, + std::nullopt, + std::nullopt, + std::vector{1}, + std::nullopt, + std::nullopt, + std::nullopt, + std::nullopt, + std::nullopt, + std::nullopt, + std::vector{EMBEDDING_DIM}, + std::nullopt, + std::nullopt, + 0, + 0, + 0); + + auto hash_size_cumsum = at::tensor({0, 100000}, at::kLong); + + backend_ = std::make_unique>( + EMBEDDING_DIM, + -0.1, + 0.1, + feature_evict_config, + NUM_SHARDS, + 32, + 32, + false, + std::nullopt, + hash_size_cumsum, + false); + } + + void TearDown() override { + backend_.reset(); + } + + static std::vector generateEmbedding(int64_t embedding_id) { + std::vector embedding(EMBEDDING_DIM); + + // Use both embedding_id and current time as seed for randomness + auto now = std::chrono::system_clock::now(); + auto time_seed = std::chrono::duration_cast( + now.time_since_epoch()) + .count(); + uint32_t combined_seed = static_cast(embedding_id ^ time_seed); + + std::mt19937 rng(combined_seed); + std::uniform_real_distribution dist(-0.1f, 0.1f); + for (int i = 0; i < EMBEDDING_DIM; ++i) { + embedding[i] = dist(rng); + } + return embedding; + } + + std::unique_ptr> backend_; +}; + +TEST_F(KVEmbeddingInferenceTest, InferenceLifecycleWithMetadata) { + const int64_t embedding_id = 12345; + + auto now = std::chrono::system_clock::now(); + auto now_seconds = + std::chrono::duration_cast(now.time_since_epoch()) + .count(); + const uint32_t snapshot_timestamp = static_cast(now_seconds - 120); + + auto embedding_data = generateEmbedding(embedding_id); + + LOG(INFO) << "STEP 1: Define test embedding"; + LOG(INFO) << "Embedding ID: " << embedding_id; + LOG(INFO) << "Timestamp: " << snapshot_timestamp + << " (current time - 2 minutes)"; + LOG(INFO) << "Dimension: " << EMBEDDING_DIM; + LOG(INFO) << "First 5 elements: [" << embedding_data[0] << ", " + << embedding_data[1] << ", " << embedding_data[2] << ", " + << embedding_data[3] << ", " << embedding_data[4] << "]"; + + auto indices_tensor = at::tensor({embedding_id}, at::kLong); + auto weights_tensor = at::from_blob( + embedding_data.data(), + {1, EMBEDDING_DIM}, + at::TensorOptions().dtype(at::kFloat)); + auto count_tensor = at::tensor({1}, at::kInt); + + LOG(INFO) << "STEP 2: Insert embedding into cache"; + folly::coro::blockingWait(backend_->inference_set_kv_db_async( + indices_tensor, weights_tensor, count_tensor, snapshot_timestamp)); + LOG(INFO) << "Insertion completed"; + + auto retrieved_embedding = at::zeros({1, EMBEDDING_DIM}, at::kFloat); + + LOG(INFO) << "STEP 3: Retrieve embedding from cache"; + folly::coro::blockingWait(backend_->get_kv_db_async( + indices_tensor, retrieved_embedding, count_tensor)); + LOG(INFO) << "Retrieval completed"; + + auto retrieved_ptr = retrieved_embedding.data_ptr(); + bool all_match = true; + int mismatch_count = 0; + + LOG(INFO) << "STEP 4: Verify embedding consistency"; + for (int i = 0; i < EMBEDDING_DIM; ++i) { + if (std::abs(retrieved_ptr[i] - embedding_data[i]) > 1e-5f) { + all_match = false; + mismatch_count++; + } + } + + if (all_match) { + LOG(INFO) << "All " << EMBEDDING_DIM << " dimensions match"; + } else { + LOG(ERROR) << "Found " << mismatch_count << " mismatches out of " + << EMBEDDING_DIM << " dimensions"; + } + + ASSERT_TRUE(all_match) << "Retrieved embedding must match inserted embedding"; + + LOG(INFO) << "STEP 5: Test repeated reads"; + for (int iteration = 1; iteration <= 3; ++iteration) { + auto read_again = at::zeros({1, EMBEDDING_DIM}, at::kFloat); + folly::coro::blockingWait( + backend_->get_kv_db_async(indices_tensor, read_again, count_tensor)); + + auto read_ptr = read_again.data_ptr(); + bool matches = true; + for (int i = 0; i < EMBEDDING_DIM; ++i) { + if (std::abs(read_ptr[i] - embedding_data[i]) > 1e-5f) { + matches = false; + break; + } + } + LOG(INFO) << "Read #" << iteration << ": " + << (matches ? "Match" : "Mismatch"); + } + + LOG(INFO) << "STEP 6: Trigger eviction"; + auto eviction_time = std::chrono::system_clock::now(); + auto eviction_seconds = std::chrono::duration_cast( + eviction_time.time_since_epoch()) + .count(); + uint32_t eviction_threshold = static_cast(eviction_seconds - 60); + + LOG(INFO) << "Eviction threshold: " << eviction_threshold; + backend_->trigger_feature_evict(eviction_threshold); + backend_->wait_until_eviction_done(); + LOG(INFO) << "Eviction completed"; + + auto post_eviction_embedding = at::zeros({1, EMBEDDING_DIM}, at::kFloat); + + LOG(INFO) << "STEP 7: Read embedding after eviction"; + folly::coro::blockingWait(backend_->get_kv_db_async( + indices_tensor, post_eviction_embedding, count_tensor)); + + auto post_eviction_ptr = post_eviction_embedding.data_ptr(); + bool values_changed = false; + int differences = 0; + + for (int i = 0; i < EMBEDDING_DIM; ++i) { + if (std::abs(post_eviction_ptr[i] - embedding_data[i]) > 1e-5f) { + values_changed = true; + differences++; + } + } + + LOG(INFO) << "Differences found: " << differences << "/" << EMBEDDING_DIM; + + if (values_changed) { + LOG(INFO) << "Eviction successful - values changed"; + } else { + LOG(ERROR) << "Eviction may have failed - values unchanged"; + } + + LOG(INFO) << "Original (cached): [" << embedding_data[0] << ", " + << embedding_data[1] << ", " << embedding_data[2] << ", " + << embedding_data[3] << ", " << embedding_data[4] << "]"; + LOG(INFO) << "After eviction: [" << post_eviction_ptr[0] << ", " + << post_eviction_ptr[1] << ", " << post_eviction_ptr[2] << ", " + << post_eviction_ptr[3] << ", " << post_eviction_ptr[4] << "]"; + + ASSERT_TRUE(values_changed) << "Embedding should be different after eviction"; + + LOG(INFO) << "Test completed successfully"; +} + +} // namespace kv_mem diff --git a/fbgemm_gpu/test/jagged/common.py b/fbgemm_gpu/test/jagged/common.py index d8838b8447..2cdb9078ec 100644 --- a/fbgemm_gpu/test/jagged/common.py +++ b/fbgemm_gpu/test/jagged/common.py @@ -10,7 +10,6 @@ import itertools import sys -import unittest from typing import Callable import fbgemm_gpu @@ -43,15 +42,7 @@ # Please avoid putting tests here, you should put operator-specific # skips and failures in deeplearning/fbgemm/fbgemm_gpu/test/failures_dict.json # pyre-ignore[24]: Generic type `Callable` expects 2 type parameters. -additional_decorators: dict[str, list[Callable]] = { - "test_pt2_compliant_tag_fbgemm_jagged_dense_elementwise_add": [ - # This operator has been grandfathered in. We need to fix this test failure. - unittest.expectedFailure, - ], - "test_pt2_compliant_tag_fbgemm_jagged_to_padded_dense": [ - unittest.expectedFailure, - ], -} +additional_decorators: dict[str, list[Callable]] = {} def lengths_to_segment_ids(lengths: torch.Tensor) -> torch.Tensor: diff --git a/fbgemm_gpu/test/jagged/dense_to_jagged_test.py b/fbgemm_gpu/test/jagged/dense_to_jagged_test.py index 0e6e08e56a..d03823c364 100644 --- a/fbgemm_gpu/test/jagged/dense_to_jagged_test.py +++ b/fbgemm_gpu/test/jagged/dense_to_jagged_test.py @@ -80,6 +80,11 @@ def _test_dense_to_jagged( jagged_values.backward(ref_output_values) torch.testing.assert_close(dense.grad, ref_values) + torch.library.opcheck( + torch.ops.fbgemm.dense_to_jagged, + (dense.detach().requires_grad_(True), offsets), + ) + @given( num_jagged_dim=st.integers(1, 5), outer_dense_size=st.integers(0, 5), diff --git a/fbgemm_gpu/test/jagged/jagged_index_select_2d_test.py b/fbgemm_gpu/test/jagged/jagged_index_select_2d_test.py index 7433edbeb3..70b2ef276a 100644 --- a/fbgemm_gpu/test/jagged/jagged_index_select_2d_test.py +++ b/fbgemm_gpu/test/jagged/jagged_index_select_2d_test.py @@ -158,6 +158,26 @@ def test_jagged_index_select_2d( rtol=1e-2 if jagged_tensor_dtype in [torch.half, torch.bfloat16] else None, atol=1e-2 if jagged_tensor_dtype in [torch.half, torch.bfloat16] else None, ) + if known_shape: + with torch.no_grad(): + tmp_output, _ = torch.ops.fbgemm.jagged_index_select( + values, lengths, indices + ) + num_dense_output_rows = tmp_output.shape[0] + torch.library.opcheck( + torch.ops.fbgemm.jagged_index_select.default, + ( + values.detach().requires_grad_(), + lengths, + indices, + num_dense_output_rows, + ), + ) + else: + torch.library.opcheck( + torch.ops.fbgemm.jagged_index_select.default, + (values.detach().requires_grad_(), lengths, indices), + ) @given( max_seq_length=st.integers(5, 10), diff --git a/fbgemm_gpu/test/jagged/jagged_to_padded_dense_test.py b/fbgemm_gpu/test/jagged/jagged_to_padded_dense_test.py index 1242470d18..24a8567bee 100644 --- a/fbgemm_gpu/test/jagged/jagged_to_padded_dense_test.py +++ b/fbgemm_gpu/test/jagged/jagged_to_padded_dense_test.py @@ -113,6 +113,50 @@ def test_jagged_to_padded_dense( rtol=1e-3, ) + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a, b, c, d): + return torch.ops.fbgemm.jagged_to_padded_dense(a, b, c, d) + + with torch.inference_mode(): + gm = torch.export.export( + Mod(), + ( + x_values.float().requires_grad_(True), + x_offsets, + max_lengths.astype(int).tolist(), + padding_value, + ), + ).run_decompositions() + num_fw_ops = len( + [ + x + for x in gm.graph.nodes + if x.target is torch.ops.fbgemm.jagged_to_padded_dense_forward.default + ] + ) + num_composite_ops = len( + [ + x + for x in gm.graph.nodes + if x.target is torch.ops.fbgemm.jagged_to_padded_dense.default + ] + ) + self.assertEqual(num_fw_ops, 1) + self.assertEqual(num_composite_ops, 0) + + torch.library.opcheck( + torch.ops.fbgemm.jagged_to_padded_dense, + ( + x_values.float().requires_grad_(True), + x_offsets, + max_lengths, + padding_value, + ), + ) + @given( num_jagged_dim=st.integers(1, 5), outer_dense_size=st.integers(0, 5), diff --git a/fbgemm_gpu/test/tbe/training/store_prefetched_tensors_test.py b/fbgemm_gpu/test/tbe/training/store_prefetched_tensors_test.py new file mode 100644 index 0000000000..b74690b2cc --- /dev/null +++ b/fbgemm_gpu/test/tbe/training/store_prefetched_tensors_test.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest + +import torch + +from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( + SplitTableBatchedEmbeddingBagsCodegen, +) + +from ..common import open_source + +if open_source: + # pyre-ignore[21] + from test_utils import gpu_unavailable +else: + from fbgemm_gpu.test.test_utils import gpu_unavailable + + +class StorePrefetchedTensorsTest(unittest.TestCase): + @unittest.skipIf(*gpu_unavailable) + def test_get_prefetched_info(self) -> None: + hash_zch_identities = torch.tensor( + [ + [3350213393928437575], # for index 54 + [6548733451892409412], # for index 27 + [4126118985661274454], # for index 43 + [2565973416302224539], # for index 90 + ], + device=torch.cuda.current_device(), + dtype=torch.int64, + ) + total_cache_hash_size = 100 + linear_cache_indices_merged = torch.tensor( + [54, 27, 43, 90], + device=torch.cuda.current_device(), + dtype=torch.int64, + ) + + prefetched_info = SplitTableBatchedEmbeddingBagsCodegen._get_prefetched_info( + linear_cache_indices_merged, + total_cache_hash_size, + hash_zch_identities, + ) + + self.assertEqual( + [27, 43, 54, 90], + prefetched_info.linear_unique_indices.tolist(), + ) + self.assertEqual( + prefetched_info.linear_unique_indices_length[0].item(), + 4, + ) + assert prefetched_info.hash_zch_identities is not None + self.assertEqual( + prefetched_info.hash_zch_identities.shape[0], + 4, + ) + self.assertEqual( + [ + [6548733451892409412], + [4126118985661274454], + [3350213393928437575], + [2565973416302224539], + ], + prefetched_info.hash_zch_identities.tolist(), + ) + + @unittest.skipIf(*gpu_unavailable) + def test_get_prefetched_info_with_duplicate_hash_zch_identities(self) -> None: + """ + Test that duplicate cache indices are correctly deduplicated. + When the same cache index appears multiple times with the same identity, + only the first occurrence should be kept in the output. + """ + hash_zch_identities = torch.tensor( + [ + [3350213393928437575], # for index 54 (first occurrence) + [6548733451892409412], # for index 27 + [3350213393928437575], # for index 54 (duplicate - same identity) + [4126118985661274454], # for index 43 + [6548733451892409412], # for index 27 (duplicate - same identity) + [3350213393928437575], # for index 54 (duplicate - same identity) + [2565973416302224539], # for index 90 + ], + device=torch.cuda.current_device(), + dtype=torch.int64, + ) + total_cache_hash_size = 100 + linear_cache_indices_merged = torch.tensor( + [54, 27, 54, 43, 27, 54, 90], # Duplicates: 54 appears 3x, 27 appears 2x + device=torch.cuda.current_device(), + dtype=torch.int64, + ) + + prefetched_info = SplitTableBatchedEmbeddingBagsCodegen._get_prefetched_info( + linear_cache_indices_merged, + total_cache_hash_size, + hash_zch_identities, + ) + + self.assertEqual( + [27, 43, 54, 90], + prefetched_info.linear_unique_indices.tolist(), + ) + self.assertEqual( + prefetched_info.linear_unique_indices_length[0].item(), + 4, + ) + assert prefetched_info.hash_zch_identities is not None + self.assertEqual( + prefetched_info.hash_zch_identities.shape[0], + 4, + ) + self.assertEqual( + [ + [6548733451892409412], # for index 27 + [4126118985661274454], # for index 43 + [3350213393928437575], # for index 54 + [2565973416302224539], # for index 90 + ], + prefetched_info.hash_zch_identities.tolist(), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/fbgemm_gpu/test/tbe/utils/split_embeddings_test.py b/fbgemm_gpu/test/tbe/utils/split_embeddings_test.py index 3dd1bc2cd4..1ebdaddaa5 100644 --- a/fbgemm_gpu/test/tbe/utils/split_embeddings_test.py +++ b/fbgemm_gpu/test/tbe/utils/split_embeddings_test.py @@ -178,17 +178,17 @@ def test_get_table_name_for_logging(self) -> None: SplitTableBatchedEmbeddingBagsCodegen.get_table_name_for_logging( ["t1", "t2"] ), - "<2 tables>", + "<2 tables>: ['t1', 't2']", ) self.assertEqual( SplitTableBatchedEmbeddingBagsCodegen.get_table_name_for_logging( ["t1", "t2", "t1"] ), - "<2 tables>", + "<2 tables>: ['t1', 't2']", ) self.assertEqual( SplitTableBatchedEmbeddingBagsCodegen.get_table_name_for_logging([]), - "<0 tables>", + "<0 tables>: []", ) @unittest.skipIf(*gpu_unavailable) diff --git a/include/fbgemm/FbgemmConvert.h b/include/fbgemm/FbgemmConvert.h index cf404d2056..88dd5e8e30 100644 --- a/include/fbgemm/FbgemmConvert.h +++ b/include/fbgemm/FbgemmConvert.h @@ -47,6 +47,7 @@ FloatToBfloat16_simd(const float* src, bfloat16* dst, size_t size); FBGEMM_API void Bfloat16ToFloat_simd(const bfloat16* src, float* dst, size_t size); +#if !defined(__aarch64__) /** * @brief AVX2 implementation to convert fp32 numbers to bf16 numbers. * @@ -58,10 +59,8 @@ FloatToBfloat16_avx2(const float* src, bfloat16* dst, size_t size); * @brief AVX512 implementation to convert fp32 numbers to bf16 numbers. * */ -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) FBGEMM_API void FloatToBfloat16_avx512(const float* src, bfloat16* dst, size_t size); -#endif /** * @brief AVX2 implementation to convert bf16 numbers to fp32 numbers. @@ -74,7 +73,6 @@ Bfloat16ToFloat_avx2(const bfloat16* src, float* dst, size_t size); * @brief AVX512 implementation to convert bf16 numbers to fp32 numbers. * */ -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) FBGEMM_API void Bfloat16ToFloat_avx512(const bfloat16* src, float* dst, size_t size); #endif @@ -124,6 +122,7 @@ Float16ToFloat_simd(const float16* src, float* dst, size_t size); * @brief AVX2 implementation to convert fp32 numbers to fp16 numbers. * */ +#if !defined(__aarch64__) FBGEMM_API void FloatToFloat16_avx2( const float* src, float16* dst, @@ -134,7 +133,6 @@ FBGEMM_API void FloatToFloat16_avx2( * @brief AVX512 implementation to convert fp32 numbers to fp16 numbers. * */ -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) FBGEMM_API void FloatToFloat16_avx512( const float* src, float16* dst, @@ -152,6 +150,7 @@ FBGEMM_API void FloatToFloat16_sve2( size_t size, bool do_clip = false); +#if !defined(__aarch64__) /** * @brief AVX2 implementation to convert fp16 numbers to fp32 numbers. * @@ -163,7 +162,6 @@ Float16ToFloat_avx2(const float16* src, float* dst, size_t size); * @brief AVX512 implementation to convert fp16 numbers to fp32 numbers. * */ -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) FBGEMM_API void Float16ToFloat_avx512(const float16* src, float* dst, size_t size); #endif diff --git a/include/fbgemm/FbgemmEmbedding.h b/include/fbgemm/FbgemmEmbedding.h index 073b9f8655..12eb8babd6 100644 --- a/include/fbgemm/FbgemmEmbedding.h +++ b/include/fbgemm/FbgemmEmbedding.h @@ -349,7 +349,7 @@ FBGEMM_API bool EmbeddingSpMDMBlockSize1_( bool use_offsets = true, bool is_bf16 = false); -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) +#if !defined(__aarch64__) template void compressed_indices_remap_avx512( std::int32_t offsets_numel, diff --git a/include/fbgemm/FbgemmSparse.h b/include/fbgemm/FbgemmSparse.h index 82e8f889c6..dc00338fb7 100644 --- a/include/fbgemm/FbgemmSparse.h +++ b/include/fbgemm/FbgemmSparse.h @@ -166,7 +166,7 @@ void SparseDenseMMAvx2( int ldc, bool accum = false); -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) +#if !defined(__aarch64__) void SparseDenseMMAvx512( int M, int N, diff --git a/include/fbgemm/FloatConversion.h b/include/fbgemm/FloatConversion.h index f2628450e4..b88630d9b1 100644 --- a/include/fbgemm/FloatConversion.h +++ b/include/fbgemm/FloatConversion.h @@ -289,7 +289,7 @@ inline float cpu_half2float_ref(const float16 h) { // Same as the previous function, but use the built-in fp16 to fp32 // conversion provided by the compiler inline float cpu_half2float(const float16 h) { -#if defined(HAS_NATIVE_FP16_TYPE) && not defined(MISSING_GNU_F2H_IEEE) +#if defined(HAS_NATIVE_FP16_TYPE) && !defined(MISSING_GNU_F2H_IEEE) __fp16 h_fp16 = NAN; std::memcpy(&h_fp16, &h, sizeof(__fp16)); return h_fp16; @@ -299,7 +299,7 @@ inline float cpu_half2float(const float16 h) { } inline float16 cpu_float2half(const float f) { -#if defined(HAS_NATIVE_FP16_TYPE) && not defined(MISSING_GNU_F2H_IEEE) +#if defined(HAS_NATIVE_FP16_TYPE) && !defined(MISSING_GNU_F2H_IEEE) __fp16 h = f; float16 res = 0; std::memcpy(&res, &h, sizeof(__fp16)); diff --git a/include/fbgemm/OutputProcessing-inl.h b/include/fbgemm/OutputProcessing-inl.h index 5faabe7eeb..ec70a49aa4 100644 --- a/include/fbgemm/OutputProcessing-inl.h +++ b/include/fbgemm/OutputProcessing-inl.h @@ -125,7 +125,7 @@ ReQuantizeOutput::f( } } -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) +#if !defined(__aarch64__) } else if constexpr ( instSet == inst_set_t::avx2 || instSet == inst_set_t::avx512) { @@ -249,7 +249,7 @@ inline int ReQuantizeForFloat::f( } } -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) +#if !defined(__aarch64__) } else if constexpr ( instSet == inst_set_t::avx2 || instSet == inst_set_t::avx512) { bool b_symmetric = diff --git a/include/fbgemm/QuantUtilsAvx2.h b/include/fbgemm/QuantUtilsAvx2.h index 0f2859c8ff..7c6fe24396 100644 --- a/include/fbgemm/QuantUtilsAvx2.h +++ b/include/fbgemm/QuantUtilsAvx2.h @@ -52,6 +52,13 @@ struct FBGEMM_API RequantizationParams { TensorQuantizationParams target_qparams; }; +/// @ingroup fbgemm-quant-utils-avx2 +/// +/// @brief Find the min and max value in a float matrix. +void FBGEMM_API FindMinMax(const float* m, float* min, float* max, int64_t len); + +#if !defined(__aarch64__) + //////////////////////////////////////////////////////////////////////////////// // Utility functions //////////////////////////////////////////////////////////////////////////////// @@ -77,11 +84,6 @@ void FusedQuantizeDequantizeAvx2( /// this paper. uint32_t FBGEMM_API Xor128(); -/// @ingroup fbgemm-quant-utils-avx2 -/// -/// @brief Find the min and max value in a float matrix. -void FBGEMM_API FindMinMax(const float* m, float* min, float* max, int64_t len); - void RequantizeFixedPointAvx2( const std::int32_t* src, std::uint8_t* dst, @@ -94,6 +96,8 @@ void RequantizeAvx2( int len, const RequantizationParams& params); +#endif // !defined(__aarch64__) + /// @ingroup fbgemm-quant-utils-avx2 /// /// Requantize with avx2 and bias is fused. @@ -143,6 +147,8 @@ FBGEMM_API void requantizeForFloatAvx2( int ld_in, const requantizationForFloatParams_t& r); +#if !defined(__aarch64__) + template void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2( const InputType* input, @@ -176,4 +182,6 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2( int input_columns, OutputType* output); +#endif // !defined(__aarch64__) + } // namespace fbgemm diff --git a/include/fbgemm/QuantUtilsAvx512.h b/include/fbgemm/QuantUtilsAvx512.h index c4b01817bd..1ad1efe71e 100644 --- a/include/fbgemm/QuantUtilsAvx512.h +++ b/include/fbgemm/QuantUtilsAvx512.h @@ -9,7 +9,7 @@ #pragma once #include "Types.h" -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) +#if !defined(__aarch64__) #include #include "./FbgemmBuild.h" // @manual diff --git a/include/fbgemm/QuantUtilsNeon.h b/include/fbgemm/QuantUtilsNeon.h index 63f108b418..13169c8a05 100644 --- a/include/fbgemm/QuantUtilsNeon.h +++ b/include/fbgemm/QuantUtilsNeon.h @@ -22,6 +22,13 @@ namespace fbgemm { // Utility functions //////////////////////////////////////////////////////////////////////////////// +template +void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatNeon( + const InputType* input, + size_t input_rows, + int input_columns, + uint8_t* output); + template void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon( const std::uint8_t* input, diff --git a/src/FbgemmBfloat16Convert.cc b/src/FbgemmBfloat16Convert.cc index 34baed622b..4c7a358d94 100644 --- a/src/FbgemmBfloat16Convert.cc +++ b/src/FbgemmBfloat16Convert.cc @@ -29,7 +29,7 @@ namespace fbgemm { void FloatToBfloat16_simd(const float* src, bfloat16* dst, size_t size) { // Run time CPU detection if (cpuinfo_initialize()) { -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) +#if !defined(__aarch64__) if (fbgemmHasAvx512Support()) { FloatToBfloat16_avx512(src, dst, size); } else if (fbgemmHasAvx2Support()) { @@ -48,7 +48,7 @@ void FloatToBfloat16_simd(const float* src, bfloat16* dst, size_t size) { void Bfloat16ToFloat_simd(const bfloat16* src, float* dst, size_t size) { // Run time CPU detection if (cpuinfo_initialize()) { -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) +#if !defined(__aarch64__) if (fbgemmHasAvx512Support()) { Bfloat16ToFloat_avx512(src, dst, size); } else if (fbgemmHasAvx2Support()) { diff --git a/src/FbgemmFP16.cc b/src/FbgemmFP16.cc index 106f071953..79eda23712 100644 --- a/src/FbgemmFP16.cc +++ b/src/FbgemmFP16.cc @@ -35,7 +35,7 @@ namespace { // the restrictions of ymm register numbers (16). constexpr kernel_array_t kernel_fp16_avx2 = { nullptr, -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) +#if !defined(__aarch64__) gemmkernel_1x2_Avx2_fp16_fA0fB0fC0, gemmkernel_2x2_Avx2_fp16_fA0fB0fC0, gemmkernel_3x2_Avx2_fp16_fA0fB0fC0, @@ -79,7 +79,7 @@ constexpr kernel_array_t kernel_fp16_neon = { constexpr kernel_array_t kernel_fp16_avx512_256 = { nullptr, -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) +#if !defined(__aarch64__) gemmkernel_1x2_Avx2_fp16_fA0fB0fC0, gemmkernel_2x2_Avx2_fp16_fA0fB0fC0, gemmkernel_3x2_Avx2_fp16_fA0fB0fC0, diff --git a/src/FbgemmFP16UKernelsAvx2.h b/src/FbgemmFP16UKernelsAvx2.h index 888bae1833..455c49fdd5 100644 --- a/src/FbgemmFP16UKernelsAvx2.h +++ b/src/FbgemmFP16UKernelsAvx2.h @@ -16,6 +16,8 @@ namespace fbgemm { using GemmParamsFP16 = GemmParams; +#if !defined(__aarch64__) + void NOINLINE gemmkernel_1x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp); void NOINLINE gemmkernel_2x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp); void NOINLINE gemmkernel_3x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp); @@ -23,4 +25,6 @@ void NOINLINE gemmkernel_4x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp); void NOINLINE gemmkernel_5x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp); void NOINLINE gemmkernel_6x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp); +#endif // !defined(__aarch64__) + } // namespace fbgemm diff --git a/src/FbgemmFloat16Convert.cc b/src/FbgemmFloat16Convert.cc index 9519d6cb62..1f76baeafc 100644 --- a/src/FbgemmFloat16Convert.cc +++ b/src/FbgemmFloat16Convert.cc @@ -23,7 +23,7 @@ void FloatToFloat16_simd( bool do_clip) { // Run time CPU detection if (cpuinfo_initialize()) { -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) +#if !defined(__aarch64__) if (fbgemmHasAvx512Support()) { FloatToFloat16_avx512(src, dst, size, do_clip); } else if (fbgemmHasAvx2Support()) { @@ -42,7 +42,7 @@ void FloatToFloat16_simd( void Float16ToFloat_simd(const float16* src, float* dst, size_t size) { // Run time CPU detection if (cpuinfo_initialize()) { -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) +#if !defined(__aarch64__) if (fbgemmHasAvx512Support()) { Float16ToFloat_avx512(src, dst, size); } else if (fbgemmHasAvx2Support()) { diff --git a/src/FbgemmSparseDense.cc b/src/FbgemmSparseDense.cc index 1e2122d78f..eb8a82f60c 100644 --- a/src/FbgemmSparseDense.cc +++ b/src/FbgemmSparseDense.cc @@ -193,7 +193,7 @@ void SparseDenseMM( float* C, int ldc, bool accum) { -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) +#if !defined(__aarch64__) // Run time CPU detection static const auto iset = fbgemmInstructionSet(); @@ -229,7 +229,7 @@ FBGEMM_API void fbgemmSparseDenseInt8MM( return; } -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) +#if !defined(__aarch64__) // Run time CPU detection static const auto iset = fbgemmInstructionSet(); diff --git a/src/GroupwiseConv.cc b/src/GroupwiseConv.cc index f92408f2ec..38ec4910b0 100644 --- a/src/GroupwiseConv.cc +++ b/src/GroupwiseConv.cc @@ -121,7 +121,7 @@ static jit_conv_kernel_fp getOrCreateConvKernel( accum); if (cpuinfo_initialize()) { -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) +#if !defined(__aarch64__) if (fbgemmHasAvx512VnniSupport()) { return GenConvKernel::codeCache_ .getOrCreate(kernelSig, [&]() { @@ -954,7 +954,7 @@ static void dispatchOutputProcessing( } if (cpuinfo_initialize()) { -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) +#if !defined(__aarch64__) if (fbgemmHasAvx512Support() || fbgemmHasAvx512VnniSupport()) { REQUANTIZE_C_PER_G(Avx512); } else if (fbgemmHasAvx2Support() || fbgemmHasArmNeonSupport()) { diff --git a/src/PackWeightsForConv.cc b/src/PackWeightsForConv.cc index 040dbe682c..8870e6a903 100644 --- a/src/PackWeightsForConv.cc +++ b/src/PackWeightsForConv.cc @@ -61,7 +61,7 @@ PackWeightsForConv::PackWeightsForConv( break; } case optimized_conv_t::directconv: { -#if defined(__aarch64__) +#if !defined(FBGEMM_FBCODE) && defined(__aarch64__) throw std::runtime_error( "PackWeightsForConv::PackWeightsForConv(): No fallback available for aarch64"); #else diff --git a/src/PackWeightsForDirectConv.cc b/src/PackWeightsForDirectConv.cc index 3be4528642..db33d43d65 100644 --- a/src/PackWeightsForDirectConv.cc +++ b/src/PackWeightsForDirectConv.cc @@ -459,7 +459,7 @@ void fbgemmDirectConv( } } // else SPATIAL_DIM -#endif // defined(FBGEMM_FBCODE) || !defined(__aarch64__) +#endif // !defined(__aarch64__) } #define INSTANTIATE_REQUANTIZE_SPATIAL_DIM( \ diff --git a/src/QuantUtils.cc b/src/QuantUtils.cc index 1c2e58363d..5301909193 100644 --- a/src/QuantUtils.cc +++ b/src/QuantUtils.cc @@ -714,6 +714,10 @@ void FloatOrHalfToFused8BitRowwiseQuantizedSBFloat( int input_columns, std::uint8_t* output, const InputType* rowwise_min_max) { +#if HAVE_SVE + FloatOrHalfToFused8BitRowwiseQuantizedSBFloatNeon( + input, input_rows, input_columns, output); +#else if (cpuinfo_initialize() && fbgemmHasAvx2Support()) { #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 FloatOrHalfToFused8BitRowwiseQuantizedSBFloatAvx2( @@ -723,6 +727,7 @@ void FloatOrHalfToFused8BitRowwiseQuantizedSBFloat( FloatOrHalfToFused8BitRowwiseQuantizedSBFloatRef( input, input_rows, input_columns, output); } +#endif } template diff --git a/src/QuantUtilsAvx2.cc b/src/QuantUtilsAvx2.cc index ab6274d571..89deb44d39 100644 --- a/src/QuantUtilsAvx2.cc +++ b/src/QuantUtilsAvx2.cc @@ -280,6 +280,8 @@ SPECIALIZE_FUSEDDQAVX2(int8_t) #undef SPECIALIZE_FUSEDDQAVX2 +#ifndef __aarch64__ + void FindMinMax(const float* m, float* min, float* max, int64_t len) { if (len <= 0) { *min = 0.0f; @@ -317,6 +319,8 @@ void FindMinMax(const float* m, float* min, float* max, int64_t len) { *max = temp_max; } +#endif + //////////////////////////////////////////////////////////////////////////////// // Requantization (with floats) diff --git a/src/QuantUtilsNeon.cc b/src/QuantUtilsNeon.cc index a8835f0e05..8fef86b94f 100644 --- a/src/QuantUtilsNeon.cc +++ b/src/QuantUtilsNeon.cc @@ -6,15 +6,18 @@ * LICENSE file in the root directory of this source tree. */ -#include "fbgemm/Utils.h" +#if defined(__aarch64__) -#if HAVE_SVE +#include "fbgemm/Utils.h" #define FBGEMM_EXPORTS +#include // @manual #include // @manual +#if HAVE_SVE +#include // @manual #include // @manual +#endif -#include // @manual #include //for std::min/std::max #include //for assert #include // for FLT_MAX @@ -30,6 +33,229 @@ namespace fbgemm { using namespace std; //////////////////////////////////////////////////////////////////////////////// // Utility functions +static inline void +FindMinMaxImpl_f32(const float* m, float* min, float* max, uint64_t count) { + float first = *m; + + float tmp_min_s = first; + float tmp_max_s = first; + + float32x4_t temp_min_0 = vdupq_n_f32(first); + float32x4_t temp_min_1 = vdupq_n_f32(first); + float32x4_t temp_max_0 = vdupq_n_f32(first); + float32x4_t temp_max_1 = vdupq_n_f32(first); + constexpr uint64_t kItemsPerIter = 8; + uint64_t loopIters = count / kItemsPerIter; + uint64_t loopRemainder = count % kItemsPerIter; + + if (__builtin_expect(loopIters > 0, 1)) { + do { + float32x4_t v0 = vld1q_f32(m); + float32x4_t v1 = vld1q_f32(m + 4); + m += kItemsPerIter; + loopIters -= 1; + temp_min_0 = vminq_f32(temp_min_0, v0); + temp_min_1 = vminq_f32(temp_min_1, v1); + temp_max_0 = vmaxq_f32(temp_max_0, v0); + temp_max_1 = vmaxq_f32(temp_max_1, v1); + } while (loopIters > 0); + + temp_min_0 = vminq_f32(temp_min_0, temp_min_1); + temp_max_0 = vmaxq_f32(temp_max_0, temp_max_1); + + tmp_min_s = vminvq_f32(temp_min_0); + tmp_max_s = vmaxvq_f32(temp_max_0); + } + +#ifdef __clang__ +#pragma clang loop vectorize(disable) interleave(disable) unroll(disable) +#elif defined(__GNUC__) +#pragma GCC novector unroll 0 +#endif + while (loopRemainder > 0) { + float tmp = *m++; + loopRemainder -= 1; + tmp_min_s = std::min(tmp_min_s, tmp); + tmp_max_s = std::max(tmp_max_s, tmp); + } + + *min = tmp_min_s; + *max = tmp_max_s; +} + +void FindMinMax(const float* m, float* min, float* max, int64_t len) { + if (__builtin_expect(len <= 0, 0)) { + *min = 0.0f; + *max = 0.0f; + return; + } + + FindMinMaxImpl_f32(m, min, max, static_cast(len)); +} + +#if HAVE_SVE + +static inline void +FindMinMaxImpl_f16(const float16_t* m, float* min, float* max, uint64_t count) { + float16_t first = *m; + + float16_t tmp_min_s = first; + float16_t tmp_max_s = first; + + float16x8_t temp_min_0 = vdupq_n_f16(first); + float16x8_t temp_min_1 = vdupq_n_f16(first); + float16x8_t temp_max_0 = vdupq_n_f16(first); + float16x8_t temp_max_1 = vdupq_n_f16(first); + constexpr uint64_t kItemsPerIter = 16; + uint64_t loopIters = count / kItemsPerIter; + uint64_t loopRemainder = count % kItemsPerIter; + + if (__builtin_expect(loopIters > 0, 1)) { + do { + float16x8_t v0 = vld1q_f16(m); + float16x8_t v1 = vld1q_f16(m + 8); + m += kItemsPerIter; + loopIters -= 1; + temp_min_0 = vminq_f16(temp_min_0, v0); + temp_min_1 = vminq_f16(temp_min_1, v1); + temp_max_0 = vmaxq_f16(temp_max_0, v0); + temp_max_1 = vmaxq_f16(temp_max_1, v1); + } while (loopIters > 0); + + temp_min_0 = vminq_f16(temp_min_0, temp_min_1); + temp_max_0 = vmaxq_f16(temp_max_0, temp_max_1); + + tmp_min_s = vminvq_f16(temp_min_0); + tmp_max_s = vmaxvq_f16(temp_max_0); + } + +#ifdef __clang__ +#pragma clang loop vectorize(disable) interleave(disable) unroll(disable) +#elif defined(__GNUC__) +#pragma GCC novector unroll 0 +#endif + while (loopRemainder > 0) { + float16_t tmp = *m++; + loopRemainder -= 1; + tmp_min_s = vminh_f16(tmp_min_s, tmp); + tmp_max_s = vmaxh_f16(tmp_max_s, tmp); + } + + *min = static_cast(tmp_min_s); + *max = static_cast(tmp_max_s); +} + +template +void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatNeon( + const InputType* input, + size_t input_rows, + int input_columns, + uint8_t* output) { + constexpr float kEpsilon = 1e-8f; + + if (input_rows == 0 || input_columns <= 0) { + return; + } + + uint64_t column_count = static_cast(input_columns); + + const uint64_t output_columns = column_count + 2 * sizeof(float); + + for (size_t row = 0; __builtin_expect(row < input_rows, 1); ++row) { + const InputType* input_row = input + row * column_count; + uint8_t* output_row = output + row * output_columns; + + float* output_row_scale_bias = + reinterpret_cast(output_row + column_count); + + float minimum_element; + float maximum_element; + if constexpr (std::is_same()) { + FindMinMaxImpl_f32( + input_row, &minimum_element, &maximum_element, column_count); + } else { + FindMinMaxImpl_f16( + reinterpret_cast(input_row), + &minimum_element, + &maximum_element, + column_count); + } + float range = maximum_element - minimum_element; + + const auto inverse_scale = 255.0f / (range + kEpsilon); + + float32x4_t inverse_scale_v = vdupq_n_f32(inverse_scale); + float32x4_t min_v = vdupq_n_f32(minimum_element); + + constexpr uint64_t kItemsPerIter = 8; + uint64_t loopIters = column_count / kItemsPerIter; + uint64_t loopRemainder = column_count % kItemsPerIter; + + output_row_scale_bias[0] = range / 255.0f; + output_row_scale_bias[1] = minimum_element; + + while (__builtin_expect(loopIters > 0, 1)) { + float32x4_t v0; + float32x4_t v1; + + if constexpr (std::is_same()) { + v0 = vld1q_f32(input_row); + v1 = vld1q_f32(input_row + 4); + } else { + float16x8_t h0 = + vld1q_f16(reinterpret_cast(input_row)); + v0 = vcvt_f32_f16(vget_low_f16(h0)); + v1 = vcvt_high_f32_f16(h0); + } + + input_row += kItemsPerIter; + loopIters -= 1; + + v0 = vsubq_f32(v0, min_v); + v1 = vsubq_f32(v1, min_v); + + v0 = vmulq_f32(v0, inverse_scale_v); + v1 = vmulq_f32(v1, inverse_scale_v); + + int32x4_t i0 = vcvtnq_s32_f32(v0); + int32x4_t i1 = vcvtnq_s32_f32(v1); + + svst1b_s32( + svptrue_b8(), + reinterpret_cast(output_row), + svset_neonq_s32(svundef_s32(), i0)); + svst1b_s32( + svptrue_b8(), + reinterpret_cast(output_row + 4), + svset_neonq_s32(svundef_s32(), i1)); + + output_row += kItemsPerIter; + } + +#ifdef __clang__ +#pragma clang loop vectorize(disable) interleave(disable) unroll(disable) +#elif defined(__GNUC__) +#pragma GCC novector unroll 0 +#endif + while (loopRemainder > 0) { + float32x4_t v0; + if constexpr (std::is_same()) { + v0[0] = *input_row++; + } else { + v0[0] = + static_cast(*reinterpret_cast(input_row)); + input_row += 1; + } + loopRemainder -= 1; + v0 = vsubq_f32(v0, min_v); + v0 = vmulq_f32(v0, inverse_scale_v); + int32x4_t i0 = vcvtnq_s32_f32(v0); + *output_row = i0[0]; + output_row += 1; + } + + } // for each row +} template void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon( @@ -133,7 +359,12 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon( const std::uint8_t* input, \ size_t input_rows, \ int input_columns, \ - type* output); + type* output); \ + template void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatNeon( \ + const type* input, \ + size_t input_rows, \ + int input_columns, \ + uint8_t* output); // clang-format off INSTANTIATE_QuantizationNeonFunctions8Bits(float) @@ -141,6 +372,8 @@ INSTANTIATE_QuantizationNeonFunctions8Bits(float16) // clang-format on #undef INSTANTIATE_QuantizationNeonFunctions8Bits +#endif // HAVE_SVE + } // namespace fbgemm #endif // __aarch64__ diff --git a/src/TransposeUtils.cc b/src/TransposeUtils.cc index aecec554da..cb1cb58d5a 100644 --- a/src/TransposeUtils.cc +++ b/src/TransposeUtils.cc @@ -57,14 +57,11 @@ void transpose_simd( #else static const auto iset = fbgemmInstructionSet(); // Run time CPU detection -#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) if (isZmm(iset)) { internal::transpose_avx512(M, N, src, ld_src, dst, ld_dst); } else if (isYmm(iset)) { internal::transpose_avx2(M, N, src, ld_src, dst, ld_dst); - } else -#endif - { + } else { transpose_ref(M, N, src, ld_src, dst, ld_dst); }