@@ -2070,6 +2070,12 @@ static SDValue getPRMT(SDValue A, SDValue B, SDValue Selector, SDLoc DL,
20702070 {A, B, Selector, DAG.getConstant (Mode, DL, MVT::i32 )});
20712071}
20722072
2073+ static SDValue getPRMT (SDValue A, SDValue B, uint64_t Selector, SDLoc DL,
2074+ SelectionDAG &DAG,
2075+ unsigned Mode = NVPTX::PTXPrmtMode::NONE) {
2076+ return getPRMT (A, B, DAG.getConstant (Selector, DL, MVT::i32 ), DL, DAG, Mode);
2077+ }
2078+
20732079SDValue NVPTXTargetLowering::LowerBITCAST (SDValue Op, SelectionDAG &DAG) const {
20742080 // Handle bitcasting from v2i8 without hitting the default promotion
20752081 // strategy which goes through stack memory.
@@ -2121,8 +2127,7 @@ SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
21212127 L = DAG.getAnyExtOrTrunc (L, DL, MVT::i32 );
21222128 R = DAG.getAnyExtOrTrunc (R, DL, MVT::i32 );
21232129 }
2124- return getPRMT (L, R, DAG.getConstant (SelectionValue, DL, MVT::i32 ), DL,
2125- DAG);
2130+ return getPRMT (L, R, SelectionValue, DL, DAG);
21262131 };
21272132 auto PRMT__10 = GetPRMT (Op->getOperand (0 ), Op->getOperand (1 ), true , 0x3340 );
21282133 auto PRMT__32 = GetPRMT (Op->getOperand (2 ), Op->getOperand (3 ), true , 0x3340 );
@@ -2253,9 +2258,8 @@ SDValue NVPTXTargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
22532258 }
22542259
22552260 SDLoc DL (Op);
2256- SDValue PRMT =
2257- getPRMT (DAG.getBitcast (MVT::i32 , V1), DAG.getBitcast (MVT::i32 , V2),
2258- DAG.getConstant (Selector, DL, MVT::i32 ), DL, DAG);
2261+ SDValue PRMT = getPRMT (DAG.getBitcast (MVT::i32 , V1),
2262+ DAG.getBitcast (MVT::i32 , V2), Selector, DL, DAG);
22592263 return DAG.getBitcast (Op.getValueType (), PRMT);
22602264}
22612265// / LowerShiftRightParts - Lower SRL_PARTS, SRA_PARTS, which
@@ -5823,9 +5827,9 @@ PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
58235827 SDLoc DL (N);
58245828 auto &DAG = DCI.DAG ;
58255829
5826- auto PRMT = getPRMT (
5827- DAG.getBitcast (MVT::i32 , Op0), DAG.getBitcast (MVT::i32 , Op1),
5828- DAG. getConstant (( Op1Bytes << 8 ) | Op0Bytes, DL, MVT:: i32 ) , DL, DAG);
5830+ auto PRMT =
5831+ getPRMT ( DAG.getBitcast (MVT::i32 , Op0), DAG.getBitcast (MVT::i32 , Op1),
5832+ ( Op1Bytes << 8 ) | Op0Bytes, DL, DAG);
58295833 return DAG.getBitcast (VT, PRMT);
58305834}
58315835
@@ -5844,11 +5848,15 @@ static SDValue combineADDRSPACECAST(SDNode *N,
58445848 return SDValue ();
58455849}
58465850
5847- static APInt getPRMTSelector (APInt Selector, unsigned Mode) {
5851+ // Given a constant selector value and a prmt mode, return the selector value
5852+ // normalized to the generic prmt mode. See the PTX ISA documentation for more
5853+ // details:
5854+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt
5855+ static APInt getPRMTSelector (const APInt &Selector, unsigned Mode) {
58485856 if (Mode == NVPTX::PTXPrmtMode::NONE)
58495857 return Selector;
58505858
5851- unsigned V = Selector.trunc (2 ).getZExtValue ();
5859+ const unsigned V = Selector.trunc (2 ).getZExtValue ();
58525860
58535861 const auto GetSelector = [](unsigned S0, unsigned S1, unsigned S2,
58545862 unsigned S3) {
@@ -5918,24 +5926,32 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
59185926 break ;
59195927 case ISD::ADD:
59205928 return PerformADDCombine (N, DCI, OptLevel);
5929+ case ISD::ADDRSPACECAST:
5930+ return combineADDRSPACECAST (N, DCI);
5931+ case ISD::AND:
5932+ return PerformANDCombine (N, DCI);
5933+ case ISD::BUILD_VECTOR:
5934+ return PerformBUILD_VECTORCombine (N, DCI);
5935+ case ISD::EXTRACT_VECTOR_ELT:
5936+ return PerformEXTRACTCombine (N, DCI);
59215937 case ISD::FADD:
59225938 return PerformFADDCombine (N, DCI, OptLevel);
5939+ case ISD::LOAD:
5940+ case NVPTXISD::LoadParamV2:
5941+ case NVPTXISD::LoadV2:
5942+ case NVPTXISD::LoadV4:
5943+ return combineUnpackingMovIntoLoad (N, DCI);
59235944 case ISD::MUL:
59245945 return PerformMULCombine (N, DCI, OptLevel);
5946+ case NVPTXISD::PRMT:
5947+ return combinePRMT (N, DCI, OptLevel);
5948+ case ISD::SETCC:
5949+ return PerformSETCCCombine (N, DCI, STI.getSmVersion ());
59255950 case ISD::SHL:
59265951 return PerformSHLCombine (N, DCI, OptLevel);
5927- case ISD::AND:
5928- return PerformANDCombine (N, DCI);
5929- case ISD::UREM:
59305952 case ISD::SREM:
5953+ case ISD::UREM:
59315954 return PerformREMCombine (N, DCI, OptLevel);
5932- case ISD::SETCC:
5933- return PerformSETCCCombine (N, DCI, STI.getSmVersion ());
5934- case ISD::LOAD:
5935- case NVPTXISD::LoadParamV2:
5936- case NVPTXISD::LoadV2:
5937- case NVPTXISD::LoadV4:
5938- return combineUnpackingMovIntoLoad (N, DCI);
59395955 case NVPTXISD::StoreParam:
59405956 case NVPTXISD::StoreParamV2:
59415957 case NVPTXISD::StoreParamV4:
@@ -5944,16 +5960,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
59445960 case NVPTXISD::StoreV2:
59455961 case NVPTXISD::StoreV4:
59465962 return PerformStoreCombine (N, DCI);
5947- case ISD::EXTRACT_VECTOR_ELT:
5948- return PerformEXTRACTCombine (N, DCI);
59495963 case ISD::VSELECT:
59505964 return PerformVSELECTCombine (N, DCI);
5951- case ISD::BUILD_VECTOR:
5952- return PerformBUILD_VECTORCombine (N, DCI);
5953- case ISD::ADDRSPACECAST:
5954- return combineADDRSPACECAST (N, DCI);
5955- case NVPTXISD::PRMT:
5956- return combinePRMT (N, DCI, OptLevel);
59575965 }
59585966 return SDValue ();
59595967}
0 commit comments