Skip to content

Commit 718827f

Browse files
VALLIS-NERIAdominicshanshan
authored andcommitted
[TRTLLM-6286] [feat] Update CUTLASS to 4.2 and enable SM103 group gemm (NVIDIA#7832)
Signed-off-by: Xiwen Yu <[email protected]>
1 parent f97b46c commit 718827f

File tree

3 files changed

+8
-20
lines changed

3 files changed

+8
-20
lines changed

3rdparty/cutlass

Submodule cutlass updated 175 files

cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,22 +138,23 @@ void dispatchMoeGemmFinalDispatchTmaWarpSpecialized(TmaWarpSpecializedGroupedGem
138138
TLLM_THROW("Please recompile with support for hopper by passing 90-real as an arch to build_wheel.py.");
139139
}
140140
#endif
141-
// #ifndef COMPILE_BLACKWELL_SM103_TMA_GROUPED_GEMMS
141+
#ifndef COMPILE_BLACKWELL_SM103_TMA_GROUPED_GEMMS
142142
else if constexpr (Arch::kMinComputeCapability == 103)
143143
{
144144
static std::once_flag flag;
145145
std::call_once(flag,
146146
[]()
147147
{
148148
TLLM_LOG_WARNING(
149-
"Falling back to sm100f version due to a bug in cutlass." /*"For best performance please recompile with support for blackwell by "
150-
"passing 103-real as an arch to build_wheel.py."*/);
149+
"For best performance please recompile with support for blackwell by "
150+
"passing 103-real as an arch to build_wheel.py.");
151151
});
152-
return dispatchMoeGemmFinalDispatchTmaWarpSpecialized<cutlass::arch::Sm100, T, WeightType, OutputType,
153-
EpilogueTag, FUSION, TileShape, ClusterShape>(
152+
dispatchMoeGemmFinalDispatchTmaWarpSpecialized<cutlass::arch::Sm100, T, WeightType, OutputType, EpilogueTag,
153+
FUSION, TileShape, ClusterShape>(
154154
hopper_input, num_experts, gemm_config, multi_processor_count, stream, occupancy, workspace_size);
155+
return;
155156
}
156-
// #endif
157+
#endif
157158
#ifndef COMPILE_BLACKWELL_TMA_GROUPED_GEMMS
158159
else if constexpr (Arch::kMinComputeCapability >= 100 && Arch::kMinComputeCapability < 120)
159160
{

cpp/tensorrt_llm/kernels/cutlass_kernels/python/generate_kernels.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,6 @@
33
import os
44
from itertools import chain, product
55

6-
file_to_patch = os.path.abspath(
7-
os.path.join(
8-
os.path.dirname(__file__),
9-
"../../../../../3rdparty/cutlass/python/cutlass_library/heuristics_provider.py"
10-
))
11-
# replace "from library import" to "from cutlass_library.library import"
12-
with open(file_to_patch, "r") as f:
13-
file_contents = f.read()
14-
with open(file_to_patch, "w") as f:
15-
f.write(
16-
file_contents.replace("from library import",
17-
"from cutlass_library.library import"))
18-
196
from cutlass_library import *
207

218

0 commit comments

Comments
 (0)