diff --git a/.github/scripts/fbgemm_gpu_test.bash b/.github/scripts/fbgemm_gpu_test.bash index f1e13b16b4..392a85c104 100644 --- a/.github/scripts/fbgemm_gpu_test.bash +++ b/.github/scripts/fbgemm_gpu_test.bash @@ -102,7 +102,6 @@ __configure_fbgemm_gpu_test_cuda () { ignored_tests=( ) - } __configure_fbgemm_gpu_test_rocm () { diff --git a/fbgemm_gpu/FbgemmGpu.cmake b/fbgemm_gpu/FbgemmGpu.cmake index c6a8d3ab39..62c7819367 100644 --- a/fbgemm_gpu/FbgemmGpu.cmake +++ b/fbgemm_gpu/FbgemmGpu.cmake @@ -279,18 +279,27 @@ foreach(optimizer ${SSD_OPTIMIZERS}) "gen_embedding_backward_${optimizer}_ssd_${wdesc}_kernel_cta.cu" "gen_embedding_backward_${optimizer}_ssd_${wdesc}_kernel_warp.cu") endforeach() + foreach(wdesc weighted unweighted) list(APPEND gen_gpu_kernel_source_files "gen_embedding_backward_${optimizer}_ssd_${wdesc}_vbe_cuda.cu" "gen_embedding_backward_${optimizer}_ssd_${wdesc}_vbe_kernel_cta.cu" "gen_embedding_backward_${optimizer}_ssd_${wdesc}_vbe_kernel_warp.cu") endforeach() - endforeach() list(APPEND gen_defused_optim_py_files ${CMAKE_BINARY_DIR}/optimizer_args.py) +################################################################################ +# FBGEMM_GPU Generated HIP-Specific Sources +################################################################################ + +set(gen_hip_kernel_source_files) +foreach(wdesc weighted unweighted unweighted_nobag) + list(APPEND gen_hip_kernel_source_files + "gen_embedding_backward_split_${wdesc}_device_kernel_hip.hip") +endforeach() ################################################################################ # FBGEMM_GPU Static Sources @@ -426,6 +435,9 @@ set(fbgemm_gpu_sources_gpu_gen ${gen_gpu_host_source_files} ${gen_defused_optim_source_files}) +set(fbgemm_gpu_sources_hip_gen + ${gen_hip_kernel_source_files}) + if(USE_ROCM) prepend_filepaths( PREFIX ${CMAKE_BINARY_DIR} @@ -436,6 +448,11 @@ if(USE_ROCM) PREFIX ${CMAKE_BINARY_DIR} INPUT ${fbgemm_gpu_sources_gpu_gen} OUTPUT fbgemm_gpu_sources_gpu_gen) + + prepend_filepaths( + PREFIX ${CMAKE_BINARY_DIR} + INPUT ${fbgemm_gpu_sources_hip_gen} + OUTPUT fbgemm_gpu_sources_hip_gen) endif() @@ -478,6 +495,8 @@ gpu_cpp_library( GPU_SRCS ${fbgemm_gpu_sources_gpu_static} ${fbgemm_gpu_sources_gpu_gen} + HIP_SPECIFIC_SRCS + ${fbgemm_gpu_sources_hip_gen} GPU_FLAGS ${TORCH_CUDA_OPTIONS} DEPS diff --git a/fbgemm_gpu/codegen/genscript/generate_backward_split.py b/fbgemm_gpu/codegen/genscript/generate_backward_split.py index 5e01defc83..c977148578 100644 --- a/fbgemm_gpu/codegen/genscript/generate_backward_split.py +++ b/fbgemm_gpu/codegen/genscript/generate_backward_split.py @@ -390,7 +390,7 @@ def generate() -> None: BackwardSplitGenerator.generate_backward_split( ssd_tensors=ssd_tensors, **optimizer ) - BackwardSplitGenerator.generate_rocm_backward_split(**optimizer) + BackwardSplitGenerator.generate_rocm_backward_split() # Generate common device kernels for backwards BackwardSplitGenerator.generate_backward_device() diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 47a6ff2aff..958de07138 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -729,7 +729,7 @@ def __init__( # noqa C901 assert ( not mixed_D ), "OptimType.NONE does not support mixed embedding dimension" - self.mixed_D = mixed_D + self.mixed_D: bool = mixed_D if device is None: self.current_device: torch.device = ( torch.device("cpu") @@ -3551,6 +3551,15 @@ def __init__( torch.tensor(D_offsets, device=self.current_device, dtype=torch.int32), ) assert self.D_offsets.numel() == T + 1 + + mixed_D = False + D = dims[0] + for d in dims: + if d != D: + mixed_D = True + break + self.mixed_D: bool = mixed_D + # Required for VBE self.register_buffer( "feature_dims", diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 0058548e22..b3a56c4b52 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -21,6 +21,7 @@ * ******************************************************************************/ #pragma once +#include #include #include