Skip to content

Commit

Permalink
New GenISA intrinsic: WaveInterleave
Browse files Browse the repository at this point in the history
Adds new GenISA intrinsic WaveInterleave that does subgroup reduction on
each n-th work item. For example, for SIMD8 and interleave step = 2,
the result is reduction of work items 0,2,4,6 and separate reduction of
work items 1,3,5,7.

Change includes pattern match for interleave reduction implemented with
subgroup shuffles.
  • Loading branch information
pkwasnie-intel authored and igcbot committed Jul 17, 2024
1 parent dff1024 commit 86502fa
Show file tree
Hide file tree
Showing 13 changed files with 460 additions and 38 deletions.
1 change: 1 addition & 0 deletions IGC/Compiler/CISACodeGen/CheckInstrTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ void CheckInstrTypes::visitCallInst(CallInst& C)
case GenISAIntrinsic::GenISA_WaveInverseBallot:
case GenISAIntrinsic::GenISA_WavePrefix:
case GenISAIntrinsic::GenISA_WaveClustered:
case GenISAIntrinsic::GenISA_WaveInterleave:
case GenISAIntrinsic::GenISA_QuadPrefix:
case GenISAIntrinsic::GenISA_simdShuffleDown:
case GenISAIntrinsic::GenISA_simdShuffleXor:
Expand Down
145 changes: 113 additions & 32 deletions IGC/Compiler/CISACodeGen/EmitVISAPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8905,6 +8905,9 @@ void EmitPass::EmitGenIntrinsicMessage(llvm::GenIntrinsicInst* inst)
case GenISAIntrinsic::GenISA_WaveAll:
emitWaveAll(inst);
break;
case GenISAIntrinsic::GenISA_WaveInterleave:
emitWaveInterleave(inst);
break;
case GenISAIntrinsic::GenISA_WaveClustered:
emitWaveClustered(inst);
break;
Expand Down Expand Up @@ -13167,8 +13170,45 @@ CVariable* EmitPass::ScanReducePrepareSrc(VISA_Type type, uint64_t identityValue
}

