Skip to content

Commit 81a2e32

Browse files
committed
address comments
1 parent e509eef commit 81a2e32

File tree

1 file changed

+36
-28
lines changed

1 file changed

+36
-28
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2070,6 +2070,12 @@ static SDValue getPRMT(SDValue A, SDValue B, SDValue Selector, SDLoc DL,
20702070
{A, B, Selector, DAG.getConstant(Mode, DL, MVT::i32)});
20712071
}
20722072

2073+
static SDValue getPRMT(SDValue A, SDValue B, uint64_t Selector, SDLoc DL,
2074+
SelectionDAG &DAG,
2075+
unsigned Mode = NVPTX::PTXPrmtMode::NONE) {
2076+
return getPRMT(A, B, DAG.getConstant(Selector, DL, MVT::i32), DL, DAG, Mode);
2077+
}
2078+
20732079
SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
20742080
// Handle bitcasting from v2i8 without hitting the default promotion
20752081
// strategy which goes through stack memory.
@@ -2121,8 +2127,7 @@ SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
21212127
L = DAG.getAnyExtOrTrunc(L, DL, MVT::i32);
21222128
R = DAG.getAnyExtOrTrunc(R, DL, MVT::i32);
21232129
}
2124-
return getPRMT(L, R, DAG.getConstant(SelectionValue, DL, MVT::i32), DL,
2125-
DAG);
2130+
return getPRMT(L, R, SelectionValue, DL, DAG);
21262131
};
21272132
auto PRMT__10 = GetPRMT(Op->getOperand(0), Op->getOperand(1), true, 0x3340);
21282133
auto PRMT__32 = GetPRMT(Op->getOperand(2), Op->getOperand(3), true, 0x3340);
@@ -2253,9 +2258,8 @@ SDValue NVPTXTargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
22532258
}
22542259

22552260
SDLoc DL(Op);
2256-
SDValue PRMT =
2257-
getPRMT(DAG.getBitcast(MVT::i32, V1), DAG.getBitcast(MVT::i32, V2),
2258-
DAG.getConstant(Selector, DL, MVT::i32), DL, DAG);
2261+
SDValue PRMT = getPRMT(DAG.getBitcast(MVT::i32, V1),
2262+
DAG.getBitcast(MVT::i32, V2), Selector, DL, DAG);
22592263
return DAG.getBitcast(Op.getValueType(), PRMT);
22602264
}
22612265
/// LowerShiftRightParts - Lower SRL_PARTS, SRA_PARTS, which
@@ -5823,9 +5827,9 @@ PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
58235827
SDLoc DL(N);
58245828
auto &DAG = DCI.DAG;
58255829

5826-
auto PRMT = getPRMT(
5827-
DAG.getBitcast(MVT::i32, Op0), DAG.getBitcast(MVT::i32, Op1),
5828-
DAG.getConstant((Op1Bytes << 8) | Op0Bytes, DL, MVT::i32), DL, DAG);
5830+
auto PRMT =
5831+
getPRMT(DAG.getBitcast(MVT::i32, Op0), DAG.getBitcast(MVT::i32, Op1),
5832+
(Op1Bytes << 8) | Op0Bytes, DL, DAG);
58295833
return DAG.getBitcast(VT, PRMT);
58305834
}
58315835

@@ -5844,11 +5848,15 @@ static SDValue combineADDRSPACECAST(SDNode *N,
58445848
return SDValue();
58455849
}
58465850

5847-
static APInt getPRMTSelector(APInt Selector, unsigned Mode) {
5851+
// Given a constant selector value and a prmt mode, return the selector value
5852+
// normalized to the generic prmt mode. See the PTX ISA documentation for more
5853+
// details:
5854+
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt
5855+
static APInt getPRMTSelector(const APInt &Selector, unsigned Mode) {
58485856
if (Mode == NVPTX::PTXPrmtMode::NONE)
58495857
return Selector;
58505858

5851-
unsigned V = Selector.trunc(2).getZExtValue();
5859+
const unsigned V = Selector.trunc(2).getZExtValue();
58525860

58535861
const auto GetSelector = [](unsigned S0, unsigned S1, unsigned S2,
58545862
unsigned S3) {
@@ -5918,24 +5926,32 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
59185926
break;
59195927
case ISD::ADD:
59205928
return PerformADDCombine(N, DCI, OptLevel);
5929+
case ISD::ADDRSPACECAST:
5930+
return combineADDRSPACECAST(N, DCI);
5931+
case ISD::AND:
5932+
return PerformANDCombine(N, DCI);
5933+
case ISD::BUILD_VECTOR:
5934+
return PerformBUILD_VECTORCombine(N, DCI);
5935+
case ISD::EXTRACT_VECTOR_ELT:
5936+
return PerformEXTRACTCombine(N, DCI);
59215937
case ISD::FADD:
59225938
return PerformFADDCombine(N, DCI, OptLevel);
5939+
case ISD::LOAD:
5940+
case NVPTXISD::LoadParamV2:
5941+
case NVPTXISD::LoadV2:
5942+
case NVPTXISD::LoadV4:
5943+
return combineUnpackingMovIntoLoad(N, DCI);
59235944
case ISD::MUL:
59245945
return PerformMULCombine(N, DCI, OptLevel);
5946+
case NVPTXISD::PRMT:
5947+
return combinePRMT(N, DCI, OptLevel);
5948+
case ISD::SETCC:
5949+
return PerformSETCCCombine(N, DCI, STI.getSmVersion());
59255950
case ISD::SHL:
59265951
return PerformSHLCombine(N, DCI, OptLevel);
5927-
case ISD::AND:
5928-
return PerformANDCombine(N, DCI);
5929-
case ISD::UREM:
59305952
case ISD::SREM:
5953+
case ISD::UREM:
59315954
return PerformREMCombine(N, DCI, OptLevel);
5932-
case ISD::SETCC:
5933-
return PerformSETCCCombine(N, DCI, STI.getSmVersion());
5934-
case ISD::LOAD:
5935-
case NVPTXISD::LoadParamV2:
5936-
case NVPTXISD::LoadV2:
5937-
case NVPTXISD::LoadV4:
5938-
return combineUnpackingMovIntoLoad(N, DCI);
59395955
case NVPTXISD::StoreParam:
59405956
case NVPTXISD::StoreParamV2:
59415957
case NVPTXISD::StoreParamV4:
@@ -5944,16 +5960,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
59445960
case NVPTXISD::StoreV2:
59455961
case NVPTXISD::StoreV4:
59465962
return PerformStoreCombine(N, DCI);
5947-
case ISD::EXTRACT_VECTOR_ELT:
5948-
return PerformEXTRACTCombine(N, DCI);
59495963
case ISD::VSELECT:
59505964
return PerformVSELECTCombine(N, DCI);
5951-
case ISD::BUILD_VECTOR:
5952-
return PerformBUILD_VECTORCombine(N, DCI);
5953-
case ISD::ADDRSPACECAST:
5954-
return combineADDRSPACECAST(N, DCI);
5955-
case NVPTXISD::PRMT:
5956-
return combinePRMT(N, DCI, OptLevel);
59575965
}
59585966
return SDValue();
59595967
}

0 commit comments

Comments
 (0)