Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/cutlass
Submodule cutlass updated 1843 files
27 changes: 23 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ tvm_file_glob(GLOB RUNTIME_SRCS
src/runtime/minrpc/*.cc
src/runtime/relax_vm/*.cc
)
set(TVM_RUNTIME_EXT_OBJS "")

if(BUILD_FOR_HEXAGON)
if(NOT BUILD_STATIC_RUNTIME)
Expand Down Expand Up @@ -595,26 +596,44 @@ add_library(tvm_libinfo_objs OBJECT ${LIBINFO_FILE})

include(GNUInstallDirs)
if(NOT BUILD_DUMMY_LIBTVM)
add_library(tvm SHARED $<TARGET_OBJECTS:tvm_objs> $<TARGET_OBJECTS:tvm_runtime_objs> $<TARGET_OBJECTS:tvm_libinfo_objs>)
add_library(tvm SHARED
$<TARGET_OBJECTS:tvm_objs>
$<TARGET_OBJECTS:tvm_runtime_objs>
$<TARGET_OBJECTS:tvm_libinfo_objs>
${TVM_RUNTIME_EXT_OBJS}
)

else()
# dummy version of libtvm that can be used by downstream to specify dependencies
# the real runner still need a full version of libtvm
add_library(tvm SHARED $<TARGET_OBJECTS:tvm_runtime_objs> $<TARGET_OBJECTS:tvm_libinfo_objs>)
add_library(tvm SHARED
$<TARGET_OBJECTS:tvm_runtime_objs>
$<TARGET_OBJECTS:tvm_libinfo_objs>
${TVM_RUNTIME_EXT_OBJS}
)
endif()

target_include_directories(tvm PUBLIC "$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}>")
set_property(TARGET tvm APPEND PROPERTY LINK_OPTIONS "${TVM_NO_UNDEFINED_SYMBOLS}")
set_property(TARGET tvm APPEND PROPERTY LINK_OPTIONS "${TVM_VISIBILITY_FLAG}")
if(BUILD_STATIC_RUNTIME)
add_library(tvm_runtime STATIC $<TARGET_OBJECTS:tvm_runtime_objs> $<TARGET_OBJECTS:tvm_libinfo_objs>)
add_library(tvm_runtime STATIC
$<TARGET_OBJECTS:tvm_runtime_objs>
$<TARGET_OBJECTS:tvm_libinfo_objs>
${TVM_RUNTIME_EXT_OBJS}
)
set(NOTICE_MULTILINE
"You have build static version of the TVM runtime library. Make "
"sure to use --whole-archive when linking it into your project.")
string(CONCAT NOTICE ${NOTICE_MULTILINE})
add_custom_command(TARGET tvm_runtime POST_BUILD
COMMAND ${CMAKE_COMMAND} -E cmake_echo_color --yellow --bold ${NOTICE})
else()
add_library(tvm_runtime SHARED $<TARGET_OBJECTS:tvm_runtime_objs> $<TARGET_OBJECTS:tvm_libinfo_objs>)
add_library(tvm_runtime SHARED
$<TARGET_OBJECTS:tvm_runtime_objs>
$<TARGET_OBJECTS:tvm_libinfo_objs>
${TVM_RUNTIME_EXT_OBJS}
)
set_property(TARGET tvm_runtime APPEND PROPERTY LINK_OPTIONS "${TVM_NO_UNDEFINED_SYMBOLS}")
endif()

Expand Down
49 changes: 46 additions & 3 deletions cmake/modules/contrib/CUTLASS.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,59 @@
# under the License.

if(USE_CUDA AND USE_CUTLASS)
tvm_file_glob(GLOB CUTLASS_CONTRIB_SRC src/relay/backend/contrib/cutlass/*.cc src/relax/backend/contrib/cutlass/*.cc)
set(CUTLASS_GEN_COND "$<AND:$<BOOL:${USE_CUDA}>,$<BOOL:${USE_CUTLASS}>>")
set(CUTLASS_RUNTIME_OBJS "")

tvm_file_glob(GLOB CUTLASS_CONTRIB_SRC
src/relay/backend/contrib/cutlass/*.cc
src/relax/backend/contrib/cutlass/*.cc
)
list(APPEND COMPILER_SRCS ${CUTLASS_CONTRIB_SRC})

set(FPA_INTB_GEMM_TVM_BINDING ON)
set(FPA_INTB_GEMM_TVM_HOME ${PROJECT_SOURCE_DIR})

set(CUTLASS_DIR ${PROJECT_SOURCE_DIR}/3rdparty/cutlass)
### Build cutlass runtime objects for fpA_intB_gemm using its cutlass submodule
add_subdirectory(${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm)
target_include_directories(fpA_intB_gemm PRIVATE
${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm
${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm/cutlass/include
)
set(CUTLASS_FPA_INTB_RUNTIME_SRCS "")
list(APPEND CUTLASS_FPA_INTB_RUNTIME_SRCS src/runtime/contrib/cutlass/weight_preprocess.cc)
add_library(fpA_intB_cutlass_objs OBJECT ${CUTLASS_FPA_INTB_RUNTIME_SRCS})
target_compile_definitions(fpA_intB_cutlass_objs PRIVATE DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>)
target_include_directories(fpA_intB_cutlass_objs PRIVATE
${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm
${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm/cutlass/include
)
list(APPEND CUTLASS_RUNTIME_OBJS "$<${CUTLASS_GEN_COND}:$<TARGET_OBJECTS:fpA_intB_cutlass_objs>>")

### Build cutlass runtime objects for flash attention
add_subdirectory(${PROJECT_SOURCE_DIR}/3rdparty/libflash_attn)
list(APPEND RUNTIME_SRCS src/runtime/contrib/cutlass/weight_preprocess.cc)
target_include_directories(flash_attn PRIVATE
${PROJECT_SOURCE_DIR}/3rdparty/libflash_attn
${PROJECT_SOURCE_DIR}/3rdparty/libflash_attn/cutlass/include
)

### Build cutlass runtime objects using TVM's 3rdparty/cutlass submodule
set(CUTLASS_DIR ${PROJECT_SOURCE_DIR}/3rdparty/cutlass)
set(TVM_CUTLASS_RUNTIME_SRCS "")

if (CMAKE_CUDA_ARCHITECTURES MATCHES "90a")
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp16_group_gemm.cu)
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_group_gemm.cu)
endif()
if(TVM_CUTLASS_RUNTIME_SRCS)
add_library(tvm_cutlass_objs OBJECT ${TVM_CUTLASS_RUNTIME_SRCS})
target_compile_options(tvm_cutlass_objs PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>)
target_include_directories(tvm_cutlass_objs PRIVATE ${CUTLASS_DIR}/include)
target_compile_definitions(tvm_cutlass_objs PRIVATE DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>)
list(APPEND CUTLASS_RUNTIME_OBJS "$<${CUTLASS_GEN_COND}:$<TARGET_OBJECTS:tvm_cutlass_objs>>")
endif()

### Add cutlass objects to list of TVM runtime extension objs
list(APPEND TVM_RUNTIME_EXT_OBJS "${CUTLASS_RUNTIME_OBJS}")

message(STATUS "Build with CUTLASS")
endif()
70 changes: 70 additions & 0 deletions src/runtime/contrib/cutlass/fp16_group_gemm.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <cuda_fp16.h>
#include <float.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>

#include "group_gemm_runner.cuh"

#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)

template <>
struct KernelTraits<cutlass::half_t> {
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
using TileShape = Shape<_128, _256, _64>; // Threadblock-level tile size
using ClusterShape = Shape<_2, _2, _1>; // Shape of the threadblocks in a cluster
};

namespace tvm {
namespace runtime {

template <typename ElementA, typename ElementB, typename ElementC>
void tvm_cutlass_group_gemm_sm90(NDArray x, NDArray weight, NDArray indptr, NDArray workspace,
NDArray out) {
// Workspace is used for storing device-side group gemm arguments and cutlass internal workspace.
// Recommened size is 4MB.
auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
ICHECK(func != nullptr);
CHECK_EQ(x->ndim, 2);
CHECK_EQ(weight->ndim, 3);
CHECK_EQ(indptr->ndim, 1);
CHECK_EQ(workspace->ndim, 1);
CHECK_EQ(out->ndim, 2);
int num_groups = weight->shape[0];
int n = weight->shape[1];
int k = weight->shape[2];
float alpha = 1.0f;
float beta = 0.0f;
cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());
cutlass_group_gemm(static_cast<ElementA*>(x->data), static_cast<ElementB*>(weight->data),
static_cast<int64_t*>(indptr->data), static_cast<uint8_t*>(workspace->data),
workspace->shape[0], n, k, num_groups, alpha, beta,
static_cast<ElementC*>(out->data), stream);
}

TVM_REGISTER_GLOBAL("cutlass.group_gemm_fp16_sm90")
.set_body_typed(tvm_cutlass_group_gemm_sm90<cutlass::half_t, cutlass::half_t, cutlass::half_t>);

} // namespace runtime
} // namespace tvm

#endif // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED
83 changes: 83 additions & 0 deletions src/runtime/contrib/cutlass/fp8_group_gemm.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <cuda_fp16.h>
#include <float.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>

#include "group_gemm_runner.cuh"

#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)

template <>
struct KernelTraits<cutlass::float_e4m3_t> {
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum;
using TileShape = Shape<_128, _256, _64>; // Threadblock-level tile size
using ClusterShape = Shape<_2, _2, _1>; // Shape of the threadblocks in a cluster
};

template <>
struct KernelTraits<cutlass::float_e5m2_t> : KernelTraits<cutlass::float_e4m3_t> {};

namespace tvm {
namespace runtime {

template <typename ElementA, typename ElementB, typename ElementC>
void tvm_cutlass_fp8_group_gemm(NDArray x, NDArray weight, NDArray indptr, NDArray workspace,
NDArray alpha, NDArray out) {
// Workspace is used for storing device-side group gemm arguments and cutlass internal workspace.
// Recommened size is 4MB.
auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
ICHECK(func != nullptr);
CHECK_EQ(x->ndim, 2);
CHECK_EQ(weight->ndim, 3);
CHECK_EQ(indptr->ndim, 1);
CHECK_EQ(workspace->ndim, 1);
CHECK_EQ(out->ndim, 2);
CHECK_EQ(alpha->dtype.code, kDLFloat);
CHECK_EQ(alpha->dtype.bits, 32);
int num_groups = weight->shape[0];
int n = weight->shape[1];
int k = weight->shape[2];
const float* beta = nullptr;
cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());
cutlass_group_gemm(static_cast<ElementA*>(x->data), static_cast<ElementB*>(weight->data),
static_cast<int64_t*>(indptr->data), static_cast<uint8_t*>(workspace->data),
workspace->shape[0], n, k, num_groups, static_cast<float*>(alpha->data), beta,
static_cast<ElementC*>(out->data), stream);
}

TVM_REGISTER_GLOBAL("cutlass.group_gemm_e5m2_e5m2_fp16")
.set_body_typed(
tvm_cutlass_fp8_group_gemm<cutlass::float_e5m2_t, cutlass::float_e5m2_t, cutlass::half_t>);

TVM_REGISTER_GLOBAL("cutlass.group_gemm_e5m2_e4m3_fp16")
.set_body_typed(
tvm_cutlass_fp8_group_gemm<cutlass::float_e5m2_t, cutlass::float_e4m3_t, cutlass::half_t>);

TVM_REGISTER_GLOBAL("cutlass.group_gemm_e4m3_e4m3_fp16")
.set_body_typed(
tvm_cutlass_fp8_group_gemm<cutlass::float_e4m3_t, cutlass::float_e4m3_t, cutlass::half_t>);

} // namespace runtime
} // namespace tvm

#endif // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED
Loading