Skip to content

Commit

Permalink
Patch D66310520 to make it build in OSS (pytorch#3409)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#3409

X-link: facebookresearch/FBGEMM#497

- Patch D66310520 to make the code build in OSS

Reviewed By: sryap

Differential Revision: D66399304
  • Loading branch information
q10 authored and facebook-github-bot committed Dec 5, 2024
1 parent 252944a commit 115c5ca
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 4 deletions.
1 change: 0 additions & 1 deletion .github/scripts/fbgemm_gpu_test.bash
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ __configure_fbgemm_gpu_test_cuda () {

ignored_tests=(
)

}

__configure_fbgemm_gpu_test_rocm () {
Expand Down
21 changes: 20 additions & 1 deletion fbgemm_gpu/FbgemmGpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -279,18 +279,27 @@ foreach(optimizer ${SSD_OPTIMIZERS})
"gen_embedding_backward_${optimizer}_ssd_${wdesc}_kernel_cta.cu"
"gen_embedding_backward_${optimizer}_ssd_${wdesc}_kernel_warp.cu")
endforeach()

foreach(wdesc weighted unweighted)
list(APPEND gen_gpu_kernel_source_files
"gen_embedding_backward_${optimizer}_ssd_${wdesc}_vbe_cuda.cu"
"gen_embedding_backward_${optimizer}_ssd_${wdesc}_vbe_kernel_cta.cu"
"gen_embedding_backward_${optimizer}_ssd_${wdesc}_vbe_kernel_warp.cu")
endforeach()

endforeach()

list(APPEND gen_defused_optim_py_files
${CMAKE_BINARY_DIR}/optimizer_args.py)

################################################################################
# FBGEMM_GPU Generated HIP-Specific Sources
################################################################################

set(gen_hip_kernel_source_files)
foreach(wdesc weighted unweighted unweighted_nobag)
list(APPEND gen_hip_kernel_source_files
"gen_embedding_backward_split_${wdesc}_device_kernel_hip.hip")
endforeach()

################################################################################
# FBGEMM_GPU Static Sources
Expand Down Expand Up @@ -426,6 +435,9 @@ set(fbgemm_gpu_sources_gpu_gen
${gen_gpu_host_source_files}
${gen_defused_optim_source_files})

set(fbgemm_gpu_sources_hip_gen
${gen_hip_kernel_source_files})

if(USE_ROCM)
prepend_filepaths(
PREFIX ${CMAKE_BINARY_DIR}
Expand All @@ -436,6 +448,11 @@ if(USE_ROCM)
PREFIX ${CMAKE_BINARY_DIR}
INPUT ${fbgemm_gpu_sources_gpu_gen}
OUTPUT fbgemm_gpu_sources_gpu_gen)

prepend_filepaths(
PREFIX ${CMAKE_BINARY_DIR}
INPUT ${fbgemm_gpu_sources_hip_gen}
OUTPUT fbgemm_gpu_sources_hip_gen)
endif()


Expand Down Expand Up @@ -478,6 +495,8 @@ gpu_cpp_library(
GPU_SRCS
${fbgemm_gpu_sources_gpu_static}
${fbgemm_gpu_sources_gpu_gen}
HIP_SPECIFIC_SRCS
${fbgemm_gpu_sources_hip_gen}
GPU_FLAGS
${TORCH_CUDA_OPTIONS}
DEPS
Expand Down
2 changes: 1 addition & 1 deletion fbgemm_gpu/codegen/genscript/generate_backward_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def generate() -> None:
BackwardSplitGenerator.generate_backward_split(
ssd_tensors=ssd_tensors, **optimizer
)
BackwardSplitGenerator.generate_rocm_backward_split(**optimizer)
BackwardSplitGenerator.generate_rocm_backward_split()

# Generate common device kernels for backwards
BackwardSplitGenerator.generate_backward_device()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ def __init__( # noqa C901
assert (
not mixed_D
), "OptimType.NONE does not support mixed embedding dimension"
self.mixed_D = mixed_D
self.mixed_D: bool = mixed_D
if device is None:
self.current_device: torch.device = (
torch.device("cpu")
Expand Down Expand Up @@ -3551,6 +3551,15 @@ def __init__(
torch.tensor(D_offsets, device=self.current_device, dtype=torch.int32),
)
assert self.D_offsets.numel() == T + 1

mixed_D = False
D = dims[0]
for d in dims:
if d != D:
mixed_D = True
break
self.mixed_D: bool = mixed_D

# Required for VBE
self.register_buffer(
"feature_dims",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
*
******************************************************************************/
#pragma once
#include <c10/util/Half.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>

Expand Down

0 comments on commit 115c5ca

Please sign in to comment.