Skip to content

Commit 29cae4e

Browse files
authored
[Refactor] Update accumulation handling in gemm_sm90.h (#603)
- Replaced the use of `tiled_mma.accumulate_ = GMMA::ScaleOut::Zero` with a call to `clear(acc)` for better clarity and maintainability in the accumulation logic. - This change enhances the readability of the code by standardizing the approach to clearing accumulation values across multiple sections of the file.
1 parent 19b2252 commit 29cae4e

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/tl_templates/cuda/gemm_sm90.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,7 @@ class GemmTensorOp {
415415
auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
416416
auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
417417
if constexpr (clear_accum) {
418-
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
418+
clear(acc);
419419
}
420420
CUTE_UNROLL
421421
for (int k = 0; k < size<2>(tCrA); ++k) {
@@ -448,7 +448,7 @@ class GemmTensorOp {
448448
partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));
449449
auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
450450
if constexpr (clear_accum) {
451-
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
451+
clear(acc);
452452
}
453453
copy(tiled_copy_B, tCsB(_, _, 0), tCrB_copy_view(_, _, 0));
454454
CUTE_UNROLL
@@ -483,7 +483,7 @@ class GemmTensorOp {
483483
partition_shape_B(tiled_mma, Shape<Int<N>, Int<K>>{}));
484484
auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
485485
if constexpr (clear_accum) {
486-
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
486+
clear(acc);
487487
}
488488
copy(tiled_copy_A, tCsA(_, _, 0), tCrA_copy_view(_, _, 0));
489489
CUTE_UNROLL

0 commit comments

Comments
 (0)