// Reduction all reduce helper: dst_lane{k} = src_lane{simd + k} OP src_lane{k}, k = 0..(simd-1)
CVariable* EmitPass::ReductionReduceHelper(e_opcode op, VISA_Type type, SIMDMode simd, CVariable* src)
CVariable* EmitPass::ReductionReduceHelper(e_opcode op, VISA_Type type, SIMDMode simd, CVariable* src, CVariable* srcSecondHalf)
{
const bool isInt64Mul = ScanReduceIsInt64Mul(op, type);
const bool int64EmulationNeeded = ScanReduceIsInt64EmulationNeeded(op, type);

if (simd == SIMDMode::SIMD16 && m_currShader->m_numberInstance > 1)
{
IGC_ASSERT(srcSecondHalf);

CVariable* temp = m_currShader->GetNewVariable(
numLanes(simd),
type,
EALIGN_GRF,
false,
CName("reduceDstSecondHalf"));

if (!int64EmulationNeeded)
{
m_encoder->SetNoMask();
m_encoder->SetSimdSize(simd);
m_encoder->GenericAlu(op, temp, src, srcSecondHalf);
m_encoder->Push();
}
else
{
if (isInt64Mul)
{
CVariable* tmpMulSrc[2] = { src, srcSecondHalf };
Mul64(temp, tmpMulSrc, simd, true /* noMask */);
}
else
{
IGC_ASSERT_MESSAGE(0, "Unsupported");
}
}

return temp;
}

const bool is64bitType = ScanReduceIs64BitType(type);
const auto alignment = is64bitType ? IGC::EALIGN_QWORD : IGC::EALIGN_DWORD;
CVariable* temp = m_currShader->GetNewVariable(
Expand All @@ -13178,9 +13218,6 @@ CVariable* EmitPass::ReductionReduceHelper(e_opcode op, VISA_Type type, SIMDMode
false,
CName("reduceDst_SIMD", std::to_string(numLanes(simd)).c_str()));

const bool isInt64Mul = ScanReduceIsInt64Mul(op, type);
const bool int64EmulationNeeded = ScanReduceIsInt64EmulationNeeded(op, type);

if (!int64EmulationNeeded)
{
m_encoder->SetNoMask();
Expand Down Expand Up @@ -13546,34 +13583,7 @@ void EmitPass::emitReductionAll(
CVariable* srcH2 = ScanReducePrepareSrc(type, identityValue, negate, true /* secondHalf */,
src, nullptr /* dst */);

temp = m_currShader->GetNewVariable(
numLanes(simd),
type,
EALIGN_GRF,
false,
CName("reduceDstSecondHalf"));

const bool isInt64Mul = ScanReduceIsInt64Mul(op, type);
const bool int64EmulationNeeded = ScanReduceIsInt64EmulationNeeded(op, type);
if (!int64EmulationNeeded)
{
m_encoder->SetNoMask();
m_encoder->SetSimdSize(simd);
m_encoder->GenericAlu(op, temp, srcH1, srcH2);
m_encoder->Push();
}
else
{
if (isInt64Mul)
{
CVariable* tmpMulSrc[2] = { srcH1, srcH2 };
Mul64(temp, tmpMulSrc, simd, true /* noMask */);
}
else
{
IGC_ASSERT_MESSAGE(0, "Unsupported");
}
}
temp = ReductionReduceHelper(op, type, SIMDMode::SIMD16, temp, srcH2);
}
}
if (m_currShader->m_dispatchSize >= SIMDMode::SIMD16)
Expand Down Expand Up @@ -13723,6 +13733,54 @@ void EmitPass::emitReductionClustered(const e_opcode op, const uint64_t identity
}
}

void EmitPass::emitReductionInterleave(const e_opcode op, const uint64_t identityValue, const VISA_Type type,
const bool negate, const unsigned int step, CVariable* const src, CVariable* const dst)
{
if (step == 1)
{
// TODO: consider if it is possible to detect and handle this case in frontends
// and emit GenISA_WaveAll there, to enable optimizations specific to the ReduceAll intrinsic.
return emitReductionAll(op, identityValue, type, negate, src, dst);
}

const uint16_t firstStep = numLanes(m_currShader->m_dispatchSize) / 2;

IGC_ASSERT_MESSAGE(!dst->IsUniform(), "Unsupported: dst must be non-uniform");
IGC_ASSERT_MESSAGE(step % 2 == 0 && step <= firstStep, "Invalid reduction interleave step");

CVariable* srcH1 = ScanReducePrepareSrc(type, identityValue, negate, false /* secondHalf */,
src, nullptr /* dst */);
CVariable* temp = srcH1;

// Implementation is similar to emitReductionAll(), but we stop reduction before reaching SIMD1.
for (unsigned int currentStep = firstStep; currentStep >= step; currentStep >>= 1)
{
if (currentStep == 16 && m_currShader->m_numberInstance > 1)
{
CVariable* srcH2 = ScanReducePrepareSrc(type, identityValue, negate, true /* secondHalf */,
src, nullptr /* dst */);

temp = ReductionReduceHelper(op, type, SIMDMode::SIMD16, temp, srcH2);
}
else
{
temp = ReductionReduceHelper(op, type, lanesToSIMDMode(currentStep), temp);
}
}

// Broadcast result
m_encoder->SetSimdSize(m_currShader->m_SIMDSize);
m_encoder->SetSrcRegion(0, 0, step, 1);
m_encoder->Copy(dst, temp);
if (m_currShader->m_numberInstance > 1)
{
m_encoder->SetSecondHalf(true);
m_encoder->Copy(dst, temp);
m_encoder->SetSecondHalf(false);
}
m_encoder->Push();
}

// do prefix op across all activate channels
void EmitPass::emitPreOrPostFixOp(
e_opcode op, uint64_t identityValue, VISA_Type type, bool negateSrc,
Expand Down Expand Up @@ -21141,6 +21199,29 @@ void EmitPass::emitWaveClustered(llvm::GenIntrinsicInst* inst)
}
}

void EmitPass::emitWaveInterleave(llvm::GenIntrinsicInst* inst)
{
bool disableHelperLanes = int_cast<int>(cast<ConstantInt>(inst->getArgOperand(3))->getSExtValue()) == 2;
if (disableHelperLanes)
{
ForceDMask();
}
CVariable* src = GetSymbol(inst->getOperand(0));
const WaveOps op = static_cast<WaveOps>(cast<llvm::ConstantInt>(inst->getOperand(1))->getZExtValue());
const unsigned int step = int_cast<uint32_t>(cast<llvm::ConstantInt>(inst->getOperand(2))->getZExtValue());
VISA_Type type;
e_opcode opCode;
uint64_t identity = 0;
GetReductionOp(op, inst->getOperand(0)->getType(), identity, opCode, type);
CVariable* dst = m_destination;
m_encoder->SetSubSpanDestination(false);
emitReductionInterleave(opCode, identity, type, false, step, src, dst);
if (disableHelperLanes)
{
ResetVMask();
}
}

