Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
cd7dfea
Add gfx950 build support + fp16 fix + index type fix
avbokovoy Jul 29, 2025
602b7bf
Change int64_t to index_t as template parameters in load_raw_per_warp
avbokovoy Jul 29, 2025
a587e06
Implement llvm fp16 buffer load for gfx950
avbokovoy Jul 29, 2025
48a10bf
Fix c-style half to float cast
avbokovoy Aug 11, 2025
d4acaba
Patch 256 half stores
avbokovoy Aug 11, 2025
a6636f0
cta_per_row workgroup optim
shbiswas834 Aug 8, 2025
a15fb09
Added mi350 guards
shbiswas834 Aug 11, 2025
6af95e0
Fix index overflow in row load
shbiswas834 Aug 12, 2025
be5f1b8
cta_per_row workgroup reduce by 4 optim
shbiswas834 Aug 12, 2025
acef908
Fix mixed_D frontend to backend connection
avbokovoy Aug 13, 2025
33f4ad9
changed max_segment_length_per_cta to 4096
kudomcho Aug 15, 2025
aaf1966
added rocm guards and removed comment
shbiswas834 Aug 18, 2025
48e7f97
clean debug statements in Hip.cmake
liligwu Aug 20, 2025
750bee4
Merge pull request #121
shbiswas834 Aug 28, 2025
f0acbc3
Guard f16 llvm intrinsics with ROCm >=7.0
avbokovoy Sep 2, 2025
0ee2366
fix the bug in dimention 160 in ROCm optimization
liligwu Sep 18, 2025
e33120d
Cleanup optimized warp_per_raw kernel
avbokovoy Aug 19, 2025
3447ef0
Add 320 embedding dim support for optimized warp_per_row kernel
avbokovoy Aug 20, 2025
a1361ab
changed the max length per warp and cta per row WG size
Sep 8, 2025
9c2fd1d
added DPP and changed max length per warp to 16k
kudomcho Sep 9, 2025
54690c9
guard max segment warp based on emb dim
kudomcho Sep 10, 2025
d666611
added guarding opt of max segment for the case batch size list=1
kudomcho Sep 10, 2025
df863d0
opt for grad_indice_weights kernel
Sep 18, 2025
e0bee9f
added store row per warp on emb 192 and added accuracy test functiona…
kudomcho Sep 23, 2025
ca82950
workgroup tuning and loop unrolled
shbiswas834 Sep 22, 2025
7ad444b
specialize
Hardcode84 Sep 19, 2025
970229b
explicitly link to tbb
liligwu Sep 24, 2025
539985c
added warpReduceAllSum with rocm guards
shbiswas834 Sep 25, 2025
e3d4773
revert unroll and wg tuning
shbiswas834 Oct 13, 2025
9505ffe
Minor update embedding_forward_split_kernel_template.cu
liligwu Oct 13, 2025
8709307
add tbb-devel to the install_build_tools ()
liligwu Oct 17, 2025
6a3d3cb
fix lint issues
liligwu Oct 21, 2025
6351c43
solve lint issues
liligwu Oct 21, 2025
1e9b3f3
applied jinja is_rocm onto optimizations for backward and forward par…
kudomcho Oct 22, 2025
46b9f80
Guard supported grad_t for optimized warp_per_row dispatch
avbokovoy Oct 23, 2025
ab5cf5d
Forward index_t to the optimizer
avbokovoy Oct 23, 2025
5164f6e
Guard f16 llvm intrinsics with ROCm >=7.0
avbokovoy Sep 2, 2025
cde00fc
Fix buffer offset for emb_dim == 160
avbokovoy Oct 23, 2025
5d73b9c
Remove sanity check
avbokovoy Oct 27, 2025
919db74
address the potential lint issues and revert the change in indices_ge…
liligwu Oct 27, 2025
3df3c91
addresss code style issue
liligwu Oct 27, 2025
6c3a362
Remove general load/store methods
avbokovoy Oct 24, 2025
8cb6838
Move weight type check to compile-time
avbokovoy Oct 24, 2025
ab6fa10
Switch to 256B stores for float type
avbokovoy Oct 27, 2025
c5a915d
removed guard rocm on mixed_D and refactored mixed_D var assignment
kudomcho Oct 28, 2025
570f148
Merge remote-tracking branch 'origin/abokovoi/mi350-remove-general-lo…
liligwu Oct 28, 2025
ca4701f
hack param
Bernard-Liu Nov 2, 2025
5bf0cf6
support opt code_gen
Bernard-Liu Oct 27, 2025
b72bdd8
support subwarp
yadaish Aug 6, 2025
6343a4f
update subwarp kernel
Bernard-Liu Oct 28, 2025
c386072
grad sum kernel unroll improvement
XingerZhu Oct 27, 2025
7bf6dd8
fix performance issuse
yadaish Oct 29, 2025
fb7f0a8
fix vbe opt not imply
Bernard-Liu Nov 2, 2025
bec6a69
fix smybol bug & rm comment
Bernard-Liu Nov 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/scripts/utils_build.bash
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ install_build_tools () {
patchelf \
rhash \
scikit-build \
tbb-devel \
tbb \
wheel \
xz \
Expand Down
12 changes: 12 additions & 0 deletions cmake/modules/CppLibrary.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,18 @@ function(cpp_library)
target_link_libraries(${lib_name} PUBLIC OpenMP::OpenMP_CXX)
endif()

if(NOT TARGET TBB::tbb)
find_package(TBB QUIET)
endif()
if(TBB_FOUND)
target_link_libraries(${lib_name} PUBLIC TBB::tbb)
else()
find_library(TBB_LIB NAMES tbb tbb12 HINTS $ENV{CONDA_PREFIX}/lib /usr/lib/x86_64-linux-gnu /usr/local/lib /lib/x86_64-linux-gnu)
if(TBB_LIB)
target_link_libraries(${lib_name} PUBLIC ${TBB_LIB})
endif()
endif()

# Add sanitizer options if needed
if(args_SANITIZER_OPTIONS)
target_link_options(${lib_name} PUBLIC
Expand Down
12 changes: 12 additions & 0 deletions cmake/modules/GpuCppLibrary.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,18 @@ function(gpu_cpp_library)
list(APPEND library_dependencies ${NVML_LIB_PATH})
endif()

if(NOT TARGET TBB::tbb)
find_package(TBB QUIET)
endif()
if(TBB_FOUND)
list(APPEND library_dependencies TBB::tbb)
else()
find_library(TBB_LIB NAMES tbb tbb12 HINTS $ENV{CONDA_PREFIX}/lib /usr/lib/x86_64-linux-gnu /usr/local/lib /lib/x86_64-linux-gnu)
if(TBB_LIB)
list(APPEND library_dependencies ${TBB_LIB})
endif()
endif()

# Link against the external libraries as needed
target_link_libraries(${lib_name} PRIVATE ${library_dependencies})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1506,4 +1506,4 @@


if __name__ == "__main__":
cli()
cli()

Check failure on line 1509 in fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py

View workflow job for this annotation

GitHub Actions / run-lint (3.14)

W292 no newline at end of file
2 changes: 0 additions & 2 deletions fbgemm_gpu/cmake/tbe_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@
"_nobag" if nobag else "",
)
for nobag in [
True,
False,
]
for weighted in (
Expand Down Expand Up @@ -495,7 +494,6 @@
"_nobag" if nobag else "",
)
for nobag in [
True,
False,
]
for weighted in (
Expand Down
10 changes: 7 additions & 3 deletions fbgemm_gpu/codegen/genscript/generate_backward_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ def render_backward_templates(
return

weighted_options = [True, False]
nobag_options = [True, False] if (not is_gwd) else [False]
nobag_options = (
[True, False]
if (not (is_gwd or kwargs.get("is_hip_optimized_backward")))
else [False]
)
vbe_options = [True, False] if (kwargs.get("has_vbe_support")) else [False]
ssd_options = [True, False] if kwargs.get("has_ssd_support") else [False]
template = CodeTemplate.load(template_filepath)
Expand Down Expand Up @@ -327,8 +331,7 @@ def generate_backward_indices() -> None:

@staticmethod
def generate_rocm_backward_split(**kwargs: Any) -> None:
# Generate backward device kernels based on weighted (True/False), VBE
# (True/False), no bag (True/False)
# Generate backward device kernels based on weighted (True/False)
template_filepath = (
"training/backward/rocm/embedding_backward_split_device_kernel_template.hip"
)
Expand All @@ -343,6 +346,7 @@ def generate_rocm_backward_split(**kwargs: Any) -> None:
"has_ssd_support": False,
"dense": False,
"gen_once": False,
"is_hip_optimized_backward": True,
},
)

Expand Down
36 changes: 36 additions & 0 deletions fbgemm_gpu/codegen/genscript/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ def rowwise_adagrad() -> Dict[str, Any]:

at::acc_type<cache_t, true> multiplier = 0.0;
at::acc_type<cache_t, true> correction = 0.0;
"""
split_precomputation_preload = split_precomputation
split_precomputation += """
if (threadIdx.x == 0) {
auto new_sum_square_grads = g_avg_square;

Expand Down Expand Up @@ -228,6 +231,38 @@ def rowwise_adagrad() -> Dict[str, Any]:
multiplier = SHFL_SYNC(multiplier, 0);
correction = SHFL_SYNC(correction, 0);
"""
split_precomputation_preload += """
if (threadIdx.x == 0) {
auto new_sum_square_grads = g_avg_square;

// Update the optimizer state. Use optimizer state offloading only if
// SSD and if enabled by the user
if (enable_optimizer_offloading) {
// Fetch the pointer to the optimizer state along the cache row
auto* optimizer = weight_row_template.template optimizer_state_ptr<OptimizerState>();
new_sum_square_grads += optimizer->momentum;
optimizer->momentum = new_sum_square_grads;

} else {
new_sum_square_grads += momentum1_val;
momentum1[idx] = new_sum_square_grads;
}

multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps);
if (weight_decay_mode == 1) {
// L2 regularization
correction = 1.0 - multiplier * weight_decay;
} else if (weight_decay_mode == 2 || weight_decay_mode == 5) {
// Decoupled weight decay
correction = 1.0 - learning_rate * weight_decay;
} else {
// default value
correction = 1.0;
}
}
multiplier = SHFL_SYNC(multiplier, 0);
correction = SHFL_SYNC(correction, 0);
"""
split_weight_update_cpu = """
at::acc_type<grad_t, true> g_local_sum_square = 0.0;
for (int64_t d = 0; d < D; ++d) {
Expand Down Expand Up @@ -275,6 +310,7 @@ def rowwise_adagrad() -> Dict[str, Any]:
},
),
"split_precomputation": split_precomputation,
"split_precomputation_preload": split_precomputation_preload,
"split_weight_update": split_weight_update,
"split_post_update": split_post_update,
"split_weight_update_cpu": split_weight_update_cpu,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,98 @@

