Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 31 additions & 16 deletions src/layout/gemm_layouts.cc
Original file line number Diff line number Diff line change
Expand Up @@ -757,19 +757,28 @@ Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous,
return makeGemmABLayoutF64_Kouter(mat_stride, mat_continuous);
if (k_inner && continuity % 16 == 0) // float64 NxK
return makeGemmABLayoutF64_Kinner(mat_stride, mat_continuous);
// fallback for float64 when stride % 8 != 0
if (mat_stride % 8 != 0)
return makeLinearLayout(
Array<PrimExpr>{Integer(mat_stride), Integer(mat_continuous)});
return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous,
element_size);
}
int vector_size = 128 / element_size;

if (mat_continuous % (vector_size * 8) == 0)
return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size);
else if (mat_continuous % (vector_size * 4) == 0)
return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size);
else if (mat_continuous % (vector_size * 2) == 0)
return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous,
element_size);
else if (mat_continuous % vector_size == 0)
if (mat_stride % 8 == 0) {
if (mat_continuous % (vector_size * 8) == 0)
return makeFullBankSwizzleLayout(mat_stride, mat_continuous,
element_size);
else if (mat_continuous % (vector_size * 4) == 0)
return makeHalfBankSwizzleLayout(mat_stride, mat_continuous,
element_size);
else if (mat_continuous % (vector_size * 2) == 0)
return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous,
element_size);
}

if (mat_continuous % vector_size == 0)
return makeLinearLayout(
Array<PrimExpr>{Integer(mat_stride), Integer(mat_continuous)});
else
Expand All @@ -784,14 +793,20 @@ Layout makeGemmABLayoutSm100(int mat_stride, int mat_continuous, int continuity,
ICHECK(0) << "float64 on sm100 is not supported now";
}
int vector_size = 128 / element_size;
if (mat_continuous % (vector_size * 8) == 0)
return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size);
else if (mat_continuous % (vector_size * 4) == 0)
return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size);
else if (mat_continuous % (vector_size * 2) == 0)
return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous,
element_size);
else if (mat_continuous % vector_size == 0)

if (mat_stride % 8 == 0) {
if (mat_continuous % (vector_size * 8) == 0)
return makeFullBankSwizzleLayout(mat_stride, mat_continuous,
element_size);
else if (mat_continuous % (vector_size * 4) == 0)
return makeHalfBankSwizzleLayout(mat_stride, mat_continuous,
element_size);
else if (mat_continuous % (vector_size * 2) == 0)
return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous,
element_size);
}

if (mat_continuous % vector_size == 0)
return makeLinearLayout(
Array<PrimExpr>{Integer(mat_stride), Integer(mat_continuous)});
else
Expand Down
Loading