@@ -389,7 +389,10 @@ class CollectiveEpilogue<
389389 cst_callbacks.begin ();
390390
391391 auto acc_frag = recast<Array<ElementCompute, FragmentSize>>(accumulators);
392- auto trD_compute_frag = recast<Array<ElementCompute, FragmentSize>>(trD_compute);
392+ using FragmentVisit = decltype (cst_callbacks.visit (acc_frag (0 ), 0 , 0 , 0 ));
393+ constexpr bool IsDirectR2S = cute::is_same_v<FragmentVisit, Array<ElementD, FragmentSize>>;
394+ using RegisterElementD = cute::conditional_t <!IsDirectR2S, ElementCompute, ElementD>;
395+ auto trD_compute_frag = recast<Array<RegisterElementD, FragmentSize>>(trD_compute);
393396
394397 Tensor trD = make_tensor<ElementOutput>(Shape<Int<FragmentSize>>{});
395398 auto trD_frag = recast<Array<ElementOutput, FragmentSize>>(trD);
@@ -423,7 +426,7 @@ class CollectiveEpilogue<
423426 if constexpr (is_destination_supported) {
424427 CUTLASS_PRAGMA_UNROLL
425428 for (int i = 0 ; i < size (trD_compute_frag); ++i) {
426- trD_frag (i) = cutlass::NumericArrayConverter<ElementOutput, ElementCompute , FragmentSize>{}(trD_compute_frag (i));
429+ trD_frag (i) = cutlass::NumericArrayConverter<ElementOutput, RegisterElementD , FragmentSize>{}(trD_compute_frag (i));
427430 }
428431 copy (params.xe_store_d , trD, tCgD (_, epi_m, epi_n));
429432 }
0 commit comments