using namespace fbgemm_gpu;

// Helper macro: Generate block_size grad_offset_j_i variables (i from 1 to block_size-1)
#define GRAD_OFFSET(i, j) const auto grad_offset_j_##i = SHFL_SYNC(grad_offset, j + i);
#define L(i, j) int32_t l_j_##i = SHFL_SYNC(l, j + i);
#define B(i, j) int32_t b_j_##i = SHFL_SYNC(b, j + i);
#define D_START(i, j) int32_t D_start_j_##i = SHFL_SYNC(D_start, j + i);
#define IDX_WEIGHT(i, j) at::acc_type<cache_t, true> idx_weight_j_##i = SHFL_SYNC(idx_weight, j + i);

#define REPEAT_8(X, j) X(1, j); X(2, j); X(3, j); X(4, j); X(5, j); X(6, j); X(7, j);
#define REPEAT_4(X, j) X(1, j); X(2, j); X(3, j);
#define REPEAT_2(X, j) X(1, j);
#define REPEAT_1(X, j) // No additional variables needed for block size 1

#define REPEAT_I_S_8(X, j, m, n) X(1, j, m, n); X(2, j, m, n); X(3, j, m, n); X(4, j, m, n); X(5, j, m, n); X(6, j, m, n); X(7, j, m, n);
#define REPEAT_I_S_4(X, j, m, n) X(1, j, m, n); X(2, j, m, n); X(3, j, m, n);
#define REPEAT_I_S_2(X, j, m, n) X(1, j, m, n);
#define REPEAT_I_S_1(X, j, m, n) // No additional variables needed for block size 1

