diff --git a/include/cutlass/epilogue/collective/xe_epilogue.hpp b/include/cutlass/epilogue/collective/xe_epilogue.hpp index 05437ab180..ff77756f52 100644 --- a/include/cutlass/epilogue/collective/xe_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_epilogue.hpp @@ -389,7 +389,10 @@ class CollectiveEpilogue< cst_callbacks.begin(); auto acc_frag = recast>(accumulators); - auto trD_compute_frag = recast>(trD_compute); + using FragmentVisit = decltype(cst_callbacks.visit(acc_frag(0), 0, 0, 0)); + constexpr bool IsDirectR2S = cute::is_same_v>; + using RegisterElementD = cute::conditional_t; + auto trD_compute_frag = recast>(trD_compute); Tensor trD = make_tensor(Shape>{}); auto trD_frag = recast>(trD); @@ -423,7 +426,7 @@ class CollectiveEpilogue< if constexpr (is_destination_supported) { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(trD_compute_frag); ++i) { - trD_frag(i) = cutlass::NumericArrayConverter{}(trD_compute_frag(i)); + trD_frag(i) = cutlass::NumericArrayConverter{}(trD_compute_frag(i)); } copy(params.xe_store_d, trD, tCgD(_, epi_m, epi_n)); }