From e5c7ca48f12be4526f1f67a039764c08961d6487 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Sun, 26 Dec 2021 04:05:51 -0600 Subject: [PATCH] auto parse kernel deps by include (#38438) --- cmake/pten_kernel.cmake | 35 +++++++++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/cmake/pten_kernel.cmake b/cmake/pten_kernel.cmake index 8ec81dfa5bc81..23d8ca7dffc8d 100644 --- a/cmake/pten_kernel.cmake +++ b/cmake/pten_kernel.cmake @@ -18,7 +18,7 @@ function(kernel_declare TARGET_LIST) file(READ ${kernel_path} kernel_impl) # TODO(chenweihang): rename PT_REGISTER_CTX_KERNEL to PT_REGISTER_KERNEL # NOTE(chenweihang): now we don't recommend to use digit in kernel name - string(REGEX MATCH "(PT_REGISTER_CTX_KERNEL|PT_REGISTER_GENERAL_KERNEL)\\([ \t\r\n]*[a-z_]*," first_registry "${kernel_impl}") + string(REGEX MATCH "(PT_REGISTER_CTX_KERNEL|PT_REGISTER_GENERAL_KERNEL)\\([ \t\r\n]*[a-z0-9_]*," first_registry "${kernel_impl}") if (NOT first_registry STREQUAL "") # parse the first kernel name string(REPLACE "PT_REGISTER_CTX_KERNEL(" "" kernel_name "${first_registry}") @@ -49,6 +49,9 @@ function(kernel_library TARGET) set(gpu_srcs) set(xpu_srcs) set(npu_srcs) + # parse and save the deps kerenl targets + set(all_srcs) + set(kernel_deps) set(oneValueArgs "") set(multiValueArgs SRCS DEPS) @@ -57,7 +60,6 @@ function(kernel_library TARGET) list(LENGTH kernel_library_SRCS kernel_library_SRCS_len) # one kernel only match one impl file in each backend - # TODO(chenweihang): parse compile deps by include headers if (${kernel_library_SRCS_len} EQUAL 0) if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cc) list(APPEND common_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cc) @@ -84,6 +86,23 @@ function(kernel_library TARGET) # TODO(chenweihang): impl compile by source later endif() + list(APPEND all_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.h) + list(APPEND all_srcs ${common_srcs}) + list(APPEND all_srcs ${cpu_srcs}) + list(APPEND all_srcs ${gpu_srcs}) + list(APPEND all_srcs ${xpu_srcs}) + foreach(src ${all_srcs}) + file(READ ${src} target_content) + string(REGEX MATCHALL "#include \"paddle\/pten\/kernels\/[a-z0-9_]+_kernel.h\"" include_kernels ${target_content}) + foreach(include_kernel ${include_kernels}) + string(REGEX REPLACE "#include \"paddle\/pten\/kernels\/" "" kernel_name ${include_kernel}) + string(REGEX REPLACE ".h\"" "" kernel_name ${kernel_name}) + list(APPEND kernel_deps ${kernel_name}) + endforeach() + endforeach() + list(REMOVE_DUPLICATES kernel_deps) + list(REMOVE_ITEM kernel_deps ${TARGET}) + list(LENGTH common_srcs common_srcs_len) list(LENGTH cpu_srcs cpu_srcs_len) list(LENGTH gpu_srcs gpu_srcs_len) @@ -95,11 +114,11 @@ function(kernel_library TARGET) # we will use this implementation and will not adopt the implementation # under specific devices if (WITH_GPU) - nv_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS}) + nv_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) elseif (WITH_ROCM) - hip_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS}) + hip_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) else() - cc_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS}) + cc_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) endif() else() # If the kernel has a header file declaration, but no corresponding @@ -110,15 +129,15 @@ function(kernel_library TARGET) else() if (WITH_GPU) if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) - nv_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS}) + nv_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) endif() elseif (WITH_ROCM) if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) - hip_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS}) + hip_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) endif() else() if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR ${npu_srcs_len} GREATER 0) - cc_library(${TARGET} SRCS ${cpu_srcs} ${xpu_srcs} ${npu_srcs} DEPS ${kernel_library_DEPS}) + cc_library(${TARGET} SRCS ${cpu_srcs} ${xpu_srcs} ${npu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) endif() endif() endif()