// Helper macro: Generate block_size Vec4TAcc objects (i from 1 to block_size-1)
// if nobag and is_index_select
#define GRAD_VEC_N_I(i, grad_offset, grad_stride, d) Vec4TAcc<grad_t> grad_out_vec_##i(&grad_output[grad_offset + l_j_##i * grad_stride + d]);
// elif nobag
#define GRAD_VEC_N(i, d) Vec4TAcc<grad_t> grad_out_vec_##i(&grad_output[l_j_##i][d]);
// elif vbe
#define GRAD_VEC_V(i, d) Vec4TAcc<grad_t> grad_out_vec_##i(&grad_output[0][grad_offset_j_##i + d]);
// else
#define GRAD_VEC(i, d) Vec4TAcc<grad_t> grad_out_vec_##i(&grad_output[b_j_##i][0] + D_start_j_##i + d);

// Helper macro: Generate block_size fma_ calls (i from 1 to block_size-1)
#define FMA_GRAD(i, vec) grad_sum[vec].fma_(grad_out_vec_##i, idx_weight_j_##i);
// Helper macro: Generate block_size add_ calls (i from 1 to block_size-1)
#define ADD_GRAD(i, vec) grad_sum[vec].add_(grad_out_vec_##i);

// Core macro: Process blocks of specified size (block_size = 8/4/2/1)
// Parameters:
// - block_size: Size of each block to process
// - unroll_count: Number of unroll iterations for the inner loop
#define PROCESS_BLOCK(block_size, unroll_count, grad_sum, grad_output, grad_offset, vec_start, kThreadGroupSize, threadIdx_x, VEC_WIDTH, D, j, sl, sl_end) \
for (; j + (block_size - 1) < kThreadGroupSize && sl + j + (block_size - 1) < sl_end; j += block_size) { \
{%- if nobag %}
int32_t l_j_0 = SHFL_SYNC(l, j); \
REPEAT_##block_size(L, j) \
{%- elif vbe %}
/* Generate block_size grad_offset_j_0 ~ grad_offset_j_(block_size-1) */ \
const auto grad_offset_j_0 = SHFL_SYNC(grad_offset, j); \
/* Generate subsequent grad_offset_j_1 ~ grad_offset_j_(block_size-1) based on block size */ \
REPEAT_##block_size(GRAD_OFFSET, j) \
{%- else %}
int32_t b_j_0 = SHFL_SYNC(b, j); \
REPEAT_##block_size(B, j) \
int32_t D_start_j_0 = SHFL_SYNC(D_start, j); \
REPEAT_##block_size(D_START, j) \
{%- endif %}
{%- if weighted %}
at::acc_type<cache_t, true> idx_weight_j_0 = SHFL_SYNC(idx_weight, j); \
REPEAT_##block_size(IDX_WEIGHT, j) \
{%- endif %}
{%- set d = "(((vec + vec_start) * kThreadGroupSize + threadIdx.x) * VEC_WIDTH)" %}
\
for (int32_t vec = 0; vec < unroll_count && (((vec + vec_start) * kThreadGroupSize + threadIdx_x) * VEC_WIDTH) < D; ++vec) { \
const int32_t d = (((vec + vec_start) * kThreadGroupSize + threadIdx_x) * VEC_WIDTH); \
/* Generate block_size Vec4TAcc objects and accumulate them */ \
Vec4TAcc<grad_t> grad_out_vec_0( \
{%- if nobag and is_index_select %}
&grad_output[grad_offset + l_j_0 * grad_stride + d] \
{%- elif nobag %}
&grad_output[l_j_0][d] \
{%- elif vbe %}
&grad_output[0][grad_offset_j_0 + d] \
{%- else %}
&grad_output[b_j_0][0] + D_start_j_0 + d \
{%- endif %}
); \
{%- if nobag and is_index_select %}
REPEAT_I_S_##block_size(GRAD_VEC_N_I, grad_offset, grad_stride, d) \
{%- elif nobag %}
REPEAT_##block_size(GRAD_VEC_N, d) \
{%- elif vbe %}
REPEAT_##block_size(GRAD_VEC_V, d) \
{%- else %}
REPEAT_##block_size(GRAD_VEC, d) \
{%- endif %}
\
{%- if weighted %}
grad_sum[vec].fma_(grad_out_vec_0, idx_weight_j_0); \
REPEAT_##block_size(FMA_GRAD, vec) \
{%- else %}
grad_sum[vec].add_(grad_out_vec_0); \
REPEAT_##block_size(ADD_GRAD, vec) \
{%- endif %}
} \
}

