Skip to content
Merged
Show file tree
Hide file tree
Changes from 83 commits
Commits
Show all changes
108 commits
Select commit Hold shift + click to select a range
303604f
upgrade to base image and new TRT, fix many dependency issues
VALLIS-NERIA Jun 17, 2025
5c09dc8
CUDA13 breaking changes: c++ compile successful
VALLIS-NERIA Jun 17, 2025
1b84604
fix kernel select code to recognize sm103/sm100f
VALLIS-NERIA Jul 2, 2025
3a94d80
Update SM100f cubins
Tom-Zheng Jul 2, 2025
469a38d
feat: Add support for SM103 3xFP4 tile shapes
djns99 Jul 8, 2025
52ad443
disable 3xfp4
VALLIS-NERIA Jul 21, 2025
345c2bc
update trtllm-gen sm100f cubins of gemm kernels
VALLIS-NERIA Aug 4, 2025
e27cbb5
Ampere moe kernel should build to all arch
VALLIS-NERIA Aug 4, 2025
78a55b8
fix vicuna dependency
VALLIS-NERIA Aug 4, 2025
271916d
fix deep_gemm & CUDA13
VALLIS-NERIA Aug 5, 2025
886437d
merge existing env fix
VALLIS-NERIA Aug 6, 2025
b782b6e
fix sm check of kv reuse and chunked context
VALLIS-NERIA Aug 6, 2025
84f96b4
update triton and fix deepgemm pip
VALLIS-NERIA Aug 6, 2025
759e7a0
Merge remote-tracking branch 'gitlab/main' into feat/gb110_bringup
VALLIS-NERIA Aug 6, 2025
bee1df9
remove deepgemm war
VALLIS-NERIA Aug 6, 2025
97a3788
update triton image
VALLIS-NERIA Aug 6, 2025
ebec4ea
infra: upgrade to DLFW 25.08-pre and TRT 10.13.2.4
ZhanruiSunCh Aug 12, 2025
36f2e88
Merge branch 'user/zhanruis/update_dlfw_and_cu13' into 'feat/b300_cu13'
ZhanruiSunCh Aug 12, 2025
0bf6a18
Fix and waive to clean L0
VALLIS-NERIA Aug 15, 2025
f12a90b
Merge branch 'feat/gb110_bringup' into 'feat/b300_cu13'
VALLIS-NERIA Aug 15, 2025
8c99853
infra: Support build for both CU12 and CU13
ZhanruiSunCh Aug 18, 2025
c1014e8
Merge branch 'user/zhanruis/update_dlfw_and_cu13_2' into 'feat/b300_c…
ZhanruiSunCh Aug 18, 2025
4a95d88
revert tlg kernels for ease of merge
VALLIS-NERIA Aug 19, 2025
8b53236
Merge remote-tracking branch 'gitlab/main' into user/xiweny/merge_mai…
VALLIS-NERIA Aug 19, 2025
5391191
update tg cubins (temp ver)
VALLIS-NERIA Aug 21, 2025
f4de884
Merge remote-tracking branch 'gitlab/main' into user/xiweny/merge_mai…
VALLIS-NERIA Aug 21, 2025
b7cc06c
disable merge waive list stage
VALLIS-NERIA Aug 21, 2025
fa8b52e
fix more sm version check
VALLIS-NERIA Aug 22, 2025
808059d
Merge remote-tracking branch 'gitlab/main' into user/xiweny/merge_mai…
VALLIS-NERIA Aug 23, 2025
90a9bc4
fix build error
VALLIS-NERIA Aug 23, 2025
80ea062
fix cubins
VALLIS-NERIA Aug 24, 2025
66b1d8d
Update flashinfer
VALLIS-NERIA Aug 25, 2025
ab7febd
Merge commit '31979aefacbf80d2742c98ef30385db162788c84' into feat/b30…
VALLIS-NERIA Aug 26, 2025
b1c6f6a
update cutlass and DeepGEMM
VALLIS-NERIA Aug 27, 2025
9ad68de
Merge branch 'user/xiweny/update_cutlass_4.2' into 'feat/b300_cu13'
VALLIS-NERIA Aug 27, 2025
ee37589
infra: update DLFW 25.08 GA, triton 25.08 GA
ZhanruiSunCh Aug 28, 2025
c2e1ad9
Merge branch 'user/zhanruis/update_dlfw_and_cu13_3' into 'feat/b300_c…
ZhanruiSunCh Aug 28, 2025
6fd765f
[None][fix] fix trtllm moe backend error when running gptoss on b300
jiaganc Aug 28, 2025
f14c740
Merge branch 'dev-jiaganc-fix-b300-gptoss-trtllm' into 'feat/b300_cu13'
VALLIS-NERIA Aug 28, 2025
3c06303
[TRTLLM-7755][infra] Add DGX_B300 and GB300 tests in CI
yiqingy0 Aug 29, 2025
c425c12
Merge branch 'user/yiqingy/add_b300_tests' into 'feat/b300_cu13'
yiqingy0 Aug 29, 2025
0fb835d
fix cutlass moe not falling back
VALLIS-NERIA Aug 30, 2025
8d5a7ea
[https://nvbugs/5443053][fix] Disable finalize fusion when Lora is used
jiaganc Sep 1, 2025
3cc2591
Merge branch 'dev-jiaganc-fix-b300-moe-lora' into 'feat/b300_cu13'
VALLIS-NERIA Sep 1, 2025
3805f61
[https://nvbugs/5453949][infra] unwaive test_llama_eagle3
bo-nv Aug 27, 2025
a765ee4
Merge branch 'feat/b300_cu13-latest' into 'feat/b300_cu13'
VALLIS-NERIA Sep 1, 2025
14154ec
disable sm103 moe kernel
VALLIS-NERIA Sep 1, 2025
38ef850
Merge remote-tracking branch 'gitlab/main' into user/xiweny/merge_0901
VALLIS-NERIA Sep 1, 2025
62a7897
Merge remote-tracking branch 'origin/main' into user/xiweny/merge_0901
VALLIS-NERIA Sep 2, 2025
90ce786
Fix arg name in _test_trtllm_serve_multimodal_benchmark.py
VALLIS-NERIA Sep 2, 2025
5bd50d4
update mha cubins and support 103a
VALLIS-NERIA Sep 3, 2025
1978227
Merge branch 'user/xiweny/mha_103' into 'feat/b300_cu13'
VALLIS-NERIA Sep 3, 2025
5ca3376
Support DLFW sanity check use CU13 image
ZhanruiSunCh Sep 5, 2025
9ae01a8
Merge branch 'user/zhanruis/0828_support_cuda_13_for_sanity_check' in…
ZhanruiSunCh Sep 5, 2025
973fd37
add 3xfp4 cutlass gemm
VALLIS-NERIA Sep 5, 2025
fcf413e
Merge branch 'user/xiweny/3xfp4_gemm' into 'feat/b300_cu13'
VALLIS-NERIA Sep 5, 2025
5d4f7f4
update flashinfer and waive bug
VALLIS-NERIA Sep 5, 2025
22219bc
Add B300 & GB300 CI
VALLIS-NERIA Sep 5, 2025
2c3f4cb
Merge remote-tracking branch 'origin/main' into feat/b300_cu13
VALLIS-NERIA Sep 5, 2025
f8864b9
update trtllm gemm
VALLIS-NERIA Sep 5, 2025
cca347e
[TRTLLM-4629] [feat] Step1: trtllm-gen kernels support sm103
VALLIS-NERIA Sep 5, 2025
5e7aa76
Merge branch 'user/sm103_trtllmgen' into feat/b300_cu13
VALLIS-NERIA Sep 5, 2025
10af4f4
[TRTLLM-4629] [feat] Step1: trtllm-gen kernels support sm103
VALLIS-NERIA Sep 5, 2025
1d7979a
fix
VALLIS-NERIA Sep 5, 2025
3e71ec7
Merge branch 'user/sm103_trtllmgen' into feat/b300_cu13
VALLIS-NERIA Sep 5, 2025
65f8478
fix trtllm-gen interface change
VALLIS-NERIA Sep 5, 2025
bec1e71
fix
VALLIS-NERIA Sep 5, 2025
0b0781f
fix
VALLIS-NERIA Sep 6, 2025
3d4f49e
fix missing gemm kernels
VALLIS-NERIA Sep 6, 2025
1150def
Merge branch 'user/sm103_trtllmgen' into feat/b300_cu13
VALLIS-NERIA Sep 6, 2025
d12eb4b
fix CI build archs
VALLIS-NERIA Sep 6, 2025
322db71
Merge remote-tracking branch 'origin/main' into feat/b300_cu13
VALLIS-NERIA Sep 6, 2025
8f8766a
waive
VALLIS-NERIA Sep 7, 2025
2912908
Merge remote-tracking branch 'origin/main' into feat/b300_cu13
VALLIS-NERIA Sep 7, 2025
e6bb1fe
remove non-exist cases
VALLIS-NERIA Sep 7, 2025
77657de
fix build args
VALLIS-NERIA Sep 8, 2025
d42201e
remove waivers and cleanup
VALLIS-NERIA Sep 8, 2025
caea58a
increase build memory
VALLIS-NERIA Sep 8, 2025
d4d9e77
reset build memory
VALLIS-NERIA Sep 8, 2025
019b1db
fix 5505835
VALLIS-NERIA Sep 8, 2025
fdaf4e2
Merge remote-tracking branch 'origin/main' into feat/b300_cu13
VALLIS-NERIA Sep 8, 2025
e30e0c8
waive
VALLIS-NERIA Sep 8, 2025
4cf9fed
Merge commit 'ed27a72bcf71f7ab0e7137f7999988c9de82386f' into feat/b30…
VALLIS-NERIA Sep 8, 2025
b573e07
[None][infra] Disable CU12 build to save build time (cost > 5 hours o…
ZhanruiSunCh Sep 9, 2025
82833fa
address comments
VALLIS-NERIA Sep 9, 2025
8cc5ea3
add comment
VALLIS-NERIA Sep 9, 2025
a8b630f
Merge remote-tracking branch 'origin/main' into feat/b300_cu13
VALLIS-NERIA Sep 9, 2025
2c287d5
don't throw in ctor
VALLIS-NERIA Sep 9, 2025
11d603b
fix
VALLIS-NERIA Sep 9, 2025
d16d98c
fix missing change
VALLIS-NERIA Sep 9, 2025
5f508b7
Merge remote-tracking branch 'origin/main' into feat/b300_cu13
VALLIS-NERIA Sep 9, 2025
2e61526
fix
VALLIS-NERIA Sep 10, 2025
0b73a57
refine sm version check
VALLIS-NERIA Sep 10, 2025
27c73de
add a line of comment
VALLIS-NERIA Sep 10, 2025
b8d1ee6
exclude sm70
VALLIS-NERIA Sep 10, 2025
6133354
fix sm check
VALLIS-NERIA Sep 11, 2025
41d3cf6
Merge remote-tracking branch 'origin/main' into feat/b300_cu13
VALLIS-NERIA Sep 11, 2025
ced6e74
[None][infra] Remove WAR on feat branch (#7642)
ZhanruiSunCh Sep 11, 2025
98cbab0
[None][infra] Update images (#7690)
ZhanruiSunCh Sep 11, 2025
514ebc2
remove sm70 from fmha_v2 completely
VALLIS-NERIA Sep 12, 2025
9bd8df7
Merge remote-tracking branch 'origin/main' into feat/b300_cu13
VALLIS-NERIA Sep 12, 2025
ad20048
remove sm72 & 75
VALLIS-NERIA Sep 14, 2025
93195ec
waive
VALLIS-NERIA Sep 15, 2025
98d42f9
Merge remote-tracking branch 'origin/main' into feat/b300_cu13
VALLIS-NERIA Sep 15, 2025
cf74f40
fix testdb
VALLIS-NERIA Sep 15, 2025
d48e82a
fix testdb
VALLIS-NERIA Sep 15, 2025
7657d83
fix
VALLIS-NERIA Sep 15, 2025
0192299
Merge remote-tracking branch 'origin/main' into feat/b300_cu13
VALLIS-NERIA Sep 16, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/DeepGEMM
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Barry-Delaney could you help to check if this DG is what we want? Thx

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The DG version seems fine.
I tried compiling locally for verification, but the build failed here.
Once it got fixed, I can double check on this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just fixed it in latest commit.

Submodule DeepGEMM updated 36 files
+4 −2 README.md
+471 −0 csrc/apis/gemm.hpp
+85 −0 csrc/apis/layout.hpp
+28 −0 csrc/apis/runtime.hpp
+6 −4 csrc/jit/compiler.hpp
+4 −2 csrc/jit/device_runtime.hpp
+1 −1 csrc/jit/handle.hpp
+2 −2 csrc/jit/kernel_runtime.hpp
+6 −3 csrc/jit_kernels/heuristics/common.hpp
+2 −2 csrc/jit_kernels/heuristics/sm100.hpp
+7 −3 csrc/jit_kernels/heuristics/sm90.hpp
+143 −0 csrc/jit_kernels/impls/sm100_bf16_gemm.hpp
+3 −2 csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp
+3 −2 csrc/jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp
+229 −0 csrc/jit_kernels/impls/sm90_bf16_gemm.hpp
+3 −2 csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp
+55 −8 csrc/jit_kernels/impls/smxx_layout.hpp
+6 −399 csrc/python_api.cpp
+10 −3 csrc/utils/exception.hpp
+38 −10 deep_gemm/__init__.py
+6 −5 deep_gemm/include/deep_gemm/common/scheduler.cuh
+76 −0 deep_gemm/include/deep_gemm/common/sm90_utils.cuh
+18 −0 deep_gemm/include/deep_gemm/common/utils.cuh
+495 −1 deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh
+3 −4 deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh
+8 −5 deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh
+341 −1 deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh
+1 −1 deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh
+39 −0 deep_gemm/include/deep_gemm/impls/smxx_layout.cuh
+0 −3 pyproject.toml
+4 −0 setup.py
+34 −22 tests/generators.py
+125 −0 tests/test_bf16.py
+3 −3 tests/test_fp8.py
+29 −17 tests/test_layout.py
+15 −0 tests/test_lazy_init.py
2 changes: 1 addition & 1 deletion 3rdparty/cutlass
Submodule cutlass updated 606 files
2 changes: 1 addition & 1 deletion 3rdparty/json
Submodule json updated 856 files
2 changes: 1 addition & 1 deletion cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ endif()
include_directories(
SYSTEM
${CUDAToolkit_INCLUDE_DIRS}
${CUDAToolkit_INCLUDE_DIRS}/cccl
${CUDNN_ROOT_DIR}/include
$<TARGET_PROPERTY:TensorRT::NvInfer,INTERFACE_INCLUDE_DIRECTORIES>
${3RDPARTY_DIR}/cutlass/include
Expand Down Expand Up @@ -477,7 +478,6 @@ print(os.path.dirname(torch.__file__),end='');"
endif()
endif()
endif()

else()
if(NOT WIN32)
if(NOT USE_CXX11_ABI)
Expand Down
15 changes: 14 additions & 1 deletion cpp/cmake/modules/cuda_configuration.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ function(setup_cuda_architectures)
message(FATAL_ERROR "Unrecognized CUDA architecture: ${CUDA_ARCH}")
endif()
endforeach()
if("103" IN_LIST CMAKE_CUDA_ARCHITECTURES_CLEAN)
list(APPEND CMAKE_CUDA_ARCHITECTURES_CLEAN "100")
endif()
list(REMOVE_DUPLICATES CMAKE_CUDA_ARCHITECTURES_CLEAN)
set(CMAKE_CUDA_ARCHITECTURES_RAW ${CMAKE_CUDA_ARCHITECTURES_CLEAN})
endif()
Expand All @@ -150,6 +153,9 @@ function(setup_cuda_architectures)
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "12.7")
list(APPEND CMAKE_CUDA_ARCHITECTURES_RAW 100 120)
endif()
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "12.9")
list(APPEND CMAKE_CUDA_ARCHITECTURES_RAW 103)
endif()
endif()

# CMAKE_CUDA_ARCHITECTURES_ORIG contains all architectures enabled, without
Expand All @@ -160,7 +166,14 @@ function(setup_cuda_architectures)
${CMAKE_CUDA_ARCHITECTURES_ORIG}
PARENT_SCOPE)

set(ARCHITECTURES_WITH_KERNELS 80 86 89 90 100 120)
set(ARCHITECTURES_WITH_KERNELS
80
86
89
90
100
103
120)
foreach(CUDA_ARCH IN LISTS ARCHITECTURES_WITH_KERNELS)
if(NOT ${CUDA_ARCH} IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG)
add_definitions("-DEXCLUDE_SM_${CUDA_ARCH}")
Expand Down
10 changes: 10 additions & 0 deletions cpp/include/tensorrt_llm/common/cudaUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,16 @@ inline int getSMVersion()
return sm;
}

inline int getSMFamily()
{
int sm = getSMVersion();
if (sm == 100 || sm == 103)
{
return 100;
}
return sm;
}

inline int getDevice()
{
int deviceID{0};
Expand Down
6 changes: 3 additions & 3 deletions cpp/include/tensorrt_llm/deep_gemm/tma_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ constexpr CUtensorMapDataType get_CUtensorMapDataType()
}
}

PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled()
PFN_cuTensorMapEncodeTiled_v12000 get_cuTensorMapEncodeTiled()
{
// Get pointer to `cuTensorMapEncodeTiled`
cudaDriverEntryPointQueryResult driver_status;
Expand All @@ -110,12 +110,12 @@ PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled()

if (driver_status != cudaDriverEntryPointSuccess)
throw std::runtime_error("driver_status != cudaDriverEntryPointSuccess");
return reinterpret_cast<PFN_cuTensorMapEncodeTiled>(cuTensorMapEncodeTiled_ptr);
return reinterpret_cast<PFN_cuTensorMapEncodeTiled_v12000>(cuTensorMapEncodeTiled_ptr);
}

template <typename T>
CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2], uint64_t stride_in_bytes,
uint32_t smem_dim[2], CUtensorMapSwizzle swizzle_type, PFN_cuTensorMapEncodeTiled encode_func = nullptr)
uint32_t smem_dim[2], CUtensorMapSwizzle swizzle_type, PFN_cuTensorMapEncodeTiled_v12000 encode_func = nullptr)
{
CUtensorMap tensor_map{};
constexpr uint32_t rank = 2;
Expand Down
8 changes: 4 additions & 4 deletions cpp/tensorrt_llm/common/attentionOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2531,22 +2531,22 @@ int AttentionOp::initialize() noexcept
if (mFP8ContextFMHA)
{
TLLM_CHECK_WITH_INFO(mEnableContextFMHA, "FP8 FMHA cannot be enabled because Context FMHA is not supported.");
TLLM_CHECK_WITH_INFO(mSM == 89 || mSM == 90 || mSM == 100 || mSM == 120 || mSM == 121,
TLLM_CHECK_WITH_INFO(mSM == 89 || mSM == 90 || mSM == 100 || mSM == 103 || mSM == 120 || mSM == 121,
"FP8 FMHA can only be enabled on sm_89, sm_90, sm_100, sm_120 or sm_121.");
}

// Pre-Check of FP8 Generation MLA.
if (mFP8GenerationMLA)
{
TLLM_CHECK_WITH_INFO(mIsMLAEnabled, "FP8 Generation MLA cannot be enabled because MLA is not supported.");
TLLM_CHECK_WITH_INFO(mSM == 89 || mSM == 90 || mSM == 100 || mSM == 120 || mSM == 121,
TLLM_CHECK_WITH_INFO(mSM == 89 || mSM == 90 || mSM == 100 || mSM == 103 || mSM == 120 || mSM == 121,
"FP8 Generation MLA is supported on Ada, Hopper or Blackwell architecture.");
}

// Check requirements for FP4 output.
TLLM_CHECK_WITH_INFO(!mFuseFp4Quant || mEnableContextFMHA, "Context FMHA must enable if fuse_fp4_quant is enabled");
TLLM_CHECK_WITH_INFO(!mFuseFp4Quant || mSM == 100 || mSM == 120 || mSM == 121,
"fuse_fp4_quant only supports SM100 or SM120 or SM121 devices.");
TLLM_CHECK_WITH_INFO(!mFuseFp4Quant || (mSM == 100 || mSM == 103) || mSM == 120 || mSM == 121,
"fuse_fp4_quant only supports SM100f or SM120 or SM121 devices.");

// Check requirements for FP4 KV cache.
TLLM_CHECK_WITH_INFO(!mKVCacheQuantMode.hasFp4KvCache() || mFP8ContextFMHA,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ enum class CutlassTileConfigSM100 : int
CtaShape128x256x256B = shape_tuple_to_enum(128, 256, 256),
};

using CutlassTileConfigSM103 = CutlassTileConfigSM100;

enum class CutlassTileConfigSM120 : int
{
// Signals that we should run heuristics do choose a config
Expand Down Expand Up @@ -411,16 +413,17 @@ struct CutlassGemmConfig
CutlassGemmConfig(CutlassTileConfigSM100 tile_config_sm100, MainloopScheduleType mainloop_schedule,
EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape,
ClusterShape dynamic_cluster_shape = ClusterShape::Undefined,
ClusterShape fallback_cluster_shape = ClusterShape::Undefined)
ClusterShape fallback_cluster_shape = ClusterShape::Undefined, int sm_version = 100)
: tile_config_sm100(tile_config_sm100)
, mainloop_schedule(mainloop_schedule)
, epilogue_schedule(epilogue_schedule)
, cluster_shape(cluster_shape)
, dynamic_cluster_shape(dynamic_cluster_shape)
, fallback_cluster_shape(fallback_cluster_shape)
, sm_version(100)
, sm_version(sm_version)
, is_tma_warp_specialized(true)
{
assert(sm_version >= 100 && sm_version < 120 && "Expected SM 10x version");
}

CutlassGemmConfig(CutlassTileConfigSM120 tile_config_sm120, MainloopScheduleType mainloop_schedule,
Expand Down
11 changes: 10 additions & 1 deletion cpp/tensorrt_llm/deep_ep/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,14 @@ foreach(CUDA_ARCH IN LISTS CMAKE_CUDA_ARCHITECTURES)
if(${CUDA_ARCH_MAJOR} GREATER_EQUAL 9)
# The FP4-related conversion instructions in DeepEP require SM100a, SM110a,
# or SM120a.
if(${CUDA_ARCH_MAJOR} GREATER_EQUAL 10 AND ${CUDA_ARCH_MINOR} EQUAL 0)
if(${CUDA_ARCH_MAJOR} EQUAL 10 AND ${CUDA_ARCH_MINOR} EQUAL 0)
if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.31)
list(APPEND DEEP_EP_CUDA_ARCHITECTURES "100f${CUDA_ARCH_POSTFIX}")
else()
list(APPEND DEEP_EP_CUDA_ARCHITECTURES "100a${CUDA_ARCH_POSTFIX}"
"103a${CUDA_ARCH_POSTFIX}")
endif()
elseif(${CUDA_ARCH_MAJOR} GREATER_EQUAL 10 AND ${CUDA_ARCH_MINOR} EQUAL 0)
list(APPEND DEEP_EP_CUDA_ARCHITECTURES
"${CUDA_ARCH_MAJOR}${CUDA_ARCH_MINOR}a${CUDA_ARCH_POSTFIX}")
else()
Expand Down Expand Up @@ -134,6 +141,8 @@ ExternalProject_Add(
${DEEP_EP_SOURCE_DIR}/third-party/nvshmem.patch
COMMAND sed "s/TRANSPORT_VERSION_MAJOR 3/TRANSPORT_VERSION_MAJOR 103/" -i
src/CMakeLists.txt
COMMAND sed "s/_STANDARD 11/_STANDARD 17/" -i src/device/CMakeLists.txt
COMMAND sed "s/_STANDARD 11/_STANDARD 17/" -i src/CMakeLists.txt
COMMAND patch -p1 --forward --batch -i
${CMAKE_CURRENT_SOURCE_DIR}/nvshmem_fast_build.patch
CMAKE_CACHE_ARGS
Expand Down
4 changes: 2 additions & 2 deletions cpp/tensorrt_llm/executor/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ DataType Tensor::getDataType() const
case nvinfer1::DataType::kBF16: return DataType::kBF16;
case nvinfer1::DataType::kINT64: return DataType::kINT64;
case nvinfer1::DataType::kINT4: [[fallthrough]] /* do nothing */;
case nvinfer1::DataType::kFP4: /* do nothing */;
case nvinfer1::DataType::kFP4: [[fallthrough]] /* do nothing */;
default: TLLM_THROW("Unsupported data type");
}
TLLM_THROW("Unsupported data type");
}

MemoryType Tensor::getMemoryType() const
Expand Down
32 changes: 24 additions & 8 deletions cpp/tensorrt_llm/kernels/beamSearchKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -134,15 +134,14 @@ void invokeUpdateCacheIndirection(int* tgtCI, int const* srcCI, BeamHypotheses&
sync_check_cuda_error(stream);
}

template <typename T>
__global__ void addCumLogProbs(T* __restrict pStage1LogProbs, float const* __restrict cumLogProbs,
__global__ void addCumLogProbs(float* __restrict pStage1LogProbs, float const* __restrict cumLogProbs,
FinishedState const* finished, int const* endIds, float const* diversityRates,
runtime::SizeType32 const* batchSlots, size_t const nBS, size_t const nBMIn, size_t const nBMOut, size_t const nBM)
{
int const bid = blockIdx.x; // Index of request in batch
runtime::SizeType32 const slot = batchSlots[bid];
float const diversityRate{diversityRates[slot]};
T* pLocalLogProbs = pStage1LogProbs + bid * nBMIn * nBMOut * 2;
float* pLocalLogProbs = pStage1LogProbs + bid * nBMIn * nBMOut * 2;

for (int i = threadIdx.x; i < nBMIn * nBMOut * 2; i += blockDim.x)
{
Expand All @@ -160,13 +159,30 @@ __global__ void addCumLogProbs(T* __restrict pStage1LogProbs, float const* __res
return;
}

template __global__ void addCumLogProbs<float>(float* __restrict pStage1LogProbs, float const* __restrict cumLogProbs,
__global__ void addCumLogProbs(half* __restrict pStage1LogProbs, float const* __restrict cumLogProbs,
FinishedState const* finished, int const* endIds, float const* diversityRates,
runtime::SizeType32 const* batchSlots, size_t const nBS, size_t const nBMIn, size_t const nBMOut, size_t const nBM);
runtime::SizeType32 const* batchSlots, size_t const nBS, size_t const nBMIn, size_t const nBMOut, size_t const nBM)
{
int const bid = blockIdx.x; // Index of request in batch
runtime::SizeType32 const slot = batchSlots[bid];
float const diversityRate{diversityRates[slot]};
half* pLocalLogProbs = pStage1LogProbs + bid * nBMIn * nBMOut * 2;

template __global__ void addCumLogProbs<half>(half* __restrict pStage1LogProbs, float const* __restrict cumLogProbs,
FinishedState const* finished, int const* endIds, float const* diversityRates,
runtime::SizeType32 const* batchSlots, size_t const nBS, size_t const nBMIn, size_t const nBMOut, size_t const nBM);
for (int i = threadIdx.x; i < nBMIn * nBMOut * 2; i += blockDim.x)
{
int const iBMIn = i / (nBMOut * 2);
if (finished[slot * nBMIn + iBMIn].isFinished())
{
pLocalLogProbs[i] += (i == endIds[slot]) ? 1.0f : 0.0f;
}
else
{
// nBM is used in VBWS since `cumLogProbs` is initialized with kMaxBeamWidth earlier than BeamSearchLayer
pLocalLogProbs[i] += cumLogProbs[slot * nBM + iBMIn] + diversityRate * iBMIn;
}
}
return;
}

__global__ void gatherId(int const* __restrict pStage1Id, int* __restrict pStage2Id, size_t const nBS,
size_t const nBMIn, size_t const nBMOut, size_t const nV)
Expand Down
7 changes: 5 additions & 2 deletions cpp/tensorrt_llm/kernels/beamSearchKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,11 @@ void invokeTopkBeamSearch(T const* logProbs, T const* bias, void* workspace, Bea
void invokeUpdateCacheIndirection(int* tgtCI, int const* srcCI, BeamHypotheses& bh,
runtime::SizeType32 const maxAttentionWindow, runtime::SizeType32 sinkTokenLength, cudaStream_t stream);

template <typename T>
__global__ void addCumLogProbs(T* __restrict pStage1Probs, float const* __restrict cumLogProbs,
__global__ void addCumLogProbs(float* __restrict pStage1LogProbs, float const* __restrict cumLogProbs,
FinishedState const* finished, int const* endIds, float const* diversityRates,
runtime::SizeType32 const* batchSlots, size_t const nBS, size_t const nBMIn, size_t const nBMOut, size_t const nBM);

__global__ void addCumLogProbs(half* __restrict pStage1LogProbs, float const* __restrict cumLogProbs,
FinishedState const* finished, int const* endIds, float const* diversityRates,
runtime::SizeType32 const* batchSlots, size_t const nBS, size_t const nBMIn, size_t const nBMOut, size_t const nBM);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ FusedMHARunnerV2::FusedMHARunnerV2(MHARunnerFixedParams fixedParams)
: mFixedParams(fixedParams)
{
TLLM_CHECK_WITH_INFO((mSM == kSM_80 || mSM == kSM_86 || mSM == kSM_89 || mSM == kSM_90 || mSM == kSM_100
|| mSM == kSM_120 || mSM == kSM_121),
|| mSM == kSM_103 || mSM == kSM_120 || mSM == kSM_121),
"Unsupported architecture");
TLLM_CHECK_WITH_INFO((mFixedParams.dataType == DATA_TYPE_FP16 || mFixedParams.dataType == DATA_TYPE_BF16
|| mFixedParams.dataType == DATA_TYPE_E4M3),
Expand Down Expand Up @@ -347,7 +347,7 @@ void FusedMHARunnerV2::setupLaunchParams(MHARunnerParams runnerParams)
bool const isSm8x = (mSM == kSM_86 || mSM == kSM_89);
bool const isSm80 = (mSM == kSM_80);
bool const isSm89 = (mSM == kSM_89);
bool const isSm100 = (mSM == kSM_100);
bool const isSm100f = (mSM == kSM_100 || mSM == kSM_103);
bool const isSm120f = (mSM == kSM_120 || mSM == kSM_121);

// Sliding_or_chunked_causal mask.
Expand Down Expand Up @@ -416,7 +416,7 @@ void FusedMHARunnerV2::setupLaunchParams(MHARunnerParams runnerParams)
// flash attention tiled kernel is faster on Ada and Ampere derivatives when head_size>=256
mLaunchParams.granular_tiling = false;
}
else if (isSm80 || isSm8x || isSm100 || isSm120f)
else if (isSm80 || isSm8x || isSm100f || isSm120f)
{
// otherwise, choose tiled kernel for Ampere/Ada/Gb20x
mLaunchParams.granular_tiling = true;
Expand Down
22 changes: 18 additions & 4 deletions cpp/tensorrt_llm/kernels/cutlass_kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,28 @@ function(process_target target_name enable_hopper enable_blackwell)

if(${enable_blackwell}
AND ("100" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG
OR "103" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG
OR "120" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG
OR "121" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG))
OR "121" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG
))

if("100" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG)
# Both 100 and 103 support these kernels
if("100" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG
OR "103" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG)
# No kernels should be parsed, unless blackwell is specified. This is a
# build time improvement
target_compile_definitions(${target_name}
PUBLIC COMPILE_BLACKWELL_TMA_GEMMS)
target_compile_definitions(${target_name}
PUBLIC COMPILE_BLACKWELL_TMA_GROUPED_GEMMS)
endif()
# SM103 only kernels
if("103" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG)
target_compile_definitions(${target_name}
PUBLIC COMPILE_BLACKWELL_SM103_TMA_GEMMS)
target_compile_definitions(
${target_name} PUBLIC COMPILE_BLACKWELL_SM103_TMA_GROUPED_GEMMS)
endif()
if("120" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG
OR "121" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG)
target_compile_definitions(${target_name}
Expand Down Expand Up @@ -113,6 +124,8 @@ function(add_instantiations library base_dir)
list(LENGTH INSTANTIATIONS_GENERATED_${ARCH} n)
if(${n} GREATER 0)
set(TARGET_NAME "_${library}_instantiations_${ARCH}")
message(
STATUS "Adding target ${TARGET_NAME} with instantiations for ${ARCH}")
add_library(${TARGET_NAME} OBJECT ${INSTANTIATIONS_GENERATED_${ARCH}})
target_link_libraries(${library} PRIVATE ${TARGET_NAME})
set_cuda_architectures(${TARGET_NAME} ${BUILD_ARCHS})
Expand All @@ -125,9 +138,10 @@ function(add_instantiations library base_dir)
endif()
endmacro()

glob_src_create_target(80 "80;86")
glob_src_create_target(80 "80;86;90;100f;120f")
glob_src_create_target(90 90)
glob_src_create_target(100 100f)
glob_src_create_target(103 103)
glob_src_create_target(120 120f)
endfunction()

Expand Down Expand Up @@ -240,7 +254,7 @@ if(USING_OSS_CUTLASS_MOE_GEMM)
process_target(_moe_gemm_hopper_fp4 true false)

add_library(_moe_gemm_fp4 OBJECT ${MOE_GEMM_SRC_CU_FP4})
set_cuda_architectures(_moe_gemm_fp4 100f 120f)
set_cuda_architectures(_moe_gemm_fp4 100f 103 120f)
process_target(_moe_gemm_fp4 false true)

add_library(_moe_gemm_fp8 OBJECT ${MOE_GEMM_SRC_CU_FP8})
Expand Down
Loading
Loading