void EmitPass::emitDP4A(GenIntrinsicInst* GII, const SSource* Sources, const DstModifier& modifier, bool isAccSigned) {
GenISAIntrinsic::ID GIID = GII->getIntrinsicID();
CVariable* dst = m_destination;
Expand Down
11 changes: 10 additions & 1 deletion IGC/Compiler/CISACodeGen/EmitVISAPass.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ class EmitPass : public llvm::FunctionPass
bool ScanReduceIsInt64EmulationNeeded(e_opcode op, VISA_Type type);
CVariable* ScanReducePrepareSrc(VISA_Type type, uint64_t identityValue, bool negate, bool secondHalf,
CVariable* src, CVariable* dst, CVariable* flag = nullptr);
CVariable* ReductionReduceHelper(e_opcode op, VISA_Type type, SIMDMode simd, CVariable* src);
CVariable* ReductionReduceHelper(e_opcode op, VISA_Type type, SIMDMode simd, CVariable* src, CVariable* srcSecondHalf = nullptr);
void ReductionExpandHelper(e_opcode op, VISA_Type type, CVariable* src, CVariable* dst);
void ReductionClusteredSrcHelper(CVariable* (&pSrc)[2], CVariable* src, uint16_t numLanes,
VISA_Type type, uint numInst, bool secondHalf);
Expand All @@ -325,6 +325,14 @@ class EmitPass : public llvm::FunctionPass
const unsigned int clusterSize,
CVariable* const src,
CVariable* const dst);
void emitReductionInterleave(
const e_opcode op,
const uint64_t identityValue,
const VISA_Type type,
const bool negate,
const unsigned int step,
CVariable* const src,
CVariable* const dst);
void emitPreOrPostFixOp(
e_opcode op,
uint64_t identityValue,
Expand Down Expand Up @@ -432,6 +440,7 @@ class EmitPass : public llvm::FunctionPass
void emitQuadPrefix(llvm::QuadPrefixIntrinsic* I);
void emitWaveAll(llvm::GenIntrinsicInst* inst);
void emitWaveClustered(llvm::GenIntrinsicInst* inst);
void emitWaveInterleave(llvm::GenIntrinsicInst* inst);

// Those three "vector" version shall be combined with
// non-vector version.
Expand Down
3 changes: 2 additions & 1 deletion IGC/Compiler/CISACodeGen/HalfPromotion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ void IGC::HalfPromotion::handleGenIntrinsic(llvm::GenIntrinsicInst& I)
GenISAIntrinsic::ID id = I.getIntrinsicID();
if (id == GenISAIntrinsic::GenISA_WaveAll ||
id == GenISAIntrinsic::GenISA_WavePrefix ||
id == GenISAIntrinsic::GenISA_WaveClustered)
id == GenISAIntrinsic::GenISA_WaveClustered ||
id == GenISAIntrinsic::GenISA_WaveInterleave)
{
Module* M = I.getParent()->getParent()->getParent();
llvm::IGCIRBuilder<> builder(&I);
Expand Down
2 changes: 2 additions & 0 deletions IGC/Compiler/CISACodeGen/PatternMatchPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1370,6 +1370,7 @@ namespace IGC
case GenISAIntrinsic::GenISA_WaveInverseBallot:
case GenISAIntrinsic::GenISA_WaveAll:
case GenISAIntrinsic::GenISA_WaveClustered:
case GenISAIntrinsic::GenISA_WaveInterleave:
case GenISAIntrinsic::GenISA_WavePrefix:
match = MatchWaveInstruction(*GII);
break;
Expand Down Expand Up @@ -5183,6 +5184,7 @@ namespace IGC
case GenISAIntrinsic::GenISA_WaveInverseBallot:
helperLaneIndex = 1;
break;
case GenISAIntrinsic::GenISA_WaveInterleave:
case GenISAIntrinsic::GenISA_WaveClustered:
helperLaneIndex = 3;
break;
Expand Down
4 changes: 4 additions & 0 deletions IGC/Compiler/CISACodeGen/PromoteInt8Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1134,6 +1134,7 @@ void PromoteInt8Type::promoteIntrinsic()
else if (
GII->isGenIntrinsic(GenISAIntrinsic::GenISA_WaveAll) ||
GII->isGenIntrinsic(GenISAIntrinsic::GenISA_WaveClustered) ||
GII->isGenIntrinsic(GenISAIntrinsic::GenISA_WaveInterleave) ||
GII->isGenIntrinsic(GenISAIntrinsic::GenISA_WavePrefix) ||
GII->isGenIntrinsic(GenISAIntrinsic::GenISA_QuadPrefix))
{
Expand All @@ -1158,6 +1159,7 @@ void PromoteInt8Type::promoteIntrinsic()
GenISAIntrinsic::ID gid = GII->getIntrinsicID();
if (gid == GenISAIntrinsic::GenISA_WaveAll ||
gid == GenISAIntrinsic::GenISA_WaveClustered ||
gid == GenISAIntrinsic::GenISA_WaveInterleave ||
gid == GenISAIntrinsic::GenISA_WavePrefix ||
gid == GenISAIntrinsic::GenISA_QuadPrefix ||
gid == GenISAIntrinsic::GenISA_WaveShuffleIndex ||
Expand Down Expand Up @@ -1199,9 +1201,11 @@ void PromoteInt8Type::promoteIntrinsic()
break;
}
case GenISAIntrinsic::GenISA_WaveClustered:
case GenISAIntrinsic::GenISA_WaveInterleave:
{
// prototype:
// Ty <clustered> (Ty, char, int, int)
// Ty <interleave> (Ty, char, int, int)
iArgs.push_back(GII->getArgOperand(1));
iArgs.push_back(GII->getArgOperand(2));
iArgs.push_back(GII->getArgOperand(3));
Expand Down
6 changes: 6 additions & 0 deletions IGC/Compiler/CISACodeGen/WIAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1415,6 +1415,7 @@ WIAnalysis::WIDependancy WIAnalysisRunner::calculate_dep(const CallInst* inst)
intrinsic_name == llvm_waveBallot ||
intrinsic_name == llvm_waveAll ||
intrinsic_name == llvm_waveClustered ||
intrinsic_name == llvm_waveInterleave ||
intrinsic_name == llvm_ld_ptr ||
intrinsic_name == llvm_ldlptr ||
(IGC_IS_FLAG_DISABLED(DisableUniformTypedAccess) && intrinsic_name == llvm_typed_read) ||
Expand Down Expand Up @@ -1718,6 +1719,11 @@ WIAnalysis::WIDependancy WIAnalysisRunner::calculate_dep(const CallInst* inst)
}
}

if (intrinsic_name == llvm_waveInterleave)
{
return WIAnalysis::RANDOM;
}

if (intrinsic_name == llvm_URBRead ||
intrinsic_name == llvm_URBReadOutput)
{
Expand Down
1 change: 1 addition & 0 deletions IGC/Compiler/CISACodeGen/helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1782,6 +1782,7 @@ namespace IGC
{
return (opcode == llvm_waveAll ||
opcode == llvm_waveClustered ||
opcode == llvm_waveInterleave ||
opcode == llvm_wavePrefix ||
opcode == llvm_waveShuffleIndex ||
opcode == llvm_waveBroadcast ||
Expand Down
1 change: 1 addition & 0 deletions IGC/Compiler/CISACodeGen/opCode.h
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ DECLARE_OPCODE(GenISA_pair_to_ptr, GenISAIntrinsic, llvm_pair_to_ptr, false, fal
DECLARE_OPCODE(GenISA_WaveBallot, GenISAIntrinsic, llvm_waveBallot, false, false, false, false, false, false, false)
DECLARE_OPCODE(GenISA_WaveAll, GenISAIntrinsic, llvm_waveAll, false, false, false, false, false, false, false)
DECLARE_OPCODE(GenISA_WaveClustered, GenISAIntrinsic, llvm_waveClustered, false, false, false, false, false, false, false)
DECLARE_OPCODE(GenISA_WaveInterleave, GenISAIntrinsic, llvm_waveInterleave, false, false, false, false, false, false, false)
DECLARE_OPCODE(GenISA_WavePrefix, GenISAIntrinsic, llvm_wavePrefix, false, false, false, false, false, false, false)
DECLARE_OPCODE(GenISA_QuadPrefix, GenISAIntrinsic, llvm_quadPrefix, false, false, false, false, false, false, false)
DECLARE_OPCODE(GenISA_WaveShuffleIndex, GenISAIntrinsic, llvm_waveShuffleIndex, false, false, false, false, false, false, false)
Expand Down
3 changes: 2 additions & 1 deletion IGC/Compiler/CodeGenPublicEnums.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ namespace IGC
{
GroupOperationScan,
GroupOperationReduce,
GroupOperationClusteredReduce
GroupOperationClusteredReduce,
GroupOperationInterleaveReduce
};

enum SGVUsage
Expand Down
Loading

0 comments on commit 86502fa

Please sign in to comment.