Skip to content

Commit c076a02

Browse files
VALLIS-NERIATom-Zhengdjns99ZhanruiSunChjiaganc
authored
[TRTLLM-4629] [feat] Add support of CUDA13 and sm103 devices (#7568)
Signed-off-by: Xiwen Yu <[email protected]> Signed-off-by: Tian Zheng <[email protected]> Signed-off-by: Daniel Stokes <[email protected]> Signed-off-by: Zhanrui Sun <[email protected]> Signed-off-by: Xiwen Yu <[email protected]> Signed-off-by: Jiagan Cheng <[email protected]> Signed-off-by: Yiqing Yan <[email protected]> Signed-off-by: Bo Deng <[email protected]> Signed-off-by: ZhanruiSunCh <[email protected]> Signed-off-by: xiweny <[email protected]> Co-authored-by: Tian Zheng <[email protected]> Co-authored-by: Daniel Stokes <[email protected]> Co-authored-by: Zhanrui Sun <[email protected]> Co-authored-by: Jiagan Cheng <[email protected]> Co-authored-by: Yiqing Yan <[email protected]> Co-authored-by: Bo Deng <[email protected]> Co-authored-by: Zhanrui Sun <[email protected]>
1 parent 809c4d2 commit c076a02

File tree

97 files changed

+1112
-511
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

97 files changed

+1112
-511
lines changed

3rdparty/cutlass

Submodule cutlass updated 606 files

3rdparty/json

Submodule json updated 856 files

cpp/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ endif()
248248
include_directories(
249249
SYSTEM
250250
${CUDAToolkit_INCLUDE_DIRS}
251+
${CUDAToolkit_INCLUDE_DIRS}/cccl
251252
${CUDNN_ROOT_DIR}/include
252253
$<TARGET_PROPERTY:TensorRT::NvInfer,INTERFACE_INCLUDE_DIRECTORIES>
253254
${3RDPARTY_DIR}/cutlass/include
@@ -510,7 +511,6 @@ print(os.path.dirname(torch.__file__),end='');"
510511
endif()
511512
endif()
512513
endif()
513-
514514
else()
515515
if(NOT WIN32)
516516
if(NOT USE_CXX11_ABI)

cpp/cmake/modules/cuda_configuration.cmake

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@ function(setup_cuda_architectures)
138138
message(FATAL_ERROR "Unrecognized CUDA architecture: ${CUDA_ARCH}")
139139
endif()
140140
endforeach()
141+
if("103" IN_LIST CMAKE_CUDA_ARCHITECTURES_CLEAN)
142+
list(APPEND CMAKE_CUDA_ARCHITECTURES_CLEAN "100")
143+
endif()
141144
list(REMOVE_DUPLICATES CMAKE_CUDA_ARCHITECTURES_CLEAN)
142145
set(CMAKE_CUDA_ARCHITECTURES_RAW ${CMAKE_CUDA_ARCHITECTURES_CLEAN})
143146
endif()
@@ -150,6 +153,9 @@ function(setup_cuda_architectures)
150153
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "12.7")
151154
list(APPEND CMAKE_CUDA_ARCHITECTURES_RAW 100 120)
152155
endif()
156+
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "12.9")
157+
list(APPEND CMAKE_CUDA_ARCHITECTURES_RAW 103)
158+
endif()
153159
endif()
154160

155161
# CMAKE_CUDA_ARCHITECTURES_ORIG contains all architectures enabled, without
@@ -160,7 +166,14 @@ function(setup_cuda_architectures)
160166
${CMAKE_CUDA_ARCHITECTURES_ORIG}
161167
PARENT_SCOPE)
162168

163-
set(ARCHITECTURES_WITH_KERNELS 80 86 89 90 100 120)
169+
set(ARCHITECTURES_WITH_KERNELS
170+
80
171+
86
172+
89
173+
90
174+
100
175+
103
176+
120)
164177
foreach(CUDA_ARCH IN LISTS ARCHITECTURES_WITH_KERNELS)
165178
if(NOT ${CUDA_ARCH} IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG)
166179
add_definitions("-DEXCLUDE_SM_${CUDA_ARCH}")

cpp/include/tensorrt_llm/common/cudaUtils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,12 @@ inline int getSMVersion()
311311
return sm;
312312
}
313313

