Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ struct CollectiveBuilder<
using Sm1xxBlkScaledConfig = cutlass::detail::Sm1xxBlockScaledConfig<SFVectorSize>;

using SmemLayoutAtomA = decltype(detail::sm120_rr_smem_selector<SmemAllocTypeA, decltype(size<2>(TileShape_MNK{}))>());

using SmemLayoutAtomB = decltype(detail::sm120_rr_smem_selector<SmemAllocTypeB, decltype(size<2>(TileShape_MNK{}))>());

using SmemCopyAtomA = Copy_Atom<decltype(detail::sm120_rr_smem_copy_selector_A<ElementA,
Expand Down Expand Up @@ -189,12 +190,15 @@ struct CollectiveBuilder<
using kBasicBlockShape = Shape<Int<SFVectorSize>, Int<MMA_NSF>>;
using kBasicBlockStride = Stride<_0, _1>;

using sSFA_shapeM = decltype(prepend(size<0>(TileShape_MNK{}) / Blk_MN{}, mnBasicBlockShape{}));
// M dimension must be rounded up to at least Blk_MN (128) for TMA and UTCCP to work.
// Use ceil_div to ensure at least 1 block for M < 128 tiles (e.g., 64x128).
static constexpr int SFA_NumBlocks = (cute::size<0>(TileShape_MNK{}) + Blk_MN{} - cute::Int<1>{}) / Blk_MN{};
using sSFA_shapeM = decltype(prepend(cute::Int<SFA_NumBlocks>{}, mnBasicBlockShape{}));
using sSF_strideMN = decltype(prepend( Blk_Elems{}, mnBasicBlockStride{}));
using sSFA_strideM = sSF_strideMN;
using sSF_shapeK = decltype(prepend(make_shape( Blk_SF{}/Int<MMA_NSF>{}, size<2>(TileShape_MNK{}) / Int<SFVectorSize>{} / Blk_SF{}), kBasicBlockShape{}));

using sSFA_strideK = decltype(prepend(make_stride( Int<MMA_NSF>{}, size<0>(TileShape_MNK{}) / Blk_MN{} * Blk_Elems{}), kBasicBlockStride{}));
using sSFA_strideK = decltype(prepend(make_stride( Int<MMA_NSF>{}, cute::Int<SFA_NumBlocks>{} * Blk_Elems{}), kBasicBlockStride{}));
using sSFA_shape = decltype(make_shape( sSFA_shapeM{}, sSF_shapeK{}));
using sSFA_stride = decltype(make_stride(sSFA_strideM{}, sSFA_strideK{}));
using SmemLayoutAtomSFA = decltype(make_layout( sSFA_shape{}, sSFA_stride{}));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ struct CollectiveMma<
stride_a = InternalStrideA{};
stride_b = InternalStrideB{};
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(init_M, init_N, init_K, 1));
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(init_M, init_N, init_K, 1));
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(init_M, init_N, init_K, 1));
}
else {
// Tensor shapes for Ptr-Array are initialized correctly only here.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ struct CollectiveMma<
stride_a = InternalStrideA{};
stride_b = InternalStrideB{};
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(init_M, init_N, init_K, 1));
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(init_M, init_N, init_K, 1));
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(init_M, init_N, init_K, 1));
}
else {
// Tensor shapes for Ptr-Array are initialized correctly only here.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,15 @@ struct CollectiveMma<
static constexpr int SFVecSize = TiledMma::Traits::SFVecSize;
using Sm1xxBlkScaledConfig = cutlass::detail::Sm1xxBlockScaledConfig<SFVecSize>;

// Blk_MN is the minimum block size for scale factors (128)
using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN;

// Scale factor A tile shape - M dimension padded to at least 128 for TMA
// When TileShape has M < 128 (e.g., 64), we pad to 128 for TMA descriptor compatibility
static constexpr int TileM_SFA = (cute::size<0>(TileShape{}) + Blk_MN{} - cute::Int<1>{}) / Blk_MN{} * Blk_MN{};
using TileShape_SFA = decltype(cute::make_shape(cute::Int<TileM_SFA>{}, cute::size<2>(TileShape{})));
static constexpr bool IsCtaMSmall = cute::size<0>(TileShape{}) < 128;

// Gmem copies
using GmemTiledCopyPairA = GmemTiledCopyPairA_;
using GmemTiledCopyPairB = GmemTiledCopyPairB_;
Expand Down Expand Up @@ -309,11 +318,12 @@ struct CollectiveMma<
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
_1{})); // No programmatic multicast

