@@ -652,3 +652,140 @@ void VPlanTransforms::attachCheckBlock(VPlan &Plan, Value *Cond,
652652 Term->addMetadata (LLVMContext::MD_prof, BranchWeights);
653653 }
654654}
655+
656+ static VPValue *getMinMaxCompareValue (VPSingleDefRecipe *MinMaxOp,
657+ VPReductionPHIRecipe *RedPhi) {
658+ auto *RepR = dyn_cast<VPReplicateRecipe>(MinMaxOp);
659+ if (!isa<VPWidenIntrinsicRecipe>(MinMaxOp) &&
660+ !(RepR && (isa<IntrinsicInst>(RepR->getUnderlyingInstr ()))))
661+ return nullptr ;
662+
663+ if (MinMaxOp->getOperand (0 ) == RedPhi)
664+ return MinMaxOp->getOperand (1 );
665+ return MinMaxOp->getOperand (0 );
666+ }
667+
668+ // / Returns true if there VPlan is read-only and execution can be resumed at the
669+ // / beginning of the last vector iteration in the scalar loop
670+ static bool canResumeInScalarLoopFromVectorLoop (VPlan &Plan) {
671+ for (VPBlockBase *VPB : vp_depth_first_shallow (
672+ Plan.getVectorLoopRegion ()->getEntryBasicBlock ())) {
673+ auto *VPBB = dyn_cast<VPBasicBlock>(VPB);
674+ if (!VPBB)
675+ return false ;
676+ for (auto &R : *VPBB) {
677+ if (match (&R, m_BranchOnCount (m_VPValue (), m_VPValue ())))
678+ continue ;
679+ if (R.mayWriteToMemory ())
680+ return false ;
681+ }
682+ }
683+ return true ;
684+ }
685+
686+ bool VPlanTransforms::handleMaxMinNumReductionsWithoutFastMath (VPlan &Plan) {
687+ VPRegionBlock *LoopRegion = Plan.getVectorLoopRegion ();
688+ VPValue *AnyNaN = nullptr ;
689+ VPReductionPHIRecipe *RedPhiR = nullptr ;
690+ VPRecipeWithIRFlags *MinMaxOp = nullptr ;
691+ bool HasUnsupportedPhi = false ;
692+ for (auto &R : LoopRegion->getEntryBasicBlock ()->phis ()) {
693+ HasUnsupportedPhi |=
694+ !isa<VPCanonicalIVPHIRecipe, VPWidenIntOrFpInductionRecipe,
695+ VPReductionPHIRecipe>(&R);
696+ auto *Cur = dyn_cast<VPReductionPHIRecipe>(&R);
697+ if (!Cur)
698+ continue ;
699+ if (RedPhiR)
700+ return false ;
701+ if (Cur->getRecurrenceKind () != RecurKind::FMaxNumNoFMFs &&
702+ Cur->getRecurrenceKind () != RecurKind::FMinNumNoFMFs)
703+ continue ;
704+
705+ RedPhiR = Cur;
706+ MinMaxOp = dyn_cast<VPRecipeWithIRFlags>(
707+ RedPhiR->getBackedgeValue ()->getDefiningRecipe ());
708+ if (!MinMaxOp)
709+ return false ;
710+ VPValue *In = getMinMaxCompareValue (MinMaxOp, RedPhiR);
711+ if (!In)
712+ return false ;
713+
714+ auto *IsNaN =
715+ new VPInstruction (Instruction::FCmp, {In, In}, {CmpInst::FCMP_UNO}, {});
716+ IsNaN->insertBefore (MinMaxOp);
717+ AnyNaN = new VPInstruction (VPInstruction::AnyOf, {IsNaN});
718+ AnyNaN->getDefiningRecipe ()->insertAfter (IsNaN);
719+ }
720+
721+ if (!AnyNaN)
722+ return true ;
723+
724+ if (HasUnsupportedPhi || !canResumeInScalarLoopFromVectorLoop (Plan))
725+ return false ;
726+
727+ auto *MiddleVPBB = Plan.getMiddleBlock ();
728+ auto *RdxResult = dyn_cast<VPInstruction>(&MiddleVPBB->front ());
729+ if (!RdxResult ||
730+ RdxResult->getOpcode () != VPInstruction::ComputeReductionResult ||
731+ RdxResult->getOperand (0 ) != RedPhiR)
732+ return false ;
733+
734+ auto *ScalarPH = Plan.getScalarPreheader ();
735+ // Update the resume phis in the scalar preheader. They all must either resume
736+ // from the reduction result or the canonical induction. Bail out if there are
737+ // other resume phis.
738+ for (auto &R : ScalarPH->phis ()) {
739+ auto *ResumeR = cast<VPPhi>(&R);
740+ VPValue *VecV = ResumeR->getOperand (0 );
741+ VPValue *BypassV = ResumeR->getOperand (ResumeR->getNumOperands () - 1 );
742+ if (VecV != RdxResult && VecV != &Plan.getVectorTripCount ())
743+ return false ;
744+ ResumeR->setOperand (
745+ 1 , VecV == &Plan.getVectorTripCount () ? Plan.getCanonicalIV () : VecV);
746+ ResumeR->addOperand (BypassV);
747+ }
748+
749+ // Create a new reduction phi recipe with either FMin/FMax, replacing
750+ // FMinNumNoFMFs/FMaxNumNoFMFs.
751+ RecurKind NewRK = RedPhiR->getRecurrenceKind () != RecurKind::FMinNumNoFMFs
752+ ? RecurKind::FMin
753+ : RecurKind::FMax;
754+ auto *NewRedPhiR = new VPReductionPHIRecipe (
755+ cast<PHINode>(RedPhiR->getUnderlyingValue ()), NewRK,
756+ *RedPhiR->getStartValue (), RedPhiR->isInLoop (), RedPhiR->isOrdered ());
757+ NewRedPhiR->addOperand (RedPhiR->getOperand (1 ));
758+ NewRedPhiR->insertBefore (RedPhiR);
759+ RedPhiR->replaceAllUsesWith (NewRedPhiR);
760+ RedPhiR->eraseFromParent ();
761+
762+ // Update the loop exit condition to exit if either any of the inputs is NaN
763+ // or the vector trip count is reached.
764+ VPBasicBlock *LatchVPBB = LoopRegion->getExitingBasicBlock ();
765+ VPBuilder Builder (LatchVPBB->getTerminator ());
766+ auto *LatchExitingBranch = cast<VPInstruction>(LatchVPBB->getTerminator ());
767+ assert (LatchExitingBranch->getOpcode () == VPInstruction::BranchOnCount &&
768+ " Unexpected terminator" );
769+ auto *IsLatchExitTaken =
770+ Builder.createICmp (CmpInst::ICMP_EQ, LatchExitingBranch->getOperand (0 ),
771+ LatchExitingBranch->getOperand (1 ));
772+ auto *AnyExitTaken =
773+ Builder.createNaryOp (Instruction::Or, {AnyNaN, IsLatchExitTaken});
774+ Builder.createNaryOp (VPInstruction::BranchOnCond, AnyExitTaken);
775+ LatchExitingBranch->eraseFromParent ();
776+
777+ // Split the middle block and introduce a new block, branching to the scalar
778+ // preheader to resume iteration in the scalar loop if any NaNs have been
779+ // encountered.
780+ MiddleVPBB->splitAt (std::prev (MiddleVPBB->end ()));
781+ Builder.setInsertPoint (MiddleVPBB, MiddleVPBB->begin ());
782+ auto *NewSel =
783+ Builder.createSelect (AnyNaN, NewRedPhiR, RdxResult->getOperand (1 ));
784+ RdxResult->setOperand (1 , NewSel);
785+ Builder.setInsertPoint (MiddleVPBB);
786+ Builder.createNaryOp (VPInstruction::BranchOnCond, AnyNaN);
787+ VPBlockUtils::connectBlocks (MiddleVPBB, ScalarPH);
788+ MiddleVPBB->swapSuccessors ();
789+ std::swap (ScalarPH->getPredecessors ()[1 ], ScalarPH->getPredecessors ().back ());
790+ return true ;
791+ }
0 commit comments