@@ -567,21 +567,17 @@ int main(int argc, const char** argv)
567567  using  LayoutC = cutlass::layout::RowMajor;
568568  using  LayoutD = cutlass::layout::RowMajor;
569569
570-   using  GmemTiledCopyA = XE_2D_U16x32x32_LD_N ;
571-   using  GmemTiledCopyB = XE_2D_U16x32x32_LD_V ;
570+   using  GmemTiledCopyA = void ;  // XE_LOAD_2D<16, 32, 32> ;
571+   using  GmemTiledCopyB = void ;  // XE_LOAD_2D_VNNI<16, 32, 32> ;
572572
573573  //  Workgroup-level tile
574574  using  TileShape = Shape<_256, _256, _32>;
575575
576-   using  TiledMma =
577-       TiledMMA<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>,
578-                Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>,
579-                Tile<Layout<Shape<_8, _8, _4>, Stride<_1, _32, _8>>,
580-                     Layout<Shape<_16, _4, _4>, Stride<_1, _64, _16>>, _32>>;
576+   using  TiledMma = typename  TiledMMAHelper<MMA_Atom<XE_DPAS_TT<8 , ElementAccumulator, ElementA>>, Layout<TileShape>, Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
581577
582578  constexpr  int  PipelineStages = 2 ;
583579  //  Dispatch to grouped gemm algorithm
584-   using  GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16Group <PipelineStages>;
580+   using  GEMMDispatchPolicy = cutlass::gemm::MainloopXeL1StagedGroup <PipelineStages>;
585581  using  EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group;
586582
587583  using  EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementOutput, ElementComputeEpilogue,
0 commit comments