Skip to content

Commit 4931e87

Browse files
committed
grouped gemm with new APIs
1 parent 1db79a9 commit 4931e87

File tree

8 files changed

+1056
-63
lines changed

8 files changed

+1056
-63
lines changed

examples/04_bmg_grouped_gemm/04_bmg_grouped_gemm.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -567,21 +567,17 @@ int main(int argc, const char** argv)
567567
using LayoutC = cutlass::layout::RowMajor;
568568
using LayoutD = cutlass::layout::RowMajor;
569569

570-
using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
571-
using GmemTiledCopyB = XE_2D_U16x32x32_LD_V;
570+
using GmemTiledCopyA = void; //XE_LOAD_2D<16, 32, 32>;
571+
using GmemTiledCopyB = void; //XE_LOAD_2D_VNNI<16, 32, 32>;
572572

573573
// Workgroup-level tile
574574
using TileShape = Shape<_256, _256, _32>;
575575

576-
using TiledMma =
577-
TiledMMA<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>,
578-
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>,
579-
Tile<Layout<Shape<_8, _8, _4>, Stride<_1, _32, _8>>,
580-
Layout<Shape<_16, _4, _4>, Stride<_1, _64, _16>>, _32>>;
576+
using TiledMma = typename TiledMMAHelper<MMA_Atom<XE_DPAS_TT<8, ElementAccumulator, ElementA>>, Layout<TileShape>, Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
581577

582578
constexpr int PipelineStages = 2;
583579
// Dispatch to grouped gemm algorithm
584-
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16Group<PipelineStages>;
580+
using GEMMDispatchPolicy = cutlass::gemm::MainloopXeL1StagedGroup<PipelineStages>;
585581
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group;
586582

587583
using EpilogueOp = cutlass::epilogue::fusion::LinearCombination<ElementOutput, ElementComputeEpilogue,

0 commit comments

Comments
 (0)