@@ -653,102 +653,85 @@ void VPlanTransforms::attachCheckBlock(VPlan &Plan, Value *Cond,
653653 }
654654}
655655
656- static VPValue *getMinMaxCompareValue (VPSingleDefRecipe *MinMaxOp,
657- VPReductionPHIRecipe *RedPhi) {
658- auto *RepR = dyn_cast<VPReplicateRecipe>(MinMaxOp);
659- if (!isa<VPWidenIntrinsicRecipe>(MinMaxOp) &&
660- !(RepR && (isa<IntrinsicInst>(RepR->getUnderlyingInstr ()))))
661- return nullptr ;
662-
663- if (MinMaxOp->getOperand (0 ) == RedPhi)
664- return MinMaxOp->getOperand (1 );
665- return MinMaxOp->getOperand (0 );
666- }
667-
668- // / Returns true if there VPlan is read-only and execution can be resumed at the
669- // / beginning of the last vector iteration in the scalar loop
670- static bool canResumeInScalarLoopFromVectorLoop (VPlan &Plan) {
671- for (VPBlockBase *VPB : vp_depth_first_shallow (
672- Plan.getVectorLoopRegion ()->getEntryBasicBlock ())) {
673- auto *VPBB = dyn_cast<VPBasicBlock>(VPB);
674- if (!VPBB)
675- return false ;
676- for (auto &R : *VPBB) {
677- if (match (&R, m_BranchOnCount (m_VPValue (), m_VPValue ())))
678- continue ;
679- if (R.mayWriteToMemory ())
680- return false ;
681- }
682- }
683- return true ;
684- }
685-
686656bool VPlanTransforms::handleMaxMinNumReductionsWithoutFastMath (VPlan &Plan) {
687657 VPRegionBlock *LoopRegion = Plan.getVectorLoopRegion ();
688- VPValue *AnyNaN = nullptr ;
689658 VPReductionPHIRecipe *RedPhiR = nullptr ;
690- VPRecipeWithIRFlags *MinMaxOp = nullptr ;
659+ VPValue *MinMaxOp = nullptr ;
691660 bool HasUnsupportedPhi = false ;
661+
662+ auto GetMinMaxCompareValue = [](VPSingleDefRecipe *MinMaxOp,
663+ VPReductionPHIRecipe *RedPhi) -> VPValue * {
664+ auto *RepR = dyn_cast<VPReplicateRecipe>(MinMaxOp);
665+ if (!isa<VPWidenIntrinsicRecipe>(MinMaxOp) &&
666+ !(RepR && (isa<IntrinsicInst>(RepR->getUnderlyingInstr ()))))
667+ return nullptr ;
668+
669+ if (MinMaxOp->getOperand (0 ) == RedPhi)
670+ return MinMaxOp->getOperand (1 );
671+ assert (MinMaxOp->getOperand (1 ) == RedPhi &&
672+ " Reduction phi operand expected" );
673+ return MinMaxOp->getOperand (0 );
674+ };
675+
692676 for (auto &R : LoopRegion->getEntryBasicBlock ()->phis ()) {
677+ // TODO: Also support first-order recurrence phis.
693678 HasUnsupportedPhi |=
694679 !isa<VPCanonicalIVPHIRecipe, VPWidenIntOrFpInductionRecipe,
695680 VPReductionPHIRecipe>(&R);
696681 auto *Cur = dyn_cast<VPReductionPHIRecipe>(&R);
697682 if (!Cur)
698683 continue ;
684+ // For now, only a single reduction is supported.
685+ // TODO: Support multiple MaxNum/MinNum reductions and other reductions.
699686 if (RedPhiR)
700687 return false ;
701- if (Cur->getRecurrenceKind () != RecurKind::FMaxNumNoFMFs &&
702- Cur->getRecurrenceKind () != RecurKind::FMinNumNoFMFs )
688+ if (Cur->getRecurrenceKind () != RecurKind::FMaxNum &&
689+ Cur->getRecurrenceKind () != RecurKind::FMinNum )
703690 continue ;
704691
705692 RedPhiR = Cur;
706- MinMaxOp = dyn_cast<VPRecipeWithIRFlags>(
693+ auto *MinMaxR = dyn_cast<VPRecipeWithIRFlags>(
707694 RedPhiR->getBackedgeValue ()->getDefiningRecipe ());
708- if (!MinMaxOp )
695+ if (!MinMaxR )
709696 return false ;
710- VPValue *In = getMinMaxCompareValue (MinMaxOp , RedPhiR);
711- if (!In )
697+ MinMaxOp = GetMinMaxCompareValue (MinMaxR , RedPhiR);
698+ if (!MinMaxOp )
712699 return false ;
713-
714- auto *IsNaN =
715- new VPInstruction (Instruction::FCmp, {In, In}, {CmpInst::FCMP_UNO}, {});
716- IsNaN->insertBefore (MinMaxOp);
717- AnyNaN = new VPInstruction (VPInstruction::AnyOf, {IsNaN});
718- AnyNaN->getDefiningRecipe ()->insertAfter (IsNaN);
719700 }
720701
721- if (!AnyNaN )
702+ if (!RedPhiR )
722703 return true ;
723704
724- if (HasUnsupportedPhi || !canResumeInScalarLoopFromVectorLoop ( Plan))
705+ if (HasUnsupportedPhi || !Plan. hasScalarTail ( ))
725706 return false ;
726707
708+ // / Check if the vector loop of \p Plan can early exit and restart
709+ // / execution of last vector iteration in the scalar loop. This requires all
710+ // / recipes up to early exit point be side-effect free as they are
711+ // / re-executed. Currently we check that the loop is free of any recipe that
712+ // / may write to memory. Expected to operate on an early VPlan w/o nested
713+ // / regions.
714+ for (VPBlockBase *VPB : vp_depth_first_shallow (
715+ Plan.getVectorLoopRegion ()->getEntryBasicBlock ())) {
716+ auto *VPBB = cast<VPBasicBlock>(VPB);
717+ for (auto &R : *VPBB) {
718+ if (match (&R, m_BranchOnCount (m_VPValue (), m_VPValue ())))
719+ continue ;
720+ if (R.mayWriteToMemory ())
721+ return false ;
722+ }
723+ }
724+
727725 auto *MiddleVPBB = Plan.getMiddleBlock ();
728726 auto *RdxResult = dyn_cast<VPInstruction>(&MiddleVPBB->front ());
729727 if (!RdxResult ||
730728 RdxResult->getOpcode () != VPInstruction::ComputeReductionResult ||
731729 RdxResult->getOperand (0 ) != RedPhiR)
732730 return false ;
733731
734- auto *ScalarPH = Plan.getScalarPreheader ();
735- // Update the resume phis in the scalar preheader. They all must either resume
736- // from the reduction result or the canonical induction. Bail out if there are
737- // other resume phis.
738- for (auto &R : ScalarPH->phis ()) {
739- auto *ResumeR = cast<VPPhi>(&R);
740- VPValue *VecV = ResumeR->getOperand (0 );
741- VPValue *BypassV = ResumeR->getOperand (ResumeR->getNumOperands () - 1 );
742- if (VecV != RdxResult && VecV != &Plan.getVectorTripCount ())
743- return false ;
744- ResumeR->setOperand (
745- 1 , VecV == &Plan.getVectorTripCount () ? Plan.getCanonicalIV () : VecV);
746- ResumeR->addOperand (BypassV);
747- }
748-
749732 // Create a new reduction phi recipe with either FMin/FMax, replacing
750- // FMinNumNoFMFs/FMaxNumNoFMFs .
751- RecurKind NewRK = RedPhiR->getRecurrenceKind () != RecurKind::FMinNumNoFMFs
733+ // FMinNum/FMaxNum .
734+ RecurKind NewRK = RedPhiR->getRecurrenceKind () == RecurKind::FMinNum
752735 ? RecurKind::FMin
753736 : RecurKind::FMax;
754737 auto *NewRedPhiR = new VPReductionPHIRecipe (
@@ -769,23 +752,40 @@ bool VPlanTransforms::handleMaxMinNumReductionsWithoutFastMath(VPlan &Plan) {
769752 auto *IsLatchExitTaken =
770753 Builder.createICmp (CmpInst::ICMP_EQ, LatchExitingBranch->getOperand (0 ),
771754 LatchExitingBranch->getOperand (1 ));
755+
756+ VPValue *IsNaN = Builder.createFCmp (CmpInst::FCMP_UNO, MinMaxOp, MinMaxOp);
757+ VPValue *AnyNaN = Builder.createNaryOp (VPInstruction::AnyOf, {IsNaN});
772758 auto *AnyExitTaken =
773759 Builder.createNaryOp (Instruction::Or, {AnyNaN, IsLatchExitTaken});
774760 Builder.createNaryOp (VPInstruction::BranchOnCond, AnyExitTaken);
775761 LatchExitingBranch->eraseFromParent ();
776762
777- // Split the middle block and introduce a new block, branching to the scalar
778- // preheader to resume iteration in the scalar loop if any NaNs have been
779- // encountered.
780- MiddleVPBB->splitAt (std::prev (MiddleVPBB->end ()));
763+ // If we exit early due to NaNs, compute the final reduction result based on
764+ // the reduction phi at the beginning of the last vector iteration.
781765 Builder.setInsertPoint (MiddleVPBB, MiddleVPBB->begin ());
782766 auto *NewSel =
783767 Builder.createSelect (AnyNaN, NewRedPhiR, RdxResult->getOperand (1 ));
784768 RdxResult->setOperand (1 , NewSel);
785- Builder.setInsertPoint (MiddleVPBB);
786- Builder.createNaryOp (VPInstruction::BranchOnCond, AnyNaN);
787- VPBlockUtils::connectBlocks (MiddleVPBB, ScalarPH);
788- MiddleVPBB->swapSuccessors ();
789- std::swap (ScalarPH->getPredecessors ()[1 ], ScalarPH->getPredecessors ().back ());
769+
770+ auto *ScalarPH = Plan.getScalarPreheader ();
771+ // Update the resume phis for inductions in the scalar preheader. If AnyNaN is
772+ // true, the resume from the start of the last vector iteration via the
773+ // canonical IV, otherwise from the original value.
774+ for (auto &R : ScalarPH->phis ()) {
775+ auto *ResumeR = cast<VPPhi>(&R);
776+ VPValue *VecV = ResumeR->getOperand (0 );
777+ if (VecV == RdxResult)
778+ continue ;
779+ if (VecV != &Plan.getVectorTripCount ())
780+ return false ;
781+ auto *NewSel = Builder.createSelect (AnyNaN, Plan.getCanonicalIV (), VecV);
782+ ResumeR->setOperand (0 , NewSel);
783+ }
784+
785+ auto *MiddleTerm = MiddleVPBB->getTerminator ();
786+ Builder.setInsertPoint (MiddleTerm);
787+ VPValue *MiddleCond = MiddleTerm->getOperand (0 );
788+ VPValue *NewCond = Builder.createAnd (MiddleCond, Builder.createNot (AnyNaN));
789+ MiddleTerm->setOperand (0 , NewCond);
790790 return true ;
791791}
0 commit comments