diff --git a/3rdparty/cutlass b/3rdparty/cutlass index a49a78ffefc..57e3cfb47a2 160000 --- a/3rdparty/cutlass +++ b/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit a49a78ffefc86a87160dfe0ccc3a3a2d1622c918 +Subproject commit 57e3cfb47a2d9e0d46eb6335c3dc411498efa198 diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h index f5e8655b6fd..70b2a430e48 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h @@ -138,7 +138,7 @@ void dispatchMoeGemmFinalDispatchTmaWarpSpecialized(TmaWarpSpecializedGroupedGem TLLM_THROW("Please recompile with support for hopper by passing 90-real as an arch to build_wheel.py."); } #endif - // #ifndef COMPILE_BLACKWELL_SM103_TMA_GROUPED_GEMMS +#ifndef COMPILE_BLACKWELL_SM103_TMA_GROUPED_GEMMS else if constexpr (Arch::kMinComputeCapability == 103) { static std::once_flag flag; @@ -146,14 +146,15 @@ void dispatchMoeGemmFinalDispatchTmaWarpSpecialized(TmaWarpSpecializedGroupedGem []() { TLLM_LOG_WARNING( - "Falling back to sm100f version due to a bug in cutlass." /*"For best performance please recompile with support for blackwell by " - "passing 103-real as an arch to build_wheel.py."*/); + "For best performance please recompile with support for blackwell by " + "passing 103-real as an arch to build_wheel.py."); }); - return dispatchMoeGemmFinalDispatchTmaWarpSpecialized( + dispatchMoeGemmFinalDispatchTmaWarpSpecialized( hopper_input, num_experts, gemm_config, multi_processor_count, stream, occupancy, workspace_size); + return; } -// #endif +#endif #ifndef COMPILE_BLACKWELL_TMA_GROUPED_GEMMS else if constexpr (Arch::kMinComputeCapability >= 100 && Arch::kMinComputeCapability < 120) { diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/python/generate_kernels.py b/cpp/tensorrt_llm/kernels/cutlass_kernels/python/generate_kernels.py index f406e234592..0ac521e44bb 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/python/generate_kernels.py +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/python/generate_kernels.py @@ -3,19 +3,6 @@ import os from itertools import chain, product -file_to_patch = os.path.abspath( - os.path.join( - os.path.dirname(__file__), - "../../../../../3rdparty/cutlass/python/cutlass_library/heuristics_provider.py" - )) -# replace "from library import" to "from cutlass_library.library import" -with open(file_to_patch, "r") as f: - file_contents = f.read() -with open(file_to_patch, "w") as f: - f.write( - file_contents.replace("from library import", - "from cutlass_library.library import")) - from cutlass_library import *