Skip to content

Commit 1b4c7e8

Browse files
committed
[CMAKE][CUTLASS] Improve dependancy management with different cutlass versions.
* Each cutlass-based submodule library now uses its own cutlass submodule dependancy * TVM's cutlass submodule is decoupled from others and is bumped to v3.4.1 for H100 support * Add scaffold for new cutlass fp8 dequant gemm interface targetting TVM's cutlass submodule
1 parent 89cc09c commit 1b4c7e8

File tree

4 files changed

+98
-9
lines changed

4 files changed

+98
-9
lines changed

3rdparty/cutlass

Submodule cutlass updated 1843 files

CMakeLists.txt

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@ tvm_file_glob(GLOB RUNTIME_SRCS
368368
src/runtime/minrpc/*.cc
369369
src/runtime/relax_vm/*.cc
370370
)
371+
set(TVM_RUNTIME_EXT_OBJS "")
371372

372373
if(BUILD_FOR_HEXAGON)
373374
if(NOT BUILD_STATIC_RUNTIME)
@@ -594,26 +595,44 @@ add_library(tvm_libinfo_objs OBJECT ${LIBINFO_FILE})
594595

595596
include(GNUInstallDirs)
596597
if(NOT BUILD_DUMMY_LIBTVM)
597-
add_library(tvm SHARED $<TARGET_OBJECTS:tvm_objs> $<TARGET_OBJECTS:tvm_runtime_objs> $<TARGET_OBJECTS:tvm_libinfo_objs>)
598+
add_library(tvm SHARED
599+
$<TARGET_OBJECTS:tvm_objs>
600+
$<TARGET_OBJECTS:tvm_runtime_objs>
601+
$<TARGET_OBJECTS:tvm_libinfo_objs>
602+
${TVM_RUNTIME_EXT_OBJS}
603+
)
604+
598605
else()
599606
# dummy version of libtvm that can be used by downstream to specify dependencies
600607
# the real runner still need a full version of libtvm
601-
add_library(tvm SHARED $<TARGET_OBJECTS:tvm_runtime_objs> $<TARGET_OBJECTS:tvm_libinfo_objs>)
608+
add_library(tvm SHARED
609+
$<TARGET_OBJECTS:tvm_runtime_objs>
610+
$<TARGET_OBJECTS:tvm_libinfo_objs>
611+
${TVM_RUNTIME_EXT_OBJS}
612+
)
602613
endif()
603614

604615
target_include_directories(tvm PUBLIC "$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}>")
605616
set_property(TARGET tvm APPEND PROPERTY LINK_OPTIONS "${TVM_NO_UNDEFINED_SYMBOLS}")
606617
set_property(TARGET tvm APPEND PROPERTY LINK_OPTIONS "${TVM_VISIBILITY_FLAG}")
607618
if(BUILD_STATIC_RUNTIME)
608-
add_library(tvm_runtime STATIC $<TARGET_OBJECTS:tvm_runtime_objs> $<TARGET_OBJECTS:tvm_libinfo_objs>)
619+
add_library(tvm_runtime STATIC
620+
$<TARGET_OBJECTS:tvm_runtime_objs>
621+
$<TARGET_OBJECTS:tvm_libinfo_objs>
622+
${TVM_RUNTIME_EXT_OBJS}
623+
)
609624
set(NOTICE_MULTILINE
610625
"You have build static version of the TVM runtime library. Make "
611626
"sure to use --whole-archive when linking it into your project.")
612627
string(CONCAT NOTICE ${NOTICE_MULTILINE})
613628
add_custom_command(TARGET tvm_runtime POST_BUILD
614629
COMMAND ${CMAKE_COMMAND} -E cmake_echo_color --yellow --bold ${NOTICE})
615630
else()
616-
add_library(tvm_runtime SHARED $<TARGET_OBJECTS:tvm_runtime_objs> $<TARGET_OBJECTS:tvm_libinfo_objs>)
631+
add_library(tvm_runtime SHARED
632+
$<TARGET_OBJECTS:tvm_runtime_objs>
633+
$<TARGET_OBJECTS:tvm_libinfo_objs>
634+
${TVM_RUNTIME_EXT_OBJS}
635+
)
617636
set_property(TARGET tvm_runtime APPEND PROPERTY LINK_OPTIONS "${TVM_NO_UNDEFINED_SYMBOLS}")
618637
endif()
619638

cmake/modules/contrib/CUTLASS.cmake

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,61 @@
1616
# under the License.
1717

1818
if(USE_CUDA AND USE_CUTLASS)
19-
tvm_file_glob(GLOB CUTLASS_CONTRIB_SRC src/relay/backend/contrib/cutlass/*.cc src/relax/backend/contrib/cutlass/*.cc)
19+
set(CUTLASS_GEN_COND "$<AND:$<BOOL:${USE_CUDA}>,$<BOOL:${USE_CUTLASS}>>")
20+
set(CUTLASS_RUNTIME_OBJS "")
21+
22+
tvm_file_glob(GLOB CUTLASS_CONTRIB_SRC
23+
src/relay/backend/contrib/cutlass/*.cc
24+
src/relax/backend/contrib/cutlass/*.cc
25+
)
2026
list(APPEND COMPILER_SRCS ${CUTLASS_CONTRIB_SRC})
2127

2228
set(FPA_INTB_GEMM_TVM_BINDING ON)
2329
set(FPA_INTB_GEMM_TVM_HOME ${PROJECT_SOURCE_DIR})
2430

25-
set(CUTLASS_DIR ${PROJECT_SOURCE_DIR}/3rdparty/cutlass)
31+
### Build cutlass runtime objects for fpA_intB_gemm using its cutlass submodule
2632
add_subdirectory(${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm)
33+
target_include_directories(fpA_intB_gemm PRIVATE
34+
${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm
35+
${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm/cutlass/include
36+
)
37+
set(CUTLASS_FPA_INTB_RUNTIME_SRCS "")
38+
list(APPEND CUTLASS_FPA_INTB_RUNTIME_SRCS src/runtime/contrib/cutlass/moe_gemm.cc)
39+
list(APPEND CUTLASS_FPA_INTB_RUNTIME_SRCS src/runtime/contrib/cutlass/weight_preprocess.cc)
40+
add_library(fpA_intB_cutlass_objs OBJECT ${CUTLASS_FPA_INTB_RUNTIME_SRCS})
41+
target_include_directories(fpA_intB_cutlass_objs PRIVATE
42+
${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm/cutlass/include
43+
)
44+
list(APPEND CUTLASS_RUNTIME_OBJS "$<${CUTLASS_GEN_COND}:$<TARGET_OBJECTS:fpA_intB_cutlass_objs>>")
45+
46+
### Build cutlass runtime objects for flash attention using its cutlass submodule
2747
add_subdirectory(${PROJECT_SOURCE_DIR}/3rdparty/libflash_attn)
28-
list(APPEND RUNTIME_SRCS src/runtime/contrib/cutlass/weight_preprocess.cc)
48+
target_include_directories(flash_attn PRIVATE
49+
${PROJECT_SOURCE_DIR}/3rdparty/libflash_attn
50+
${PROJECT_SOURCE_DIR}/3rdparty/libflash_attn/cutlass/include
51+
)
52+
set(CUTLASS_FLASH_ATTN_RUNTIME_SRCS "")
53+
list(APPEND CUTLASS_FLASH_ATTN_RUNTIME_SRCS src/runtime/contrib/cutlass/flash_decoding.cu)
54+
add_library(flash_attn_cutlass_objs OBJECT ${CUTLASS_FLASH_ATTN_RUNTIME_SRCS})
55+
target_include_directories(flash_attn_cutlass_objs PRIVATE
56+
${PROJECT_SOURCE_DIR}/3rdparty/libflash_attn/cutlass/include
57+
)
58+
list(APPEND CUTLASS_RUNTIME_OBJS "$<${CUTLASS_GEN_COND}:$<TARGET_OBJECTS:flash_attn_cutlass_objs>>")
59+
60+
### Build cutlass runtime objects using TVM's 3rdparty/cutlass submodule
61+
set(CUTLASS_DIR ${PROJECT_SOURCE_DIR}/3rdparty/cutlass)
62+
set(TVM_CUTLASS_RUNTIME_SRCS "")
63+
if (CMAKE_CUDA_ARCHITECTURES MATCHES "90")
64+
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp16_fp8_gemm.cu)
65+
endif()
66+
if(TVM_CUTLASS_RUNTIME_SRCS)
67+
add_library(tvm_cutlass_objs OBJECT ${TVM_CUTLASS_RUNTIME_SRCS})
68+
target_include_directories(tvm_cutlass_objs PRIVATE ${CUTLASS_DIR}/include)
69+
list(APPEND CUTLASS_RUNTIME_OBJS "$<${CUTLASS_GEN_COND}:$<TARGET_OBJECTS:tvm_cutlass_objs>>")
70+
endif()
71+
72+
### Add cutlass objects to list of TVM runtime extension objs
73+
list(APPEND TVM_RUNTIME_EXT_OBJS "${CUTLASS_RUNTIME_OBJS}")
2974

3075
message(STATUS "Build with CUTLASS")
31-
endif()
76+
endif()
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#include <cuda_fp16.h>
21+
#include <tvm/runtime/ndarray.h>
22+
#include <tvm/runtime/packed_func.h>
23+
#include <tvm/runtime/registry.h>
24+
25+
TVM_REGISTER_GLOBAL("cutlass.fp16_fp8_gemm").set_body_typed([]() { return 0; });

0 commit comments

Comments
 (0)