aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRoman Lebedev <lebedev.ri@gmail.com>2021-04-10 19:37:59 +0300
committerRoman Lebedev <lebedev.ri@gmail.com>2021-04-10 19:38:55 +0300
commite8c7f43e2c2c6f3581ec1c6489ec21ad9f98958a (patch)
tree85a7b7ec57c83340778733ffba90c16d38ed80f4
parentRevert "[NFC][ConstantRange] Add 'icmp' helper method" (diff)
downloadllvm-project-e8c7f43e2c2c6f3581ec1c6489ec21ad9f98958a.tar.gz
llvm-project-e8c7f43e2c2c6f3581ec1c6489ec21ad9f98958a.tar.bz2
llvm-project-e8c7f43e2c2c6f3581ec1c6489ec21ad9f98958a.zip
[NFC][ConstantRange] Add 'icmp' helper method
"Does the predicate hold between two ranges?" Not very surprisingly, some places were already doing this check, without explicitly naming the algorithm, cleanup them all.
-rw-r--r--llvm/include/llvm/Analysis/ValueLattice.h6
-rw-r--r--llvm/include/llvm/IR/ConstantRange.h4
-rw-r--r--llvm/include/llvm/IR/IntrinsicInst.h41
-rw-r--r--llvm/lib/Analysis/InstructionSimplify.cpp7
-rw-r--r--llvm/lib/Analysis/ScalarEvolution.cpp12
-rw-r--r--llvm/lib/IR/ConstantRange.cpp5
-rw-r--r--llvm/lib/Transforms/IPO/AttributorAttributes.cpp5
-rw-r--r--llvm/unittests/IR/ConstantRangeTest.cpp49
8 files changed, 107 insertions, 22 deletions
diff --git a/llvm/include/llvm/Analysis/ValueLattice.h b/llvm/include/llvm/Analysis/ValueLattice.h
index 5ff9c4a6b080..1b32fca50697 100644
--- a/llvm/include/llvm/Analysis/ValueLattice.h
+++ b/llvm/include/llvm/Analysis/ValueLattice.h
@@ -474,11 +474,9 @@ public:
const auto &CR = getConstantRange();
const auto &OtherCR = Other.getConstantRange();
- if (ConstantRange::makeSatisfyingICmpRegion(Pred, OtherCR).contains(CR))
+ if (CR.icmp(Pred, OtherCR))
return ConstantInt::getTrue(Ty);
- if (ConstantRange::makeSatisfyingICmpRegion(
- CmpInst::getInversePredicate(Pred), OtherCR)
- .contains(CR))
+ if (CR.icmp(CmpInst::getInversePredicate(Pred), OtherCR))
return ConstantInt::getFalse(Ty);
return nullptr;
diff --git a/llvm/include/llvm/IR/ConstantRange.h b/llvm/include/llvm/IR/ConstantRange.h
index 20e8e67436a4..44b8c395c89e 100644
--- a/llvm/include/llvm/IR/ConstantRange.h
+++ b/llvm/include/llvm/IR/ConstantRange.h
@@ -124,6 +124,10 @@ public:
static ConstantRange makeExactICmpRegion(CmpInst::Predicate Pred,
const APInt &Other);
+ /// Does the predicate \p Pred hold between ranges this and \p Other?
+ /// NOTE: false does not mean that inverse predicate holds!
+ bool icmp(CmpInst::Predicate Pred, const ConstantRange &Other) const;
+
/// Produce the largest range containing all X such that "X BinOp Y" is
/// guaranteed not to wrap (overflow) for *all* Y in Other. However, there may
/// be *some* Y in Other for which additional X not contained in the result
diff --git a/llvm/include/llvm/IR/IntrinsicInst.h b/llvm/include/llvm/IR/IntrinsicInst.h
index 6c825d380fc9..b688ece7067e 100644
--- a/llvm/include/llvm/IR/IntrinsicInst.h
+++ b/llvm/include/llvm/IR/IntrinsicInst.h
@@ -458,6 +458,47 @@ public:
}
};
+/// This class represents min/max intrinsics.
+class LimitingIntrinsic : public IntrinsicInst {
+public:
+ static bool classof(const IntrinsicInst *I) {
+ switch (I->getIntrinsicID()) {
+ case Intrinsic::umin:
+ case Intrinsic::umax:
+ case Intrinsic::smin:
+ case Intrinsic::smax:
+ return true;
+ default:
+ return false;
+ }
+ }
+ static bool classof(const Value *V) {
+ return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
+ }
+
+ Value *getLHS() const { return const_cast<Value *>(getArgOperand(0)); }
+ Value *getRHS() const { return const_cast<Value *>(getArgOperand(1)); }
+
+ /// Returns the comparison predicate underlying the intrinsic.
+ ICmpInst::Predicate getPredicate() const {
+ switch (getIntrinsicID()) {
+ case Intrinsic::umin:
+ return ICmpInst::Predicate::ICMP_ULT;
+ case Intrinsic::umax:
+ return ICmpInst::Predicate::ICMP_UGT;
+ case Intrinsic::smin:
+ return ICmpInst::Predicate::ICMP_SLT;
+ case Intrinsic::smax:
+ return ICmpInst::Predicate::ICMP_SGT;
+ default:
+ llvm_unreachable("Invalid intrinsic");
+ }
+ }
+
+ /// Whether the intrinsic is signed or unsigned.
+ bool isSigned() const { return ICmpInst::isSigned(getPredicate()); };
+};
+
/// This class represents an intrinsic that is based on a binary operation.
/// This includes op.with.overflow and saturating add/sub intrinsics.
class BinaryOpIntrinsic : public IntrinsicInst {
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index a6d3ca64189d..b233a0f3eb2d 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -3451,13 +3451,10 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
auto LHS_CR = getConstantRangeFromMetadata(
*LHS_Instr->getMetadata(LLVMContext::MD_range));
- auto Satisfied_CR = ConstantRange::makeSatisfyingICmpRegion(Pred, RHS_CR);
- if (Satisfied_CR.contains(LHS_CR))
+ if (LHS_CR.icmp(Pred, RHS_CR))
return ConstantInt::getTrue(RHS->getContext());
- auto InversedSatisfied_CR = ConstantRange::makeSatisfyingICmpRegion(
- CmpInst::getInversePredicate(Pred), RHS_CR);
- if (InversedSatisfied_CR.contains(LHS_CR))
+ if (LHS_CR.icmp(CmpInst::getInversePredicate(Pred), RHS_CR))
return ConstantInt::getFalse(RHS->getContext());
}
}
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index a481c23c3d05..4630c5562623 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -9843,10 +9843,9 @@ bool ScalarEvolution::isKnownPredicateViaConstantRanges(
// This code is split out from isKnownPredicate because it is called from
// within isLoopEntryGuardedByCond.
- auto CheckRanges =
- [&](const ConstantRange &RangeLHS, const ConstantRange &RangeRHS) {
- return ConstantRange::makeSatisfyingICmpRegion(Pred, RangeRHS)
- .contains(RangeLHS);
+ auto CheckRanges = [&](const ConstantRange &RangeLHS,
+ const ConstantRange &RangeRHS) {
+ return RangeLHS.icmp(Pred, RangeRHS);
};
// The check at the top of the function catches the case where the values are
@@ -11148,12 +11147,9 @@ bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred,
// We can also compute the range of values for `LHS` that satisfy the
// consequent, "`LHS` `Pred` `RHS`":
const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
- ConstantRange SatisfyingLHSRange =
- ConstantRange::makeSatisfyingICmpRegion(Pred, ConstRHS);
-
// The antecedent implies the consequent if every value of `LHS` that
// satisfies the antecedent also satisfies the consequent.
- return SatisfyingLHSRange.contains(LHSRange);
+ return LHSRange.icmp(Pred, ConstRHS);
}
bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp
index 4dbe1a1b902d..b38599fa7d98 100644
--- a/llvm/lib/IR/ConstantRange.cpp
+++ b/llvm/lib/IR/ConstantRange.cpp
@@ -181,6 +181,11 @@ bool ConstantRange::getEquivalentICmp(CmpInst::Predicate &Pred,
return Success;
}
+bool ConstantRange::icmp(CmpInst::Predicate Pred,
+ const ConstantRange &Other) const {
+ return makeSatisfyingICmpRegion(Pred, Other).contains(*this);
+}
+
/// Exact mul nuw region for single element RHS.
static ConstantRange makeExactMulNUWRegion(const APInt &V) {
unsigned BitWidth = V.getBitWidth();
diff --git a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
index 867dd3118cea..33ed2b4423a8 100644
--- a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
+++ b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
@@ -7328,13 +7328,10 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl {
auto AllowedRegion =
ConstantRange::makeAllowedICmpRegion(CmpI->getPredicate(), RHSAARange);
- auto SatisfyingRegion = ConstantRange::makeSatisfyingICmpRegion(
- CmpI->getPredicate(), RHSAARange);
-
if (AllowedRegion.intersectWith(LHSAARange).isEmptySet())
MustFalse = true;
- if (SatisfyingRegion.contains(LHSAARange))
+ if (LHSAARange.icmp(CmpI->getPredicate(), RHSAARange))
MustTrue = true;
assert((!MustTrue || !MustFalse) &&
diff --git a/llvm/unittests/IR/ConstantRangeTest.cpp b/llvm/unittests/IR/ConstantRangeTest.cpp
index 12362b9460f9..f8816e4d43df 100644
--- a/llvm/unittests/IR/ConstantRangeTest.cpp
+++ b/llvm/unittests/IR/ConstantRangeTest.cpp
@@ -6,9 +6,10 @@
//
//===----------------------------------------------------------------------===//
+#include "llvm/IR/ConstantRange.h"
#include "llvm/ADT/BitVector.h"
+#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallBitVector.h"
-#include "llvm/IR/ConstantRange.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Operator.h"
#include "llvm/Support/KnownBits.h"
@@ -1509,6 +1510,52 @@ TEST(ConstantRange, MakeSatisfyingICmpRegion) {
ConstantRange(APInt(8, 4), APInt(8, -128)));
}
+static bool icmp(CmpInst::Predicate Pred, const APInt &LHS, const APInt &RHS) {
+ switch (Pred) {
+ case CmpInst::Predicate::ICMP_EQ:
+ return LHS.eq(RHS);
+ case CmpInst::Predicate::ICMP_NE:
+ return LHS.ne(RHS);
+ case CmpInst::Predicate::ICMP_UGT:
+ return LHS.ugt(RHS);
+ case CmpInst::Predicate::ICMP_UGE:
+ return LHS.uge(RHS);
+ case CmpInst::Predicate::ICMP_ULT:
+ return LHS.ult(RHS);
+ case CmpInst::Predicate::ICMP_ULE:
+ return LHS.ule(RHS);
+ case CmpInst::Predicate::ICMP_SGT:
+ return LHS.sgt(RHS);
+ case CmpInst::Predicate::ICMP_SGE:
+ return LHS.sge(RHS);
+ case CmpInst::Predicate::ICMP_SLT:
+ return LHS.slt(RHS);
+ case CmpInst::Predicate::ICMP_SLE:
+ return LHS.sle(RHS);
+ default:
+ llvm_unreachable("Not an ICmp predicate!");
+ }
+}
+
+void ICmpTestImpl(CmpInst::Predicate Pred) {
+ unsigned Bits = 4;
+ EnumerateTwoConstantRanges(
+ Bits, [&](const ConstantRange &CR1, const ConstantRange &CR2) {
+ bool Exhaustive = true;
+ ForeachNumInConstantRange(CR1, [&](const APInt &N1) {
+ ForeachNumInConstantRange(
+ CR2, [&](const APInt &N2) { Exhaustive &= icmp(Pred, N1, N2); });
+ });
+ EXPECT_EQ(CR1.icmp(Pred, CR2), Exhaustive);
+ });
+}
+
+TEST(ConstantRange, ICmp) {
+ for (auto Pred : seq<unsigned>(CmpInst::Predicate::FIRST_ICMP_PREDICATE,
+ 1 + CmpInst::Predicate::LAST_ICMP_PREDICATE))
+ ICmpTestImpl((CmpInst::Predicate)Pred);
+}
+
TEST(ConstantRange, MakeGuaranteedNoWrapRegion) {
const int IntMin4Bits = 8;
const int IntMax4Bits = 7;