forked from bytedance/ByteTransformer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
CMakeLists.txt
executable file
·147 lines (123 loc) · 5.24 KB
/
CMakeLists.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
cmake_minimum_required(VERSION 3.8 FATAL_ERROR)
project(ByteTransformer LANGUAGES CXX CUDA)
option(BUILD_THS "Build in TorchScript class mode" OFF)
add_definitions("-DCUTLASS_ATTENTION")
set(PYTHON_PATH "python" CACHE STRING "Python path")
# find cuda
find_package(CUDA)
set(CUDA_HOME ${CUDA_TOOLKIT_ROOT_DIR} CACHE STRING "CUDA home path")
list(APPEND CMAKE_MODULE_PATH ${CUDA_HOME}/lib)
list(APPEND CMAKE_MODULE_PATH ${CUDA_HOME}/lib64)
find_package(CUDA REQUIRED)
set(CUDA_PATH ${CUDA_HOME})
if (${CUDA_VERSION} GREATER_EQUAL 11.0)
message(STATUS "Add DCUDA11_MODE")
add_definitions("-DCUDA11_MODE")
endif()
set(CUDAARCHS "80" CACHE STRING "CUDA Architectures")
set(CMAKE_CUDA_ARCHITECTURES ${CUDAARCHS})
message(STATUS "CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}")
# setting compiler flags
if (DataType STREQUAL FP16)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DFP16")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DFP16")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DFP16")
message("-- Set the data type as FP16 ")
else()
message("-- Set the data type as FP32 ")
endif()
set(CUDA_ARCH_FLAGS)
foreach(ARCH ${CMAKE_CUDA_ARCHITECTURES})
list(APPEND CUDA_ARCH_FLAGS "-gencode=arch=compute_${ARCH},code=\\\"sm_${ARCH},compute_${ARCH}\\\"")
endforeach()
string(JOIN " " JOINED_CUDA_ARCH_FLAGS "${CUDA_ARCH_FLAGS}")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -rdc=true ${JOINED_CUDA_ARCH_FLAGS}")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -Wall")
set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -Wall -O0")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -Wall -O0")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -O0 -G -Xcompiler -Wall")
set(CMAKE_CXX_STANDARD "17")
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --std=c++17")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -O3")
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
set(COMMON_HEADER_DIRS
${PROJECT_SOURCE_DIR}
${CUDA_PATH}/include
${CMAKE_CURRENT_BINARY_DIR}
${PROJECT_SOURCE_DIR}/3rdparty/cutlass/include
${PROJECT_SOURCE_DIR}/3rdparty/cutlass/tools/util/include
${PROJECT_SOURCE_DIR}/cutlass_contrib/include
)
set(COMMON_LIB_DIRS
${CUDA_PATH}/lib
${CUDA_PATH}/lib64
)
if(BUILD_THS)
set(TORCH_CUDA_ARCH_LIST)
foreach(ARCH ${CMAKE_CUDA_ARCHITECTURES})
if(ARCH STREQUAL "80")
list(APPEND TORCH_CUDA_ARCH_LIST "8.0")
elseif(ARCH STREQUAL "75")
list(APPEND TORCH_CUDA_ARCH_LIST "7.5")
elseif(ARCH STREQUAL "70")
list(APPEND TORCH_CUDA_ARCH_LIST "7.0")
else()
message(WARNING "Unsupported CUDA arch [${ARCH}] for TORCH_CUDA_ARCH")
endif()
endforeach()
execute_process(COMMAND ${PYTHON_PATH} "-c" "from __future__ import print_function; import os; import torch;
print(os.path.dirname(torch.__file__),end='');"
RESULT_VARIABLE _PYTHON_SUCCESS
OUTPUT_VARIABLE TORCH_DIR)
if (NOT _PYTHON_SUCCESS MATCHES 0)
message(FATAL_ERROR "Torch config Error.")
endif()
list(APPEND CMAKE_PREFIX_PATH ${TORCH_DIR})
find_package(Torch REQUIRED)
execute_process(COMMAND ${PYTHON_PATH} "-c" "from __future__ import print_function; from distutils import sysconfig;
print(sysconfig.get_python_inc());
print(sysconfig.get_config_var('EXT_SUFFIX'));"
RESULT_VARIABLE _PYTHON_SUCCESS
OUTPUT_VARIABLE _PYTHON_VALUES)
if (NOT _PYTHON_SUCCESS MATCHES 0)
message(FATAL_ERROR "Python config Error.")
endif()
string(REGEX REPLACE ";" "\\\\;" _PYTHON_VALUES ${_PYTHON_VALUES})
string(REGEX REPLACE "\n" ";" _PYTHON_VALUES ${_PYTHON_VALUES})
list(GET _PYTHON_VALUES 0 PY_INCLUDE_DIR)
list(GET _PYTHON_VALUES 1 PY_SUFFIX)
list(APPEND COMMON_HEADER_DIRS ${PY_INCLUDE_DIR})
execute_process(COMMAND ${PYTHON_PATH} "-c"
"from torch.utils import cpp_extension; import re; import torch; \
version = tuple(int(i) for i in re.match('(\\d+)\\.(\\d+)\\.(\\d+)', torch.__version__).groups()); \
args = ([],True,False,False) if version >= (1, 8, 0) else ([],True,False); \
print(' '.join(cpp_extension._prepare_ldflags(*args)),end='');"
RESULT_VARIABLE _PYTHON_SUCCESS
OUTPUT_VARIABLE TORCH_LINK)
message("-- TORCH_LINK ${TORCH_LINK}")
if (NOT _PYTHON_SUCCESS MATCHES 0)
message(FATAL_ERROR "PyTorch link config Error.")
endif()
message("CMAKE_CUDA_FLAGS after torch: ${CMAKE_CUDA_FLAGS}")
endif()
include_directories(
${COMMON_HEADER_DIRS}
)
link_directories(
${COMMON_LIB_DIRS}
)
add_subdirectory(bytetransformer)
add_subdirectory(unit_test)
if(BUILD_THS)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
endif()