Skip to content

Commit

Permalink
auto parse kernel deps by include (#38438)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenwhql authored Dec 26, 2021
1 parent acef85b commit e5c7ca4
Showing 1 changed file with 27 additions and 8 deletions.
35 changes: 27 additions & 8 deletions cmake/pten_kernel.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ function(kernel_declare TARGET_LIST)
file(READ ${kernel_path} kernel_impl)
# TODO(chenweihang): rename PT_REGISTER_CTX_KERNEL to PT_REGISTER_KERNEL
# NOTE(chenweihang): now we don't recommend to use digit in kernel name
string(REGEX MATCH "(PT_REGISTER_CTX_KERNEL|PT_REGISTER_GENERAL_KERNEL)\\([ \t\r\n]*[a-z_]*," first_registry "${kernel_impl}")
string(REGEX MATCH "(PT_REGISTER_CTX_KERNEL|PT_REGISTER_GENERAL_KERNEL)\\([ \t\r\n]*[a-z0-9_]*," first_registry "${kernel_impl}")
if (NOT first_registry STREQUAL "")
# parse the first kernel name
string(REPLACE "PT_REGISTER_CTX_KERNEL(" "" kernel_name "${first_registry}")
Expand Down Expand Up @@ -49,6 +49,9 @@ function(kernel_library TARGET)
set(gpu_srcs)
set(xpu_srcs)
set(npu_srcs)
# parse and save the deps kerenl targets
set(all_srcs)
set(kernel_deps)

set(oneValueArgs "")
set(multiValueArgs SRCS DEPS)
Expand All @@ -57,7 +60,6 @@ function(kernel_library TARGET)

list(LENGTH kernel_library_SRCS kernel_library_SRCS_len)
# one kernel only match one impl file in each backend
# TODO(chenweihang): parse compile deps by include headers
if (${kernel_library_SRCS_len} EQUAL 0)
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cc)
list(APPEND common_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cc)
Expand All @@ -84,6 +86,23 @@ function(kernel_library TARGET)
# TODO(chenweihang): impl compile by source later
endif()

list(APPEND all_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.h)
list(APPEND all_srcs ${common_srcs})
list(APPEND all_srcs ${cpu_srcs})
list(APPEND all_srcs ${gpu_srcs})
list(APPEND all_srcs ${xpu_srcs})
foreach(src ${all_srcs})
file(READ ${src} target_content)
string(REGEX MATCHALL "#include \"paddle\/pten\/kernels\/[a-z0-9_]+_kernel.h\"" include_kernels ${target_content})
foreach(include_kernel ${include_kernels})
string(REGEX REPLACE "#include \"paddle\/pten\/kernels\/" "" kernel_name ${include_kernel})
string(REGEX REPLACE ".h\"" "" kernel_name ${kernel_name})
list(APPEND kernel_deps ${kernel_name})
endforeach()
endforeach()
list(REMOVE_DUPLICATES kernel_deps)
list(REMOVE_ITEM kernel_deps ${TARGET})

list(LENGTH common_srcs common_srcs_len)
list(LENGTH cpu_srcs cpu_srcs_len)
list(LENGTH gpu_srcs gpu_srcs_len)
Expand All @@ -95,11 +114,11 @@ function(kernel_library TARGET)
# we will use this implementation and will not adopt the implementation
# under specific devices
if (WITH_GPU)
nv_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS})
nv_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
elseif (WITH_ROCM)
hip_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS})
hip_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
else()
cc_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS})
cc_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
endif()
else()
# If the kernel has a header file declaration, but no corresponding
Expand All @@ -110,15 +129,15 @@ function(kernel_library TARGET)
else()
if (WITH_GPU)
if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0)
nv_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS})
nv_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
endif()
elseif (WITH_ROCM)
if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0)
hip_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS})
hip_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
endif()
else()
if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR ${npu_srcs_len} GREATER 0)
cc_library(${TARGET} SRCS ${cpu_srcs} ${xpu_srcs} ${npu_srcs} DEPS ${kernel_library_DEPS})
cc_library(${TARGET} SRCS ${cpu_srcs} ${xpu_srcs} ${npu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps})
endif()
endif()
endif()
Expand Down

0 comments on commit e5c7ca4

Please sign in to comment.