// TMA for scale factor A - use padded tile shape (TileShape_SFA) for M < 128 compatibility
using TMA_SFA = decltype(make_tma_copy<uint16_t>(
GmemTiledCopySFA{},
make_tensor(static_cast<ElementSF const*>(nullptr), InternalLayoutSFA{}),
SmemLayoutSFA{}(_,_,cute::Int<0>{}),
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
TileShape_SFA{},
_1{})); // No programmatic multicast


Expand Down Expand Up @@ -370,10 +380,40 @@ struct CollectiveMma<

if constexpr (IsGroupedGemmKernel) {
// Strides for Grouped Gemm will be replaced prior to the first access regardless.
stride_a = InternalStrideA{};
stride_b = InternalStrideB{};
// However, TMA descriptor encoding requires valid non-zero strides.
// Use tile dimensions as placeholder values for the runtime stride components.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this change needed?
TMA descriptors that are being created on host (here in to_underlying_arguments) are mainly to just properly initialize kernel dependent parameters, but global tensor pointer, shapes and strides will be correctly updated on device - based on the group - prior to any usage.
Were you running into any issues with this?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@christopherowen please check

// The compile-time components (Int<1>, Int<0>) are preserved from the type.
//
// For RowMajor A [M,K,L]: stride = (K, Int<1>, Int<0>) → runtime component at index 0
// For RowMajor B [N,K,L]: stride = (Int<1>, N, Int<0>) → runtime component at index 1
//
// We detect which component is runtime by checking if get<i> returns a non-Int type.
stride_a = [&]() {
InternalStrideA s{};
// For A: stride pattern is typically (runtime, Int<1>, Int<0>) for RowMajor
// Set the runtime component to init_K (stride between M rows)
if constexpr (!cute::is_static_v<decltype(cute::get<0>(s))>) {
return InternalStrideA(init_K, cute::get<1>(s), cute::get<2>(s));
} else if constexpr (!cute::is_static_v<decltype(cute::get<1>(s))>) {
return InternalStrideA(cute::get<0>(s), init_M, cute::get<2>(s));
} else {
return s; // All static, use as-is
}
}();
stride_b = [&]() {
InternalStrideB s{};
// For B: stride pattern is typically (Int<1>, runtime, Int<0>) for RowMajor
// Set the runtime component to init_N (stride between K columns)
if constexpr (!cute::is_static_v<decltype(cute::get<0>(s))>) {
return InternalStrideB(init_K, cute::get<1>(s), cute::get<2>(s));
} else if constexpr (!cute::is_static_v<decltype(cute::get<1>(s))>) {
return InternalStrideB(cute::get<0>(s), init_N, cute::get<2>(s));
} else {
return s; // All static, use as-is
}
}();
layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(init_M, init_N, init_K, 1));
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(init_M, init_N, init_K, 1));
layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(init_M, init_N, init_K, 1));
}
else {
// Tensor shapes for Ptr-Array are initialized correctly only here.
Expand Down Expand Up @@ -410,7 +450,7 @@ struct CollectiveMma<
GmemTiledCopySFA{},
tensor_sfa,
SmemLayoutSFA{}(_,_,cute::Int<0>{}),
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
TileShape_SFA{},
_1{}); // No programmatic multicast

typename Params::TMA_SFB tma_load_sfb = make_tma_copy<uint16_t>(
Expand Down Expand Up @@ -863,7 +903,9 @@ struct CollectiveMma<

CUTE_STATIC_ASSERT_V(size<1>(tCsSFA) == size<1>(tCrSFA_copy_view)); // CPY_M
CUTE_STATIC_ASSERT_V(size<2>(tCsSFA) == size<2>(tCrSFA_copy_view)); // CPY_K
CUTE_STATIC_ASSERT_V(size<1>(tCrSFA) == size<1>(accum)); // MMA_M
// For small CTA M, SFA layout is padded to 128 but accumulator is smaller.
// Skip MMA_M size assertion only when M < 128.
if constexpr (!IsCtaMSmall) { CUTE_STATIC_ASSERT_V(size<1>(tCrSFA) == size<1>(accum)); } // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrSFB) == size<2>(accum)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCsSFA) == size<2>(tCsSFB)); // CPY_K
CUTE_STATIC_ASSERT_V(size<3>(tCsSFA) == size<3>(tCsSFB)); // PIPE
Expand Down
18 changes: 15 additions & 3 deletions include/cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,15 @@ struct CollectiveMma<
static constexpr int SFVecSize = TiledMma::Traits::SFVecSize;
using Sm1xxBlkScaledConfig = cutlass::detail::Sm1xxBlockScaledConfig<SFVecSize>;

// Blk_MN is the minimum block size for scale factors (128)
using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN;

// Scale factor A tile shape - M dimension padded to at least 128 for TMA
// When TileShape has M < 128 (e.g., 64), we pad to 128 for TMA descriptor compatibility
static constexpr int TileM_SFA = (cute::size<0>(TileShape{}) + Blk_MN{} - cute::Int<1>{}) / Blk_MN{} * Blk_MN{};
using TileShape_SFA = decltype(cute::make_shape(cute::Int<TileM_SFA>{}, cute::size<2>(TileShape{})));
static constexpr bool IsCtaMSmall = cute::size<0>(TileShape{}) < 128;

// Gmem copies
using GmemTiledCopyPairA = GmemTiledCopyPairA_;
using GmemTiledCopyPairB = GmemTiledCopyPairB_;
Expand Down Expand Up @@ -295,11 +304,12 @@ struct CollectiveMma<
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
_1{})); // No programmatic multicast

