Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 184 additions & 60 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1048,9 +1048,12 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
MVT::v32i32, MVT::v64i32, MVT::v128i32},
Custom);

setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::Other, Custom);
// Enable custom lowering for the i128 bit operand with clusterlaunchcontrol
setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i128, Custom);
// Enable custom lowering for the following:
// * MVT::i128 - clusterlaunchcontrol
// * MVT::i32 - prmt
// * MVT::Other - internal.addrspace.wrap
setOperationAction(ISD::INTRINSIC_WO_CHAIN, {MVT::i32, MVT::i128, MVT::Other},
Custom);
}

const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
Expand Down Expand Up @@ -2060,6 +2063,19 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
}

static SDValue getPRMT(SDValue A, SDValue B, SDValue Selector, SDLoc DL,
SelectionDAG &DAG,
unsigned Mode = NVPTX::PTXPrmtMode::NONE) {
Comment on lines +2066 to +2068
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd add another overload with Selector provided as an integer. That seems to be a common pattern that forces us to sprinkle DAG.getConstant(X, DL, MVT::i32) in numerous places.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

return DAG.getNode(NVPTXISD::PRMT, DL, MVT::i32,
{A, B, Selector, DAG.getConstant(Mode, DL, MVT::i32)});
}

static SDValue getPRMT(SDValue A, SDValue B, uint64_t Selector, SDLoc DL,
SelectionDAG &DAG,
unsigned Mode = NVPTX::PTXPrmtMode::NONE) {
return getPRMT(A, B, DAG.getConstant(Selector, DL, MVT::i32), DL, DAG, Mode);
}

SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
// Handle bitcasting from v2i8 without hitting the default promotion
// strategy which goes through stack memory.
Expand Down Expand Up @@ -2111,15 +2127,12 @@ SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
L = DAG.getAnyExtOrTrunc(L, DL, MVT::i32);
R = DAG.getAnyExtOrTrunc(R, DL, MVT::i32);
}
return DAG.getNode(
NVPTXISD::PRMT, DL, MVT::v4i8,
{L, R, DAG.getConstant(SelectionValue, DL, MVT::i32),
DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32)});
return getPRMT(L, R, SelectionValue, DL, DAG);
};
auto PRMT__10 = GetPRMT(Op->getOperand(0), Op->getOperand(1), true, 0x3340);
auto PRMT__32 = GetPRMT(Op->getOperand(2), Op->getOperand(3), true, 0x3340);
auto PRMT3210 = GetPRMT(PRMT__10, PRMT__32, false, 0x5410);
return DAG.getNode(ISD::BITCAST, DL, VT, PRMT3210);
return DAG.getBitcast(VT, PRMT3210);
}

// Get value or the Nth operand as an APInt(32). Undef values treated as 0.
Expand Down Expand Up @@ -2176,11 +2189,14 @@ SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
SDValue Selector = DAG.getNode(ISD::OR, DL, MVT::i32,
DAG.getZExtOrTrunc(Index, DL, MVT::i32),
DAG.getConstant(0x7770, DL, MVT::i32));
SDValue PRMT = DAG.getNode(
NVPTXISD::PRMT, DL, MVT::i32,
{DAG.getBitcast(MVT::i32, Vector), DAG.getConstant(0, DL, MVT::i32),
Selector, DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32)});
return DAG.getAnyExtOrTrunc(PRMT, DL, Op->getValueType(0));
SDValue PRMT = getPRMT(DAG.getBitcast(MVT::i32, Vector),
DAG.getConstant(0, DL, MVT::i32), Selector, DL, DAG);
SDValue Ext = DAG.getAnyExtOrTrunc(PRMT, DL, Op->getValueType(0));
SDNodeFlags Flags;
Flags.setNoSignedWrap(Ext.getScalarValueSizeInBits() > 8);
Flags.setNoUnsignedWrap(Ext.getScalarValueSizeInBits() >= 8);
Ext->setFlags(Flags);
return Ext;
}

// Constant index will be matched by tablegen.
Expand Down Expand Up @@ -2242,9 +2258,9 @@ SDValue NVPTXTargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
}

SDLoc DL(Op);
return DAG.getNode(NVPTXISD::PRMT, DL, MVT::v4i8, V1, V2,
DAG.getConstant(Selector, DL, MVT::i32),
DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32));
SDValue PRMT = getPRMT(DAG.getBitcast(MVT::i32, V1),
DAG.getBitcast(MVT::i32, V2), Selector, DL, DAG);
return DAG.getBitcast(Op.getValueType(), PRMT);
}
/// LowerShiftRightParts - Lower SRL_PARTS, SRA_PARTS, which
/// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
Expand Down Expand Up @@ -2729,10 +2745,46 @@ static SDValue LowerClusterLaunchControlQueryCancel(SDValue Op,
{TryCancelResponse0, TryCancelResponse1});
}

