diff --git a/src/layout/gemm_layouts.cc b/src/layout/gemm_layouts.cc index 501241fc8..97fcf2069 100644 --- a/src/layout/gemm_layouts.cc +++ b/src/layout/gemm_layouts.cc @@ -757,19 +757,28 @@ Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous, return makeGemmABLayoutF64_Kouter(mat_stride, mat_continuous); if (k_inner && continuity % 16 == 0) // float64 NxK return makeGemmABLayoutF64_Kinner(mat_stride, mat_continuous); + // fallback for float64 when stride % 8 != 0 + if (mat_stride % 8 != 0) + return makeLinearLayout( + Array{Integer(mat_stride), Integer(mat_continuous)}); return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous, element_size); } int vector_size = 128 / element_size; - if (mat_continuous % (vector_size * 8) == 0) - return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size); - else if (mat_continuous % (vector_size * 4) == 0) - return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size); - else if (mat_continuous % (vector_size * 2) == 0) - return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous, - element_size); - else if (mat_continuous % vector_size == 0) + if (mat_stride % 8 == 0) { + if (mat_continuous % (vector_size * 8) == 0) + return makeFullBankSwizzleLayout(mat_stride, mat_continuous, + element_size); + else if (mat_continuous % (vector_size * 4) == 0) + return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, + element_size); + else if (mat_continuous % (vector_size * 2) == 0) + return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous, + element_size); + } + + if (mat_continuous % vector_size == 0) return makeLinearLayout( Array{Integer(mat_stride), Integer(mat_continuous)}); else @@ -784,14 +793,20 @@ Layout makeGemmABLayoutSm100(int mat_stride, int mat_continuous, int continuity, ICHECK(0) << "float64 on sm100 is not supported now"; } int vector_size = 128 / element_size; - if (mat_continuous % (vector_size * 8) == 0) - return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size); - else if (mat_continuous % (vector_size * 4) == 0) - return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size); - else if (mat_continuous % (vector_size * 2) == 0) - return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous, - element_size); - else if (mat_continuous % vector_size == 0) + + if (mat_stride % 8 == 0) { + if (mat_continuous % (vector_size * 8) == 0) + return makeFullBankSwizzleLayout(mat_stride, mat_continuous, + element_size); + else if (mat_continuous % (vector_size * 4) == 0) + return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, + element_size); + else if (mat_continuous % (vector_size * 2) == 0) + return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous, + element_size); + } + + if (mat_continuous % vector_size == 0) return makeLinearLayout( Array{Integer(mat_stride), Integer(mat_continuous)}); else