Skip to content

Commit

Permalink
gpu: jit: gemm: only QW-align widths for QW-aligned data
Browse files Browse the repository at this point in the history
  • Loading branch information
petercad committed Apr 3, 2024
1 parent f5ff0a6 commit 5587f08
Showing 1 changed file with 19 additions and 14 deletions.
33 changes: 19 additions & 14 deletions src/gpu/jit/gemm/gen_gemm_kernel_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5770,9 +5770,12 @@ static inline bool canRelAddr(const RegisterBlock &blockSrc,
}

static inline int block2DWidthAlignment(Type T, const RegisterBlock &block,
const MatrixAddressing &atype,
const MatrixAddressingStrategy &astrategy) {
// Block 2D width must be DW-aligned, but generally use QW alignment for better performance for reads.
return ((astrategy.noExtraPad || block.writable) ? 4 : 8);
return ((astrategy.noExtraPad || block.writable || atype.alignment % 8)
? 4
: 8);
}

// Output code for setting up address/header GRFs for a single block, given
Expand Down Expand Up @@ -5989,7 +5992,7 @@ void gemm_kernel_generator_t<hw>::setupAddr(Type T, const GRFRange &addr,
if (doBaseAdjust && !astrategy.address2D) stub();
Subregister baStorage, baseAdjust, baseAdjustElems;

int widthAlign = block2DWidthAlignment(T, block, astrategy);
int widthAlign = block2DWidthAlignment(T, block, atype, astrategy);

if (!astrategy.address2D) mov(4, addr[0].ud(4)(1), 0u);

Expand Down Expand Up @@ -6729,6 +6732,7 @@ void gemm_kernel_generator_t<hw>::remaskLayout(Type T, int index, bool column,
}

static bool needsRemask(Type T, bool column, const RegisterBlock &block,
const MatrixAddressing &atype,
const MatrixAddressingStrategy &astrategy, bool ignoreMasks = false) {
if (!ignoreMasks)
if (column ? !block.remainderC : !block.remainderR) return false;
Expand All @@ -6740,19 +6744,20 @@ static bool needsRemask(Type T, bool column, const RegisterBlock &block,
int maskGranularity = block.ebytes;
if (block.ebytes >= 16) maskGranularity = 4;
if (block2DRemask)
maskGranularity = std::max(
maskGranularity, block2DWidthAlignment(T, block, astrategy));
maskGranularity = std::max(maskGranularity,
block2DWidthAlignment(T, block, atype, astrategy));
if (ignoreMasks && !(block2DRemask && astrategy.address2D))
maskGranularity = 256;

return (T.size() < maskGranularity);
}

static bool needsRemask(Type T, bool column,
const vector<RegisterBlock> &layout,
const vector<RegisterBlock> &layout, const MatrixAddressing &atype,
const MatrixAddressingStrategy &astrategy, bool ignoreMasks = false) {
for (auto &block : layout)
if (needsRemask(T, column, block, astrategy, ignoreMasks)) return true;
if (needsRemask(T, column, block, atype, astrategy, ignoreMasks))
return true;
return false;
}

Expand Down Expand Up @@ -13613,11 +13618,11 @@ void gemm_kernel_generator_t<hw>::kLoopActivateSLMRemainder(bool active,
bool asIfMaskedAi = Ai_lateKRem && state.Ai_strategy.padded;
bool asIfMaskedBi = Bi_lateKRem && state.Bi_strategy.padded;
slmRemaskA = slmA && mayAccessAllK && !Ai_remIncrCopy
&& needsRemask(Ta_ext, true, state.Ai_layoutRem, state.Ai_strategy,
asIfMaskedAi);
&& needsRemask(Ta_ext, true, state.Ai_layoutRem, state.Ai,
state.Ai_strategy, asIfMaskedAi);
slmRemaskB = slmB && mayAccessAllK && !Bi_remIncrCopy
&& needsRemask(Tb_ext, false, state.Bi_layoutRem, state.Bi_strategy,
asIfMaskedBi);
&& needsRemask(Tb_ext, false, state.Bi_layoutRem, state.Bi,
state.Bi_strategy, asIfMaskedBi);
}

static inline void kLoopModifiedFlagAP(GEMMState &state) {
Expand Down Expand Up @@ -14376,11 +14381,11 @@ void gemm_kernel_generator_t<hw>::kLoop(KLoop type, const GEMMProblem &problem,

// A/B remasking in k dimension, during remainder handling.
bool remaskA = !slmA && readA && (minOPCount > 1)
&& needsRemask(Ta_load, true, state.A_layoutRem, strategy.A,
state.A_lateKRem);
&& needsRemask(Ta_load, true, state.A_layoutRem, problem.A,
strategy.A, state.A_lateKRem);
bool remaskB = !slmB && readB && (minOPCount > 1)
&& needsRemask(Tb_load, false, state.B_layoutRem, strategy.B,
state.B_lateKRem);
&& needsRemask(Tb_load, false, state.B_layoutRem, problem.B,
strategy.B, state.B_lateKRem);

if (Ta.isInteger() && Tb.isInteger() && !calcASums && !calcBSums) {
// Only need to remask one operand for integer A/B. Choose the smaller one.
Expand Down

0 comments on commit 5587f08

Please sign in to comment.