static SDValue lowerPrmtIntrinsic(SDValue Op, SelectionDAG &DAG) {
const unsigned Mode = [&]() {
switch (Op->getConstantOperandVal(0)) {
case Intrinsic::nvvm_prmt:
return NVPTX::PTXPrmtMode::NONE;
case Intrinsic::nvvm_prmt_b4e:
return NVPTX::PTXPrmtMode::B4E;
case Intrinsic::nvvm_prmt_ecl:
return NVPTX::PTXPrmtMode::ECL;
case Intrinsic::nvvm_prmt_ecr:
return NVPTX::PTXPrmtMode::ECR;
case Intrinsic::nvvm_prmt_f4e:
return NVPTX::PTXPrmtMode::F4E;
case Intrinsic::nvvm_prmt_rc16:
return NVPTX::PTXPrmtMode::RC16;
case Intrinsic::nvvm_prmt_rc8:
return NVPTX::PTXPrmtMode::RC8;
default:
llvm_unreachable("unsupported/unhandled intrinsic");
}
}();
SDLoc DL(Op);
SDValue A = Op->getOperand(1);
SDValue B = Op.getNumOperands() == 4 ? Op.getOperand(2)
: DAG.getConstant(0, DL, MVT::i32);
SDValue Selector = (Op->op_end() - 1)->get();
return getPRMT(A, B, Selector, DL, DAG, Mode);
}
static SDValue lowerIntrinsicWOChain(SDValue Op, SelectionDAG &DAG) {
switch (Op->getConstantOperandVal(0)) {
default:
return Op;
case Intrinsic::nvvm_prmt:
case Intrinsic::nvvm_prmt_b4e:
case Intrinsic::nvvm_prmt_ecl:
case Intrinsic::nvvm_prmt_ecr:
case Intrinsic::nvvm_prmt_f4e:
case Intrinsic::nvvm_prmt_rc16:
case Intrinsic::nvvm_prmt_rc8:
return lowerPrmtIntrinsic(Op, DAG);
case Intrinsic::nvvm_internal_addrspace_wrap:
return Op.getOperand(1);
case Intrinsic::nvvm_clusterlaunchcontrol_query_cancel_is_canceled:
Expand Down Expand Up @@ -5775,11 +5827,10 @@ PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
SDLoc DL(N);
auto &DAG = DCI.DAG;

auto PRMT = DAG.getNode(
NVPTXISD::PRMT, DL, MVT::v4i8,
{Op0, Op1, DAG.getConstant((Op1Bytes << 8) | Op0Bytes, DL, MVT::i32),
DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32)});
return DAG.getNode(ISD::BITCAST, DL, VT, PRMT);
auto PRMT =
getPRMT(DAG.getBitcast(MVT::i32, Op0), DAG.getBitcast(MVT::i32, Op1),
(Op1Bytes << 8) | Op0Bytes, DL, DAG);
return DAG.getBitcast(VT, PRMT);
}

