From 01b56b8bddaee03ab5261e6bd67b9511dce00cd6 Mon Sep 17 00:00:00 2001 From: Philip Reames Date: Thu, 10 Feb 2022 15:50:50 -0800 Subject: [PATCH] [SCEVPredicateRewriter] Remove assumption top level predicate is a union [NFC] --- llvm/include/llvm/Analysis/ScalarEvolution.h | 2 +- llvm/lib/Analysis/ScalarEvolution.cpp | 26 ++++++++++++-------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h index 36e4f84f39f6..30e62a640363 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -1176,7 +1176,7 @@ public: /// Re-writes the SCEV according to the Predicates in \p A. const SCEV *rewriteUsingPredicate(const SCEV *S, const Loop *L, - const SCEVUnionPredicate &A); + const SCEVPredicate &A); /// Tries to convert the \p S expression to an AddRec expression, /// adding additional predicates to \p Preds as required. const SCEVAddRecExpr *convertSCEVToAddRecWithPredicates( diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 1a31f96c1c24..a9e102a1056a 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -13634,19 +13634,25 @@ public: /// \p NewPreds such that the result will be an AddRecExpr. static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE, SmallPtrSetImpl *NewPreds, - const SCEVUnionPredicate *Pred) { + const SCEVPredicate *Pred) { SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred); return Rewriter.visit(S); } const SCEV *visitUnknown(const SCEVUnknown *Expr) { if (Pred) { - auto ExprPreds = Pred->getPredicatesForExpr(Expr); - for (auto *Pred : ExprPreds) - if (const auto *IPred = dyn_cast(Pred)) - if (IPred->getLHS() == Expr && - IPred->getPredicate() == ICmpInst::ICMP_EQ) - return IPred->getRHS(); + if (auto *U = dyn_cast(Pred)) { + auto ExprPreds = U->getPredicatesForExpr(Expr); + for (auto *Pred : ExprPreds) + if (const auto *IPred = dyn_cast(Pred)) + if (IPred->getLHS() == Expr && + IPred->getPredicate() == ICmpInst::ICMP_EQ) + return IPred->getRHS(); + } else if (const auto *IPred = dyn_cast(Pred)) { + if (IPred->getLHS() == Expr && + IPred->getPredicate() == ICmpInst::ICMP_EQ) + return IPred->getRHS(); + } } return convertToAddRecWithPreds(Expr); } @@ -13686,7 +13692,7 @@ public: private: explicit SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE, SmallPtrSetImpl *NewPreds, - const SCEVUnionPredicate *Pred) + const SCEVPredicate *Pred) : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {} bool addOverflowAssumption(const SCEVPredicate *P) { @@ -13731,7 +13737,7 @@ private: } SmallPtrSetImpl *NewPreds; - const SCEVUnionPredicate *Pred; + const SCEVPredicate *Pred; const Loop *L; }; @@ -13739,7 +13745,7 @@ private: const SCEV * ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L, - const SCEVUnionPredicate &Preds) { + const SCEVPredicate &Preds) { return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds); }