diff --git a/include/cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl b/include/cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl index adff27740c..17c9e7b84f 100755 --- a/include/cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl @@ -158,6 +158,7 @@ struct CollectiveBuilder< using Sm1xxBlkScaledConfig = cutlass::detail::Sm1xxBlockScaledConfig; using SmemLayoutAtomA = decltype(detail::sm120_rr_smem_selector(TileShape_MNK{}))>()); + using SmemLayoutAtomB = decltype(detail::sm120_rr_smem_selector(TileShape_MNK{}))>()); using SmemCopyAtomA = Copy_Atom, Int>; using kBasicBlockStride = Stride<_0, _1>; - using sSFA_shapeM = decltype(prepend(size<0>(TileShape_MNK{}) / Blk_MN{}, mnBasicBlockShape{})); + // M dimension must be rounded up to at least Blk_MN (128) for TMA and UTCCP to work. + // Use ceil_div to ensure at least 1 block for M < 128 tiles (e.g., 64x128). + static constexpr int SFA_NumBlocks = (cute::size<0>(TileShape_MNK{}) + Blk_MN{} - cute::Int<1>{}) / Blk_MN{}; + using sSFA_shapeM = decltype(prepend(cute::Int{}, mnBasicBlockShape{})); using sSF_strideMN = decltype(prepend( Blk_Elems{}, mnBasicBlockStride{})); using sSFA_strideM = sSF_strideMN; using sSF_shapeK = decltype(prepend(make_shape( Blk_SF{}/Int{}, size<2>(TileShape_MNK{}) / Int{} / Blk_SF{}), kBasicBlockShape{})); - using sSFA_strideK = decltype(prepend(make_stride( Int{}, size<0>(TileShape_MNK{}) / Blk_MN{} * Blk_Elems{}), kBasicBlockStride{})); + using sSFA_strideK = decltype(prepend(make_stride( Int{}, cute::Int{} * Blk_Elems{}), kBasicBlockStride{})); using sSFA_shape = decltype(make_shape( sSFA_shapeM{}, sSF_shapeK{})); using sSFA_stride = decltype(make_stride(sSFA_strideM{}, sSFA_strideK{})); using SmemLayoutAtomSFA = decltype(make_layout( sSFA_shape{}, sSFA_stride{})); diff --git a/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp b/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp index edacf6a4fb..a2dfea26d9 100644 --- a/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp @@ -452,7 +452,7 @@ struct CollectiveMma< stride_a = InternalStrideA{}; stride_b = InternalStrideB{}; layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(init_M, init_N, init_K, 1)); - layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(init_M, init_N, init_K, 1)); + layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(init_M, init_N, init_K, 1)); } else { // Tensor shapes for Ptr-Array are initialized correctly only here. diff --git a/include/cutlass/gemm/collective/sm103_blockscaled_mma_array_warpspecialized.hpp b/include/cutlass/gemm/collective/sm103_blockscaled_mma_array_warpspecialized.hpp index 2030b9a5db..51195b2369 100644 --- a/include/cutlass/gemm/collective/sm103_blockscaled_mma_array_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm103_blockscaled_mma_array_warpspecialized.hpp @@ -452,7 +452,7 @@ struct CollectiveMma< stride_a = InternalStrideA{}; stride_b = InternalStrideB{}; layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(init_M, init_N, init_K, 1)); - layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(init_M, init_N, init_K, 1)); + layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(init_M, init_N, init_K, 1)); } else { // Tensor shapes for Ptr-Array are initialized correctly only here. diff --git a/include/cutlass/gemm/collective/sm120_blockscaled_mma_array_tma.hpp b/include/cutlass/gemm/collective/sm120_blockscaled_mma_array_tma.hpp index 458ee1af49..e5cce92b5b 100755 --- a/include/cutlass/gemm/collective/sm120_blockscaled_mma_array_tma.hpp +++ b/include/cutlass/gemm/collective/sm120_blockscaled_mma_array_tma.hpp @@ -132,6 +132,15 @@ struct CollectiveMma< static constexpr int SFVecSize = TiledMma::Traits::SFVecSize; using Sm1xxBlkScaledConfig = cutlass::detail::Sm1xxBlockScaledConfig; + // Blk_MN is the minimum block size for scale factors (128) + using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; + + // Scale factor A tile shape - M dimension padded to at least 128 for TMA + // When TileShape has M < 128 (e.g., 64), we pad to 128 for TMA descriptor compatibility + static constexpr int TileM_SFA = (cute::size<0>(TileShape{}) + Blk_MN{} - cute::Int<1>{}) / Blk_MN{} * Blk_MN{}; + using TileShape_SFA = decltype(cute::make_shape(cute::Int{}, cute::size<2>(TileShape{}))); + static constexpr bool IsCtaMSmall = cute::size<0>(TileShape{}) < 128; + // Gmem copies using GmemTiledCopyPairA = GmemTiledCopyPairA_; using GmemTiledCopyPairB = GmemTiledCopyPairB_; @@ -309,11 +318,12 @@ struct CollectiveMma< make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), _1{})); // No programmatic multicast + // TMA for scale factor A - use padded tile shape (TileShape_SFA) for M < 128 compatibility using TMA_SFA = decltype(make_tma_copy( GmemTiledCopySFA{}, make_tensor(static_cast(nullptr), InternalLayoutSFA{}), SmemLayoutSFA{}(_,_,cute::Int<0>{}), - make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + TileShape_SFA{}, _1{})); // No programmatic multicast @@ -370,10 +380,40 @@ struct CollectiveMma< if constexpr (IsGroupedGemmKernel) { // Strides for Grouped Gemm will be replaced prior to the first access regardless. - stride_a = InternalStrideA{}; - stride_b = InternalStrideB{}; + // However, TMA descriptor encoding requires valid non-zero strides. + // Use tile dimensions as placeholder values for the runtime stride components. + // The compile-time components (Int<1>, Int<0>) are preserved from the type. + // + // For RowMajor A [M,K,L]: stride = (K, Int<1>, Int<0>) → runtime component at index 0 + // For RowMajor B [N,K,L]: stride = (Int<1>, N, Int<0>) → runtime component at index 1 + // + // We detect which component is runtime by checking if get returns a non-Int type. + stride_a = [&]() { + InternalStrideA s{}; + // For A: stride pattern is typically (runtime, Int<1>, Int<0>) for RowMajor + // Set the runtime component to init_K (stride between M rows) + if constexpr (!cute::is_static_v(s))>) { + return InternalStrideA(init_K, cute::get<1>(s), cute::get<2>(s)); + } else if constexpr (!cute::is_static_v(s))>) { + return InternalStrideA(cute::get<0>(s), init_M, cute::get<2>(s)); + } else { + return s; // All static, use as-is + } + }(); + stride_b = [&]() { + InternalStrideB s{}; + // For B: stride pattern is typically (Int<1>, runtime, Int<0>) for RowMajor + // Set the runtime component to init_N (stride between K columns) + if constexpr (!cute::is_static_v(s))>) { + return InternalStrideB(init_K, cute::get<1>(s), cute::get<2>(s)); + } else if constexpr (!cute::is_static_v(s))>) { + return InternalStrideB(cute::get<0>(s), init_N, cute::get<2>(s)); + } else { + return s; // All static, use as-is + } + }(); layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(init_M, init_N, init_K, 1)); - layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(init_M, init_N, init_K, 1)); + layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(init_M, init_N, init_K, 1)); } else { // Tensor shapes for Ptr-Array are initialized correctly only here. @@ -410,7 +450,7 @@ struct CollectiveMma< GmemTiledCopySFA{}, tensor_sfa, SmemLayoutSFA{}(_,_,cute::Int<0>{}), - make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + TileShape_SFA{}, _1{}); // No programmatic multicast typename Params::TMA_SFB tma_load_sfb = make_tma_copy( @@ -863,7 +903,9 @@ struct CollectiveMma< CUTE_STATIC_ASSERT_V(size<1>(tCsSFA) == size<1>(tCrSFA_copy_view)); // CPY_M CUTE_STATIC_ASSERT_V(size<2>(tCsSFA) == size<2>(tCrSFA_copy_view)); // CPY_K - CUTE_STATIC_ASSERT_V(size<1>(tCrSFA) == size<1>(accum)); // MMA_M + // For small CTA M, SFA layout is padded to 128 but accumulator is smaller. + // Skip MMA_M size assertion only when M < 128. + if constexpr (!IsCtaMSmall) { CUTE_STATIC_ASSERT_V(size<1>(tCrSFA) == size<1>(accum)); } // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrSFB) == size<2>(accum)); // MMA_N CUTE_STATIC_ASSERT_V(size<2>(tCsSFA) == size<2>(tCsSFB)); // CPY_K CUTE_STATIC_ASSERT_V(size<3>(tCsSFA) == size<3>(tCsSFB)); // PIPE diff --git a/include/cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp b/include/cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp index 9cb805188e..747d9d41a3 100755 --- a/include/cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp +++ b/include/cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp @@ -127,6 +127,15 @@ struct CollectiveMma< static constexpr int SFVecSize = TiledMma::Traits::SFVecSize; using Sm1xxBlkScaledConfig = cutlass::detail::Sm1xxBlockScaledConfig; + // Blk_MN is the minimum block size for scale factors (128) + using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; + + // Scale factor A tile shape - M dimension padded to at least 128 for TMA + // When TileShape has M < 128 (e.g., 64), we pad to 128 for TMA descriptor compatibility + static constexpr int TileM_SFA = (cute::size<0>(TileShape{}) + Blk_MN{} - cute::Int<1>{}) / Blk_MN{} * Blk_MN{}; + using TileShape_SFA = decltype(cute::make_shape(cute::Int{}, cute::size<2>(TileShape{}))); + static constexpr bool IsCtaMSmall = cute::size<0>(TileShape{}) < 128; + // Gmem copies using GmemTiledCopyPairA = GmemTiledCopyPairA_; using GmemTiledCopyPairB = GmemTiledCopyPairB_; @@ -295,11 +304,12 @@ struct CollectiveMma< make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), _1{})); // No programmatic multicast + // TMA for scale factor A - use padded tile shape (TileShape_SFA) for M < 128 compatibility using TMA_SFA = decltype(make_tma_copy( GmemTiledCopySFA{}, make_tensor(static_cast(nullptr), LayoutSFA{}), SmemLayoutSFA{}(_,_,cute::Int<0>{}), - make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + TileShape_SFA{}, _1{})); // No programmatic multicast @@ -360,7 +370,7 @@ struct CollectiveMma< GmemTiledCopySFA{}, tensor_sfa, SmemLayoutSFA{}(_,_,cute::Int<0>{}), - make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + TileShape_SFA{}, _1{}); // No programmatic multicast typename Params::TMA_SFB tma_load_sfb = make_tma_copy( @@ -778,7 +788,9 @@ struct CollectiveMma< CUTE_STATIC_ASSERT_V(size<1>(tCsSFA) == size<1>(tCrSFA_copy_view)); // CPY_M CUTE_STATIC_ASSERT_V(size<2>(tCsSFA) == size<2>(tCrSFA_copy_view)); // CPY_K - CUTE_STATIC_ASSERT_V(size<1>(tCrSFA) == size<1>(accum)); // MMA_M + // For small CTA M, SFA layout is padded to 128 but accumulator is smaller. + // Skip MMA_M size assertion only when M < 128. + if constexpr (!IsCtaMSmall) { CUTE_STATIC_ASSERT_V(size<1>(tCrSFA) == size<1>(accum)); } // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrSFB) == size<2>(accum)); // MMA_N CUTE_STATIC_ASSERT_V(size<2>(tCsSFA) == size<2>(tCsSFB)); // CPY_K CUTE_STATIC_ASSERT_V(size<3>(tCsSFA) == size<3>(tCsSFB)); // PIPE diff --git a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp index e1fa1c86b4..9df2643d41 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp @@ -610,7 +610,7 @@ class GemmUniversal< // Consumer1 is not on the critical path at prologue. if (warp_group_role == WarpGroupRole::Consumer1) [[unlikely]] { // Advance 2nd Math WG to the next work tile for the startup - const auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info, tile_scheduler_pipeline, tile_scheduler_pipe_consumer_state); work_tile_info = next_work_tile_info; diff --git a/test/unit/gemm/device/CMakeLists.txt b/test/unit/gemm/device/CMakeLists.txt index 992b8094ef..afb90ab326 100644 --- a/test/unit/gemm/device/CMakeLists.txt +++ b/test/unit/gemm/device/CMakeLists.txt @@ -59,6 +59,7 @@ add_subdirectory(sm120_blockscaled_sparse_tensorop_gemm) add_subdirectory(sm120_sparse_tensorop_gemm) add_subdirectory(sm120_tensorop_gemm) add_subdirectory(sm120_blockscaled_tensorop_gemm) +add_subdirectory(sm121_blockscaled_tensorop_gemm) cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_simt diff --git a/test/unit/gemm/device/sm121_blockscaled_tensorop_gemm/CMakeLists.txt b/test/unit/gemm/device/sm121_blockscaled_tensorop_gemm/CMakeLists.txt new file mode 100644 index 0000000000..c6bd7da860 --- /dev/null +++ b/test/unit/gemm/device/sm121_blockscaled_tensorop_gemm/CMakeLists.txt @@ -0,0 +1,42 @@ +# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +if (CUTLASS_NVCC_ARCHS MATCHES 121a) + +add_custom_target( + cutlass_test_unit_gemm_device_sm121_bs + DEPENDS + cutlass_test_unit_bs_grouped_gemm_device_tensorop_sm121 +) + +cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_bs_grouped_gemm_device_tensorop_sm121 + ../sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_mxf8_mxf4_f32_group_gemm_fusion.cu +) + +endif()