Skip to content

Commit e60bf79

Browse files
committed
Fixing trD compute type in the Xe Epilogue
1 parent 675727d commit e60bf79

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

include/cutlass/epilogue/collective/xe_epilogue.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,10 @@ class CollectiveEpilogue<
389389
cst_callbacks.begin();
390390

391391
auto acc_frag = recast<Array<ElementCompute, FragmentSize>>(accumulators);
392-
auto trD_compute_frag = recast<Array<ElementCompute, FragmentSize>>(trD_compute);
392+
using FragmentVisit = decltype(cst_callbacks.visit(acc_frag(0), 0, 0, 0));
393+
constexpr bool IsDirectR2S = cute::is_same_v<FragmentVisit, Array<ElementD, FragmentSize>>;
394+
using RegisterElementD = cute::conditional_t<!IsDirectR2S, ElementCompute, ElementD>;
395+
auto trD_compute_frag = recast<Array<RegisterElementD, FragmentSize>>(trD_compute);
393396

394397
Tensor trD = make_tensor<ElementOutput>(Shape<Int<FragmentSize>>{});
395398
auto trD_frag = recast<Array<ElementOutput, FragmentSize>>(trD);
@@ -423,7 +426,7 @@ class CollectiveEpilogue<
423426
if constexpr (is_destination_supported) {
424427
CUTLASS_PRAGMA_UNROLL
425428
for (int i = 0; i < size(trD_compute_frag); ++i) {
426-
trD_frag(i) = cutlass::NumericArrayConverter<ElementOutput, ElementCompute, FragmentSize>{}(trD_compute_frag(i));
429+
trD_frag(i) = cutlass::NumericArrayConverter<ElementOutput, RegisterElementD, FragmentSize>{}(trD_compute_frag(i));
427430
}
428431
copy(params.xe_store_d, trD, tCgD(_, epi_m, epi_n));
429432
}

0 commit comments

Comments
 (0)