@@ -118,12 +118,12 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
118118                           make_identity_tensor (select<0 ,1 >(TiledMMA{}.tile_mnk()))));
119119
120120  using  FragS = FragC<TiledMMAQK>;
121-   using  FragSRow = decltype (reduce<1 >(FragS{}, sycl::plus{}));
121+   using  FragSRow = decltype (reduce<1 >(FragS{}, sycl::plus< void > {}));
122122  using  ElementS = typename  TiledMMAQK::ValTypeD;
123123
124124  using  SingleFragA = FragC<TiledMMAPV>;                          //  (atom val,q',v')
125125  using  FragA = expand_sg_fragment_t <SingleFragA, 1 , VTiles>;     //  (atom val,q',v',VV)
126-   using  FragARow = decltype (reduce<1 >(FragA{}, sycl::plus{}));
126+   using  FragARow = decltype (reduce<1 >(FragA{}, sycl::plus< void > {}));
127127  using  ElementA = typename  TiledMMAPV::ValTypeD;
128128
129129  static  constexpr  bool  CausalMask = CausalMask_;
@@ -293,9 +293,11 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
293293      if  (check_remainder_k && K == blk_k1 - 1 ) {
294294        FragSRow k_rem_mask;
295295        int  k = get<0 >(tKgK (0 ,0 ,0 ,K,0 )) + get_sub_group ().get_local_id ()[0 ];
296+         CUTLASS_PRAGMA_UNROLL
296297        for  (int  i = 0 ; i < k_rem_mask.size (); i++, k += intel::sg_size) {
297298          k_rem_mask (i) = (k < shape<0 >(K_2D)) ? ElementS (sycl::nan (0u )) : ElementS (-INFINITY);
298299        }
300+         CUTLASS_PRAGMA_UNROLL
299301        for  (int  i = 0 ; i < tSrS.size (); i++) {
300302          tSrS (i) = sycl::fmin (tSrS (i), broadcast<1 >(k_rem_mask, tSrS, i));
301303        }
@@ -309,7 +311,7 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
309311#if  0 
310312      reorder(tSrS, tArP);
311313#else 
312-       for  (int  i = 0 ; i < tArP.size (); i++)
314+       for  (int  i = 0 ; i < tArP.size (); i++)    //  SYCL compiler currently is not correctly handling the above reorder. 
313315        tArP (i) = static_cast <typename  TiledMMAPV::ValTypeA>(tSrS (i));
314316#endif 
315317
@@ -370,7 +372,7 @@ struct FMHAFwdMainloop<XeDefault<Stages>, CausalMask_,
370372    }
371373
372374    /*  Update sums */ 
373-     auto  tS_bsum = reduce<1 >(tS, sycl::plus{});
375+     auto  tS_bsum = reduce<1 >(tS, sycl::plus< void > {});
374376    for  (int  i = 0 ; i < tS_sum.size (); i++)
375377      tS_sum (i) += tS_bsum (i);
376378  }
0 commit comments