{%- if gen_once %}
{#- /*
The kernels in this section will be generated only once for all TBE configs
Expand Down Expand Up @@ -141,45 +233,23 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}(
? sorted_indice_weights[segment_start + sl_j]
: 0.0;
{%- endif %}
for (int32_t j = 0; j < kThreadGroupSize && sl + j < sl_end; ++j) {
{%- if nobag %}
int32_t l_j = SHFL_SYNC(l, j);
{%- elif vbe %}
const auto grad_offset_j = SHFL_SYNC(grad_offset, j);
{%- else %}
int32_t b_j = SHFL_SYNC(b, j);
int32_t D_start_j = SHFL_SYNC(D_start, j);
{%- endif %}

{%- if weighted %}
at::acc_type<cache_t, true> idx_weight_j = SHFL_SYNC(idx_weight, j);
{%- endif %}

{%- set d = "(((vec + vec_start) * kThreadGroupSize + threadIdx.x) * VEC_WIDTH)" %}
int32_t j = 0;

// Process blocks of different sizes with loop unrolling
if constexpr (sizeof(grad_t) <= 2) {
#pragma unroll kFixedMaxVecsPerThread
for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && {{ d }} < D; ++vec) {
const int32_t d = {{ d }};
Vec4TAcc<grad_t> grad_out_vec(
{%- if nobag and is_index_select %}
// grad_output is 1d
&grad_output[grad_offset + l_j * grad_stride + d]
{%- elif nobag %}
&grad_output[l_j][d]
{%- elif vbe %}
&grad_output[0][grad_offset_j + d]
{%- else %}
&grad_output[b_j][0] + D_start_j + d
{%- endif %} // if nobag
);

{%- if weighted %}
grad_sum[vec].fma_(grad_out_vec, idx_weight_j);
{%- else %}
grad_sum[vec].add_(grad_out_vec);
{%- endif %}
}
PROCESS_BLOCK(8, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \
vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end)
}
#pragma unroll kFixedMaxVecsPerThread
PROCESS_BLOCK(4, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \
vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end)
#pragma unroll kFixedMaxVecsPerThread
PROCESS_BLOCK(2, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \
vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end)
#pragma unroll kFixedMaxVecsPerThread
PROCESS_BLOCK(1, kFixedMaxVecsPerThread, grad_sum, grad_output, grad_offset, \
vec_start, kThreadGroupSize, threadIdx.x, VEC_WIDTH, D, j, sl, sl_end)
}
{%- set d_vec = "((vec + vec_start) * kThreadGroupSize + threadIdx.x)" %}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -949,7 +949,7 @@ class {{ autograd_func }} :

#ifdef USE_ROCM
constexpr int32_t BT_block_size = 64;
constexpr int32_t max_segment_length_per_warp = 64;
constexpr int32_t max_segment_length_per_warp = 16384;
#else
constexpr int32_t BT_block_size = 32;
constexpr int32_t max_segment_length_per_warp = 32;
Expand Down
Loading
Loading