314+
inline bool isSM100Family()
315+
{
316+
int const sm = getSMVersion();
317+
return sm == 100 || sm == 103; // To be continued...
318+
}
319+
314320
inline int getDevice()
315321
{
316322
int deviceID{0};

cpp/include/tensorrt_llm/deep_gemm/tma_utils.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ constexpr CUtensorMapDataType get_CUtensorMapDataType()
9595
}
9696
}
9797

98-
PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled()
98+
PFN_cuTensorMapEncodeTiled_v12000 get_cuTensorMapEncodeTiled()
9999
{
100100
// Get pointer to `cuTensorMapEncodeTiled`
101101
cudaDriverEntryPointQueryResult driver_status;
@@ -110,12 +110,12 @@ PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled()
110110

111111
if (driver_status != cudaDriverEntryPointSuccess)
112112
throw std::runtime_error("driver_status != cudaDriverEntryPointSuccess");
113-
return reinterpret_cast<PFN_cuTensorMapEncodeTiled>(cuTensorMapEncodeTiled_ptr);
113+
return reinterpret_cast<PFN_cuTensorMapEncodeTiled_v12000>(cuTensorMapEncodeTiled_ptr);
114114
}
115115

116116
template <typename T>
117117
CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2], uint64_t stride_in_bytes,
118-
uint32_t smem_dim[2], CUtensorMapSwizzle swizzle_type, PFN_cuTensorMapEncodeTiled encode_func = nullptr)
118+
uint32_t smem_dim[2], CUtensorMapSwizzle swizzle_type, PFN_cuTensorMapEncodeTiled_v12000 encode_func = nullptr)
119119
{
120120
CUtensorMap tensor_map{};
121121
constexpr uint32_t rank = 2;

cpp/kernels/fmha_v2/Makefile

Lines changed: 2 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,6 @@ NVCC_FLAGS += $(PREPROCESSOR_FLAGS)
9090
# The include directories.
9191
INCLUDE_DIRS += -I./src -I./generated -I$(CUDA)/include
9292

93-
GENCODE_SM70 = -gencode=arch=compute_70,code=\"sm_70\"
94-
GENCODE_SM72 = -gencode=arch=compute_72,code=\"sm_72\"
95-
GENCODE_SM75 = -gencode=arch=compute_75,code=\"sm_75\"
9693
GENCODE_SM80 = -gencode=arch=compute_80,code=\"sm_80\"
9794
GENCODE_SM86 = -gencode=arch=compute_86,code=\"sm_86\"
9895
GENCODE_SM87 = -gencode=arch=compute_87,code=\"sm_87\"
@@ -125,9 +122,8 @@ endif
125122
CUBIN_CPP = $(patsubst %.cu.cubin, %.cubin.cpp, $(CUBINS))
126123
CUBIN_OBJ = $(patsubst %.cubin.cpp, %.cubin.o, $(CUBIN_CPP))
127124

128-
GENCODES = $(GENCODE_SM70)
129-
GENCODES += $(GENCODE_SM72)
130-
GENCODES += $(GENCODE_SM75)
125+
GENCODES =
126+
131127
GENCODES += $(GENCODE_SM80)
132128
GENCODES += $(GENCODE_SM86)
133129
GENCODES += $(GENCODE_SM89)
@@ -152,20 +148,12 @@ UNIT_TEST_OBJ = $(patsubst %.cu, obj/%.o, $(UNIT_TEST_CPP))
152148
UNIT_TEST_EXE = $(patsubst %.cu, bin/%.exe, $(UNIT_TEST_CPP))
153149

154150
# arch-dependent boilerplates
155-
UNIT_TEST_CPP_SM70 =
156-
ifdef ENABLE_SM70
157-
UNIT_TEST_CPP_SM70 = $(wildcard $(UNIT_TEST_CPP_DIR)/arch/*_sm70.cu)
158-
UNIT_TEST_OBJ_SM70 = $(patsubst %_sm70.cu, obj/%_sm70.o, $(UNIT_TEST_CPP_SM70))
159-
UNIT_TEST_EXE_SM70 = $(patsubst %_sm70.cu, bin/%_sm70.exe, $(UNIT_TEST_CPP_SM70))
160-
endif
161-
162151
UNIT_TEST_CPP_SM80 = $(wildcard $(UNIT_TEST_CPP_DIR)/arch/*_sm80.cu)
163152
UNIT_TEST_OBJ_SM80 = $(patsubst %_sm80.cu, obj/%_sm80.o, $(UNIT_TEST_CPP_SM80))
164153
UNIT_TEST_EXE_SM80 = $(patsubst %_sm80.cu, bin/%_sm80.exe, $(UNIT_TEST_CPP_SM80))
165154

166155
# aggregate exes as prerequisite of build target "test"
167156
UNIT_TEST_EXE_ARCH =
168-
UNIT_TEST_EXE_ARCH += $(UNIT_TEST_EXE_SM70)
169157
UNIT_TEST_EXE_ARCH += $(UNIT_TEST_EXE_SM80)
170158

171159
# #################################################################################################
@@ -248,12 +236,6 @@ bin/libfmha_cubin.a: $(CUBIN_OBJ)
248236

249237
###################################################################################################
250238

251-
obj/%_sm70.cu.o: ./generated/%_sm70.cu ./src/*.h ./src/fmha/*.h
252-
$(NVCC) $(NVCC_FLAGS) $(I2F_F2I_FLAGS) $(GENCODE_SM70) $(INCLUDE_DIRS) -c -o $@ $<
253-
obj/%_sm72.cu.o: ./generated/%_sm72.cu ./src/*.h ./src/fmha/*.h
254-
$(NVCC) $(NVCC_FLAGS) $(I2F_F2I_FLAGS) $(GENCODE_SM72) $(INCLUDE_DIRS) -c -o $@ $<
255-
obj/%_sm75.cu.o: ./generated/%_sm75.cu ./src/*.h ./src/fmha/*.h
256-
$(NVCC) $(NVCC_FLAGS) $(I2F_F2I_FLAGS) $(GENCODE_SM75) $(INCLUDE_DIRS) -c -o $@ $<
257239
obj/%_sm80.cu.o: ./generated/%_sm80.cu ./src/*.h ./src/fmha/*.h
258240
$(NVCC) $(NVCC_FLAGS) $(I2F_F2I_FLAGS) $(GENCODE_SM80) $(INCLUDE_DIRS) -c -o $@ $<
259241
obj/%_sm86.cu.o: ./generated/%_sm86.cu ./src/*.h ./src/fmha/*.h
@@ -269,12 +251,6 @@ obj/%_sm100.cu.o: ./generated/%_sm100.cu ./src/*.h ./src/fmha/*.h ./src/fmha/hop
269251
obj/%_sm120.cu.o: ./generated/%_sm120.cu ./src/*.h ./src/fmha/*.h
270252
$(NVCC) $(NVCC_FLAGS) $(I2F_F2I_FLAGS) $(GENCODE_SM120) $(INCLUDE_DIRS) -c -o $@ $<
271253

272-
obj/%_sm70.no_i2f_f2i.cu.o: ./generated/%_sm70.no_i2f_f2i.cu ./src/*.h ./src/fmha/*.h
273-
$(NVCC) $(NVCC_FLAGS) $(GENCODE_SM70) $(INCLUDE_DIRS) -c -o $@ $<
274-
obj/%_sm72.no_i2f_f2i.cu.o: ./generated/%_sm72.no_i2f_f2i.cu ./src/*.h ./src/fmha/*.h
275-
$(NVCC) $(NVCC_FLAGS) $(GENCODE_SM72) $(INCLUDE_DIRS) -c -o $@ $<
276-
obj/%_sm75.no_i2f_f2i.cu.o: ./generated/%_sm75.no_i2f_f2i.cu ./src/*.h ./src/fmha/*.h
277-
$(NVCC) $(NVCC_FLAGS) $(GENCODE_SM75) $(INCLUDE_DIRS) -c -o $@ $<
278254
obj/%_sm80.no_i2f_f2i.cu.o: ./generated/%_sm80.no_i2f_f2i.cu ./src/*.h ./src/fmha/*.h
279255
$(NVCC) $(NVCC_FLAGS) $(GENCODE_SM80) $(INCLUDE_DIRS) -c -o $@ $<
280256
obj/%_sm86.no_i2f_f2i.cu.o: ./generated/%_sm86.no_i2f_f2i.cu ./src/*.h ./src/fmha/*.h
@@ -314,20 +290,11 @@ $(UNIT_TEST_OBJ): $(UNIT_TEST_OBJ_DIR)/%.o : ${UNIT_TEST_CPP_DIR}/%.cu ./src/*.h
314290
$(NVCC) $(NVCC_FLAGS) $(I2F_F2I_FLAGS) $(GENCODES) -c -o $@ $< -I./src $(GTEST_INC)
315291

316292
# arch-dependent objs
317-
$(UNIT_TEST_OBJ_SM70): %.o : $(UNIT_TEST_CPP_SM70) ./src/*.h ./src/fmha/*.h
318-
$(NVCC) $(NVCC_FLAGS) $(I2F_F2I_FLAGS) $(GENCODE_SM70) -c -o $@ $< -I./src $(GTEST_INC)
319-
320293
$(UNIT_TEST_OBJ_SM80): %.o : $(UNIT_TEST_CPP_SM80) ./src/*.h ./src/fmha/*.h
321294
$(NVCC) $(NVCC_FLAGS) $(I2F_F2I_FLAGS) $(GENCODE_SM80) -c -o $@ $< -I./src $(GTEST_INC)
322295

323296
###################################################################################################
324297

325-
cubin/%_sm70.cu.cubin: ./generated/%_sm70.cu ./src/*.h ./src/fmha/*.h
326-
$(NVCC) $(NVCC_FLAGS) $(I2F_F2I_FLAGS) $(GENCODE_SM70) $(INCLUDE_DIRS) -cubin -o $@ $<
327-
cubin/%_sm72.cu.cubin: ./generated/%_sm72.cu ./src/*.h ./src/fmha/*.h
328-
$(NVCC) $(NVCC_FLAGS) $(I2F_F2I_FLAGS) $(GENCODE_SM72) $(INCLUDE_DIRS) -cubin -o $@ $<
329-
cubin/%_sm75.cu.cubin: ./generated/%_sm75.cu ./src/*.h ./src/fmha/*.h
330-
$(NVCC) $(NVCC_FLAGS) $(I2F_F2I_FLAGS) $(GENCODE_SM75) $(INCLUDE_DIRS) -cubin -o $@ $<
331298
cubin/%_sm80.cu.cubin: ./generated/%_sm80.cu ./src/*.h ./src/fmha/*.h
332299
$(NVCC) $(NVCC_FLAGS) $(I2F_F2I_FLAGS) $(GENCODE_SM80) $(INCLUDE_DIRS) -cubin -o $@ $<
333300
cubin/%_sm86.cu.cubin: ./generated/%_sm86.cu ./src/*.h ./src/fmha/*.h
@@ -343,12 +310,6 @@ cubin/%_sm100.cu.cubin: ./generated/%_sm100.cu ./src/*.h ./src/fmha/*.h
343310
cubin/%_sm120.cu.cubin: ./generated/%_sm120.cu ./src/*.h ./src/fmha/*.h
344311
$(NVCC) $(NVCC_FLAGS) $(I2F_F2I_FLAGS) $(GENCODE_SM120) $(INCLUDE_DIRS) -cubin -o $@ $<
345312

346-
cubin/%_sm70.no_i2f_f2i.cu.cubin: ./generated/%_sm70.no_i2f_f2i.cu ./src/*.h ./src/fmha/*.h
347-
$(NVCC) $(NVCC_FLAGS) $(GENCODE_SM70) $(INCLUDE_DIRS) -cubin -o $@ $<
348-
cubin/%_sm72.no_i2f_f2i.cu.cubin: ./generated/%_sm72.no_i2f_f2i.cu ./src/*.h ./src/fmha/*.h
349-
$(NVCC) $(NVCC_FLAGS) $(GENCODE_SM72) $(INCLUDE_DIRS) -cubin -o $@ $<
350-
cubin/%_sm75.no_i2f_f2i.cu.cubin: ./generated/%_sm75.no_i2f_f2i.cu ./src/*.h ./src/fmha/*.h
351-
$(NVCC) $(NVCC_FLAGS) $(GENCODE_SM75) $(INCLUDE_DIRS) -cubin -o $@ $<
352313
cubin/%_sm80.no_i2f_f2i.cu.cubin: ./generated/%_sm80.no_i2f_f2i.cu ./src/*.h ./src/fmha/*.h
353314
$(NVCC) $(NVCC_FLAGS) $(GENCODE_SM80) $(INCLUDE_DIRS) -cubin -o $@ $<
354315
cubin/%_sm86.no_i2f_f2i.cu.cubin: ./generated/%_sm86.no_i2f_f2i.cu ./src/*.h ./src/fmha/*.h

cpp/tensorrt_llm/common/attentionOp.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2530,22 +2530,22 @@ int AttentionOp::initialize() noexcept
25302530
if (mFP8ContextFMHA)
25312531
{
25322532
TLLM_CHECK_WITH_INFO(mEnableContextFMHA, "FP8 FMHA cannot be enabled because Context FMHA is not supported.");
2533-
TLLM_CHECK_WITH_INFO(mSM == 89 || mSM == 90 || mSM == 100 || mSM == 120 || mSM == 121,
2534-
"FP8 FMHA can only be enabled on sm_89, sm_90, sm_100, sm_120 or sm_121.");
2533+
TLLM_CHECK_WITH_INFO(mSM == 89 || mSM == 90 || mSM == 100 || mSM == 103 || mSM == 120 || mSM == 121,
2534+
"FP8 FMHA can only be enabled on sm_89, sm_90, sm_100f, sm_120 or sm_121.");
25352535
}
25362536

25372537
// Pre-Check of FP8 Generation MLA.
25382538
if (mFP8GenerationMLA)
25392539
{
25402540
TLLM_CHECK_WITH_INFO(mIsMLAEnabled, "FP8 Generation MLA cannot be enabled because MLA is not supported.");
2541-
TLLM_CHECK_WITH_INFO(mSM == 89 || mSM == 90 || mSM == 100 || mSM == 120 || mSM == 121,
2541+
TLLM_CHECK_WITH_INFO(mSM == 89 || mSM == 90 || mSM == 100 || mSM == 103 || mSM == 120 || mSM == 121,
25422542
"FP8 Generation MLA is supported on Ada, Hopper or Blackwell architecture.");
25432543
}
25442544

25452545
// Check requirements for FP4 output.
25462546
TLLM_CHECK_WITH_INFO(!mFuseFp4Quant || mEnableContextFMHA, "Context FMHA must enable if fuse_fp4_quant is enabled");
2547-
TLLM_CHECK_WITH_INFO(!mFuseFp4Quant || mSM == 100 || mSM == 120 || mSM == 121,
2548-
"fuse_fp4_quant only supports SM100 or SM120 or SM121 devices.");
2547+
TLLM_CHECK_WITH_INFO(!mFuseFp4Quant || (mSM == 100 || mSM == 103) || mSM == 120 || mSM == 121,
2548+
"fuse_fp4_quant only supports SM100f or SM120 or SM121 devices.");
25492549

25502550
// Check requirements for FP4 KV cache.
25512551
TLLM_CHECK_WITH_INFO(!mKVCacheQuantMode.hasFp4KvCache() || mFP8ContextFMHA,

cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
#include <type_traits>
2424

2525
#include "cute/tensor.hpp"
26+
#include "tensorrt_llm/common/assert.h"
27+
#include "tensorrt_llm/common/tllmException.h"
2628

2729
namespace tensorrt_llm
2830
{
@@ -155,6 +157,9 @@ enum class CutlassTileConfigSM100 : int
155157
CtaShape128x256x256B = shape_tuple_to_enum(128, 256, 256),
156158
};
157159

160+
// An alias to make the SHAPE_CASE macro work
161+
using CutlassTileConfigSM103 = CutlassTileConfigSM100;
162+
158163
enum class CutlassTileConfigSM120 : int
159164
{
160165
// Signals that we should run heuristics do choose a config
@@ -411,16 +416,17 @@ struct CutlassGemmConfig
411416
CutlassGemmConfig(CutlassTileConfigSM100 tile_config_sm100, MainloopScheduleType mainloop_schedule,
412417
EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape,
413418
ClusterShape dynamic_cluster_shape = ClusterShape::Undefined,
414-
ClusterShape fallback_cluster_shape = ClusterShape::Undefined)
419+
ClusterShape fallback_cluster_shape = ClusterShape::Undefined, int sm_version = 100)
415420
: tile_config_sm100(tile_config_sm100)
416421
, mainloop_schedule(mainloop_schedule)
417422
, epilogue_schedule(epilogue_schedule)
418423
, cluster_shape(cluster_shape)
419424
, dynamic_cluster_shape(dynamic_cluster_shape)
420425
, fallback_cluster_shape(fallback_cluster_shape)
421-
, sm_version(100)
426+
, sm_version(sm_version)
422427
, is_tma_warp_specialized(true)
423428
{
429+
TLLM_CHECK_WITH_INFO(sm_version >= 100 && sm_version < 120, "Expected SM 10x version");
424430
}
425431

426432
CutlassGemmConfig(CutlassTileConfigSM120 tile_config_sm120, MainloopScheduleType mainloop_schedule,

0 commit comments

Comments
 (0)