// TMA for scale factor A - use padded tile shape (TileShape_SFA) for M < 128 compatibility
using TMA_SFA = decltype(make_tma_copy<uint16_t>(
GmemTiledCopySFA{},
make_tensor(static_cast<ElementSF const*>(nullptr), LayoutSFA{}),
SmemLayoutSFA{}(_,_,cute::Int<0>{}),
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
TileShape_SFA{},
_1{})); // No programmatic multicast


Expand Down Expand Up @@ -360,7 +370,7 @@ struct CollectiveMma<
GmemTiledCopySFA{},
tensor_sfa,
SmemLayoutSFA{}(_,_,cute::Int<0>{}),
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
TileShape_SFA{},
_1{}); // No programmatic multicast

typename Params::TMA_SFB tma_load_sfb = make_tma_copy<uint16_t>(
Expand Down Expand Up @@ -778,7 +788,9 @@ struct CollectiveMma<

CUTE_STATIC_ASSERT_V(size<1>(tCsSFA) == size<1>(tCrSFA_copy_view)); // CPY_M
CUTE_STATIC_ASSERT_V(size<2>(tCsSFA) == size<2>(tCrSFA_copy_view)); // CPY_K
CUTE_STATIC_ASSERT_V(size<1>(tCrSFA) == size<1>(accum)); // MMA_M
// For small CTA M, SFA layout is padded to 128 but accumulator is smaller.
// Skip MMA_M size assertion only when M < 128.
if constexpr (!IsCtaMSmall) { CUTE_STATIC_ASSERT_V(size<1>(tCrSFA) == size<1>(accum)); } // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrSFB) == size<2>(accum)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCsSFA) == size<2>(tCsSFB)); // CPY_K
CUTE_STATIC_ASSERT_V(size<3>(tCsSFA) == size<3>(tCsSFB)); // PIPE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ class GemmUniversal<
// Consumer1 is not on the critical path at prologue.
if (warp_group_role == WarpGroupRole::Consumer1) [[unlikely]] {
// Advance 2nd Math WG to the next work tile for the startup
const auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape);
auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape);
auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work(work_tile_info, tile_scheduler_pipeline, tile_scheduler_pipe_consumer_state);
work_tile_info = next_work_tile_info;
Expand Down
1 change: 1 addition & 0 deletions test/unit/gemm/device/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ add_subdirectory(sm120_blockscaled_sparse_tensorop_gemm)
add_subdirectory(sm120_sparse_tensorop_gemm)
add_subdirectory(sm120_tensorop_gemm)
add_subdirectory(sm120_blockscaled_tensorop_gemm)
add_subdirectory(sm121_blockscaled_tensorop_gemm)

cutlass_test_unit_gemm_device_add_executable(
cutlass_test_unit_gemm_device_simt
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

if (CUTLASS_NVCC_ARCHS MATCHES 121a)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add copyright to this file.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this. I’ve added the standard NVIDIA copyright header.


add_custom_target(
cutlass_test_unit_gemm_device_sm121_bs
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Porting some internal review comments:

Let's just spell out blockscaled and not call the kernel bs.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DEPENDS
cutlass_test_unit_bs_grouped_gemm_device_tensorop_sm121
)

cutlass_test_unit_gemm_device_add_executable(
cutlass_test_unit_bs_grouped_gemm_device_tensorop_sm121
../sm120_blockscaled_tensorop_gemm/sm120_bs_gemm_mxf8_mxf4_f32_group_gemm_fusion.cu
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Porting some internal review comments:

It seems like this PR added support for CTA_TILE_M == 64 which is smaller than 128. Can we add one unit test in this file to test this case?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

)

endif()