@@ -295,19 +295,30 @@ foreach(optimizer ${SSD_OPTIMIZERS})
295
295
"gen_embedding_backward_${optimizer} _ssd_${wdesc} _kernel_cta.cu"
296
296
"gen_embedding_backward_${optimizer} _ssd_${wdesc} _kernel_warp.cu" )
297
297
endforeach ()
298
+
298
299
foreach (wdesc weighted unweighted)
299
300
list (APPEND gen_gpu_kernel_source_files
300
301
"gen_embedding_backward_${optimizer} _ssd_${wdesc} _vbe_cuda.cu"
301
302
"gen_embedding_backward_${optimizer} _ssd_${wdesc} _vbe_kernel_cta.cu"
302
303
"gen_embedding_backward_${optimizer} _ssd_${wdesc} _vbe_kernel_warp.cu" )
303
304
endforeach ()
304
-
305
305
endforeach ()
306
306
307
307
list (APPEND gen_defused_optim_py_files
308
308
${CMAKE_BINARY_DIR} /optimizer_args.py)
309
309
310
310
311
+ ################################################################################
312
+ # FBGEMM_GPU Generated HIP-Specific Sources
313
+ ################################################################################
314
+
315
+ set (gen_hip_kernel_source_files)
316
+ foreach (wdesc weighted unweighted unweighted_nobag)
317
+ list (APPEND gen_hip_kernel_source_files
318
+ "gen_embedding_backward_split_${wdesc} _device_kernel_hip.hip" )
319
+ endforeach ()
320
+
321
+
311
322
################################################################################
312
323
# FBGEMM (not FBGEMM_GPU) Sources
313
324
################################################################################
@@ -516,6 +527,9 @@ set(fbgemm_gpu_sources_gpu_gen
516
527
${gen_gpu_host_source_files}
517
528
${gen_defused_optim_source_files} )
518
529
530
+ set (fbgemm_gpu_sources_hip_gen
531
+ ${gen_hip_kernel_source_files} )
532
+
519
533
if (USE_ROCM)
520
534
prepend_filepaths(
521
535
PREFIX ${CMAKE_BINARY_DIR}
@@ -526,6 +540,11 @@ if(USE_ROCM)
526
540
PREFIX ${CMAKE_BINARY_DIR}
527
541
INPUT ${fbgemm_gpu_sources_gpu_gen}
528
542
OUTPUT fbgemm_gpu_sources_gpu_gen)
543
+
544
+ prepend_filepaths(
545
+ PREFIX ${CMAKE_BINARY_DIR}
546
+ INPUT ${fbgemm_gpu_sources_hip_gen}
547
+ OUTPUT fbgemm_gpu_sources_hip_gen)
529
548
endif ()
530
549
531
550
@@ -562,6 +581,8 @@ gpu_cpp_library(
562
581
GPU_SRCS
563
582
${fbgemm_gpu_sources_gpu_static}
564
583
${fbgemm_gpu_sources_gpu_gen}
584
+ HIP_SPECIFIC_SRCS
585
+ ${fbgemm_gpu_sources_hip_gen}
565
586
OTHER_SRCS
566
587
${asmjit_sources}
567
588
${fbgemm_sources}
0 commit comments