static SDValue combineADDRSPACECAST(SDNode *N,
Expand All @@ -5797,47 +5848,120 @@ static SDValue combineADDRSPACECAST(SDNode *N,
return SDValue();
}

// Given a constant selector value and a prmt mode, return the selector value
// normalized to the generic prmt mode. See the PTX ISA documentation for more
// details:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt
static APInt getPRMTSelector(const APInt &Selector, unsigned Mode) {
if (Mode == NVPTX::PTXPrmtMode::NONE)
return Selector;

const unsigned V = Selector.trunc(2).getZExtValue();

const auto GetSelector = [](unsigned S0, unsigned S1, unsigned S2,
unsigned S3) {
return APInt(32, S0 | (S1 << 4) | (S2 << 8) | (S3 << 12));
};

switch (Mode) {
case NVPTX::PTXPrmtMode::F4E:
return GetSelector(V, V + 1, V + 2, V + 3);
case NVPTX::PTXPrmtMode::B4E:
return GetSelector(V, (V - 1) & 7, (V - 2) & 7, (V - 3) & 7);
case NVPTX::PTXPrmtMode::RC8:
return GetSelector(V, V, V, V);
case NVPTX::PTXPrmtMode::ECL:
return GetSelector(V, std::max(V, 1U), std::max(V, 2U), 3U);
case NVPTX::PTXPrmtMode::ECR:
return GetSelector(0, std::min(V, 1U), std::min(V, 2U), V);
case NVPTX::PTXPrmtMode::RC16: {
unsigned V1 = (V & 1) << 1;
return GetSelector(V1, V1 + 1, V1, V1 + 1);
}
default:
llvm_unreachable("Invalid PRMT mode");
}
}

static APInt computePRMT(APInt A, APInt B, APInt Selector, unsigned Mode) {
// {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
APInt BitField = B.concat(A);
APInt SelectorVal = getPRMTSelector(Selector, Mode);
APInt Result(32, 0);
for (unsigned I : llvm::seq(4U)) {
APInt Sel = SelectorVal.extractBits(4, I * 4);
unsigned Idx = Sel.getLoBits(3).getZExtValue();
unsigned Sign = Sel.getHiBits(1).getZExtValue();
APInt Byte = BitField.extractBits(8, Idx * 8);
if (Sign)
Byte = Byte.ashr(8);
Result.insertBits(Byte, I * 8);
}
return Result;
}

static SDValue combinePRMT(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
CodeGenOptLevel OptLevel) {
if (OptLevel == CodeGenOptLevel::None)
return SDValue();

// Constant fold PRMT
if (isa<ConstantSDNode>(N->getOperand(0)) &&
isa<ConstantSDNode>(N->getOperand(1)) &&
isa<ConstantSDNode>(N->getOperand(2)))
return DCI.DAG.getConstant(computePRMT(N->getConstantOperandAPInt(0),
N->getConstantOperandAPInt(1),
N->getConstantOperandAPInt(2),
N->getConstantOperandVal(3)),
SDLoc(N), N->getValueType(0));

return SDValue();
}

SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
switch (N->getOpcode()) {
default: break;
case ISD::ADD:
return PerformADDCombine(N, DCI, OptLevel);
case ISD::FADD:
return PerformFADDCombine(N, DCI, OptLevel);
case ISD::MUL:
return PerformMULCombine(N, DCI, OptLevel);
case ISD::SHL:
return PerformSHLCombine(N, DCI, OptLevel);
case ISD::AND:
return PerformANDCombine(N, DCI);
case ISD::UREM:
case ISD::SREM:
return PerformREMCombine(N, DCI, OptLevel);
case ISD::SETCC:
return PerformSETCCCombine(N, DCI, STI.getSmVersion());
case ISD::LOAD:
case NVPTXISD::LoadParamV2:
case NVPTXISD::LoadV2:
case NVPTXISD::LoadV4:
return combineUnpackingMovIntoLoad(N, DCI);
case NVPTXISD::StoreParam:
case NVPTXISD::StoreParamV2:
case NVPTXISD::StoreParamV4:
return PerformStoreParamCombine(N, DCI);
case ISD::STORE:
case NVPTXISD::StoreV2:
case NVPTXISD::StoreV4:
return PerformStoreCombine(N, DCI);
case ISD::EXTRACT_VECTOR_ELT:
return PerformEXTRACTCombine(N, DCI);
case ISD::VSELECT:
return PerformVSELECTCombine(N, DCI);
case ISD::BUILD_VECTOR:
return PerformBUILD_VECTORCombine(N, DCI);
case ISD::ADDRSPACECAST:
return combineADDRSPACECAST(N, DCI);
default:
break;
case ISD::ADD:
return PerformADDCombine(N, DCI, OptLevel);
case ISD::ADDRSPACECAST:
return combineADDRSPACECAST(N, DCI);
case ISD::AND:
return PerformANDCombine(N, DCI);
case ISD::BUILD_VECTOR:
return PerformBUILD_VECTORCombine(N, DCI);
case ISD::EXTRACT_VECTOR_ELT:
return PerformEXTRACTCombine(N, DCI);
case ISD::FADD:
return PerformFADDCombine(N, DCI, OptLevel);
case ISD::LOAD:
case NVPTXISD::LoadParamV2:
case NVPTXISD::LoadV2:
case NVPTXISD::LoadV4:
return combineUnpackingMovIntoLoad(N, DCI);
case ISD::MUL:
return PerformMULCombine(N, DCI, OptLevel);
case NVPTXISD::PRMT:
return combinePRMT(N, DCI, OptLevel);
case ISD::SETCC:
return PerformSETCCCombine(N, DCI, STI.getSmVersion());
case ISD::SHL:
return PerformSHLCombine(N, DCI, OptLevel);
case ISD::SREM:
case ISD::UREM:
return PerformREMCombine(N, DCI, OptLevel);
case NVPTXISD::StoreParam:
case NVPTXISD::StoreParamV2:
case NVPTXISD::StoreParamV4:
return PerformStoreParamCombine(N, DCI);
case ISD::STORE:
case NVPTXISD::StoreV2:
case NVPTXISD::StoreV4:
return PerformStoreCombine(N, DCI);
case ISD::VSELECT:
return PerformVSELECTCombine(N, DCI);
}
return SDValue();
}
Expand Down Expand Up @@ -6387,7 +6511,7 @@ static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,
ConstantSDNode *Selector = dyn_cast<ConstantSDNode>(Op.getOperand(2));
unsigned Mode = Op.getConstantOperandVal(3);

