@@ -1027,6 +1027,7 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
10271027 ISD::SCALAR_TO_VECTOR,
10281028 ISD::ZERO_EXTEND,
10291029 ISD::SIGN_EXTEND_INREG,
1030+ ISD::ANY_EXTEND,
10301031 ISD::EXTRACT_VECTOR_ELT,
10311032 ISD::INSERT_VECTOR_ELT,
10321033 ISD::FCOPYSIGN});
@@ -13289,6 +13290,20 @@ static uint32_t getPermuteMask(SDValue V) {
1328913290 return ~0;
1329013291}
1329113292
13293+ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI);
13294+
13295+ SDValue SITargetLowering::performLeftShiftCombine(SDNode *N,
13296+ DAGCombinerInfo &DCI) const {
13297+ if (DCI.getDAGCombineLevel() < AfterLegalizeTypes)
13298+ return SDValue();
13299+
13300+ EVT VT = N->getValueType(0);
13301+ if (VT != MVT::i32)
13302+ return SDValue();
13303+
13304+ return matchPERM(N, DCI);
13305+ }
13306+
1329213307SDValue SITargetLowering::performAndCombine(SDNode *N,
1329313308 DAGCombinerInfo &DCI) const {
1329413309 if (DCI.isBeforeLegalize())
@@ -14349,10 +14364,11 @@ SDValue SITargetLowering::performXorCombine(SDNode *N,
1434914364 return SDValue();
1435014365}
1435114366
14352- SDValue SITargetLowering::performZeroExtendCombine(SDNode *N,
14353- DAGCombinerInfo &DCI) const {
14367+ SDValue
14368+ SITargetLowering::performZeroOrAnyExtendCombine(SDNode *N,
14369+ DAGCombinerInfo &DCI) const {
1435414370 if (!Subtarget->has16BitInsts() ||
14355- DCI.getDAGCombineLevel() < AfterLegalizeDAG )
14371+ DCI.getDAGCombineLevel() < AfterLegalizeTypes )
1435614372 return SDValue();
1435714373
1435814374 EVT VT = N->getValueType(0);
@@ -14363,7 +14379,41 @@ SDValue SITargetLowering::performZeroExtendCombine(SDNode *N,
1436314379 if (Src.getValueType() != MVT::i16)
1436414380 return SDValue();
1436514381
14366- return SDValue();
14382+ // TODO: We bail out below if SrcOffset is not in the first dword (>= 4). It's
14383+ // possible we're missing out on some combine opportunities, but we'd need to
14384+ // weigh the cost of extracting the byte from the upper dwords.
14385+
14386+ std::optional<ByteProvider<SDValue>> BP0 =
14387+ calculateByteProvider(SDValue(N, 0), 0, 0, 0);
14388+ if (!BP0.has_value() || 4 <= BP0->SrcOffset)
14389+ return SDValue();
14390+ SDValue V0 = BP0->Src.value_or(SDValue());
14391+
14392+ std::optional<ByteProvider<SDValue>> BP1 =
14393+ calculateByteProvider(SDValue(N, 0), 1, 0, 1);
14394+ if (!BP1.has_value() || 4 <= BP1->SrcOffset)
14395+ return SDValue();
14396+ SDValue V1 = BP1->Src.value_or(SDValue());
14397+
14398+ if (!V0 || !V1 || V0 == V1)
14399+ return SDValue();
14400+
14401+ SelectionDAG &DAG = DCI.DAG;
14402+ SDLoc DL(N);
14403+ uint32_t PermMask = 0x0c0c0c0c;
14404+ if (V0) {
14405+ V0 = DAG.getBitcastedAnyExtOrTrunc(V0, DL, MVT::i32);
14406+ PermMask = (PermMask & ~0xFF) | (BP0->SrcOffset + 4);
14407+ }
14408+
14409+ if (V1) {
14410+ V1 = DAG.getBitcastedAnyExtOrTrunc(V1, DL, MVT::i32);
14411+ PermMask = (PermMask & ~(0xFF << 8)) | (BP1->SrcOffset << 8);
14412+ }
14413+
14414+ SDValue P = DAG.getNode(AMDGPUISD::PERM, DL, MVT::i32, V0, V1,
14415+ DAG.getConstant(PermMask, DL, MVT::i32));
14416+ return P;
1436714417}
1436814418
1436914419SDValue
@@ -17031,6 +17081,12 @@ SDValue SITargetLowering::PerformDAGCombine(SDNode *N,
1703117081 return performMinMaxCombine(N, DCI);
1703217082 case ISD::FMA:
1703317083 return performFMACombine(N, DCI);
17084+
17085+ case ISD::SHL:
17086+ if (auto Res = performLeftShiftCombine(N, DCI))
17087+ return Res;
17088+ break;
17089+
1703417090 case ISD::AND:
1703517091 return performAndCombine(N, DCI);
1703617092 case ISD::OR:
@@ -17045,8 +17101,9 @@ SDValue SITargetLowering::PerformDAGCombine(SDNode *N,
1704517101 }
1704617102 case ISD::XOR:
1704717103 return performXorCombine(N, DCI);
17104+ case ISD::ANY_EXTEND:
1704817105 case ISD::ZERO_EXTEND:
17049- return performZeroExtendCombine (N, DCI);
17106+ return performZeroOrAnyExtendCombine (N, DCI);
1705017107 case ISD::SIGN_EXTEND_INREG:
1705117108 return performSignExtendInRegCombine(N, DCI);
1705217109 case AMDGPUISD::FP_CLASS:
0 commit comments