if (Mode != NVPTX::PTXPrmtMode::NONE || !Selector)
if (!Selector)
return;

KnownBits AKnown = DAG.computeKnownBits(A, Depth);
Expand All @@ -6396,7 +6520,7 @@ static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,
// {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
KnownBits BitField = BKnown.concat(AKnown);

APInt SelectorVal = Selector->getAPIntValue();
APInt SelectorVal = getPRMTSelector(Selector->getAPIntValue(), Mode);
for (unsigned I : llvm::seq(std::min(4U, Known.getBitWidth() / 8))) {
APInt Sel = SelectorVal.extractBits(4, I * 4);
unsigned Idx = Sel.getLoBits(3).getZExtValue();
Expand Down
23 changes: 19 additions & 4 deletions llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -1453,18 +1453,33 @@ let hasSideEffects = false in {
(ins PrmtMode:$mode),
"prmt.b32$mode",
[(set i32:$d, (prmt i32:$a, i32:$b, imm:$c, imm:$mode))]>;
def PRMT_B32rir
: BasicFlagsNVPTXInst<(outs B32:$d),
(ins B32:$a, i32imm:$b, B32:$c),
(ins PrmtMode:$mode),
"prmt.b32$mode",
[(set i32:$d, (prmt i32:$a, imm:$b, i32:$c, imm:$mode))]>;
def PRMT_B32rii
: BasicFlagsNVPTXInst<(outs B32:$d),
(ins B32:$a, i32imm:$b, Hexu32imm:$c),
(ins PrmtMode:$mode),
"prmt.b32$mode",
[(set i32:$d, (prmt i32:$a, imm:$b, imm:$c, imm:$mode))]>;
def PRMT_B32rir
def PRMT_B32irr
: BasicFlagsNVPTXInst<(outs B32:$d),
(ins B32:$a, i32imm:$b, B32:$c),
(ins PrmtMode:$mode),
(ins i32imm:$a, B32:$b, B32:$c), (ins PrmtMode:$mode),
"prmt.b32$mode",
[(set i32:$d, (prmt imm:$a, i32:$b, i32:$c, imm:$mode))]>;
def PRMT_B32iri
: BasicFlagsNVPTXInst<(outs B32:$d),
(ins i32imm:$a, B32:$b, Hexu32imm:$c), (ins PrmtMode:$mode),
"prmt.b32$mode",
[(set i32:$d, (prmt imm:$a, i32:$b, imm:$c, imm:$mode))]>;
def PRMT_B32iir
: BasicFlagsNVPTXInst<(outs B32:$d),
(ins i32imm:$a, i32imm:$b, B32:$c), (ins PrmtMode:$mode),
"prmt.b32$mode",
[(set i32:$d, (prmt i32:$a, imm:$b, i32:$c, imm:$mode))]>;
[(set i32:$d, (prmt imm:$a, imm:$b, i32:$c, imm:$mode))]>;

}

Expand Down
18 changes: 0 additions & 18 deletions llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -1047,24 +1047,6 @@ class F_MATH_3<string OpcStr, NVPTXRegClass t_regclass,
// MISC
//

class PRMT3Pat<Intrinsic prmt_intrinsic, PatLeaf prmt_mode>
: Pat<(prmt_intrinsic i32:$a, i32:$b, i32:$c),
(PRMT_B32rrr $a, $b, $c, prmt_mode)>;

class PRMT2Pat<Intrinsic prmt_intrinsic, PatLeaf prmt_mode>
: Pat<(prmt_intrinsic i32:$a, i32:$c),
(PRMT_B32rir $a, (i32 0), $c, prmt_mode)>;

def : PRMT3Pat<int_nvvm_prmt, PrmtNONE>;
def : PRMT3Pat<int_nvvm_prmt_f4e, PrmtF4E>;
def : PRMT3Pat<int_nvvm_prmt_b4e, PrmtB4E>;

def : PRMT2Pat<int_nvvm_prmt_rc8, PrmtRC8>;
def : PRMT2Pat<int_nvvm_prmt_ecl, PrmtECL>;
def : PRMT2Pat<int_nvvm_prmt_ecr, PrmtECR>;
def : PRMT2Pat<int_nvvm_prmt_rc16, PrmtRC16>;


def INT_NVVM_NANOSLEEP_I : BasicNVPTXInst<(outs), (ins i32imm:$i), "nanosleep.u32",
[(int_nvvm_nanosleep imm:$i)]>,
Requires<[hasPTX<63>, hasSM<70>]>;
Expand Down
Loading
Loading