diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 3875e59c3ede3..0ebe27bcbeeee 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -2923,19 +2923,21 @@ static Value *createLogicFromTable(const std::bitset<4> &Table, Value *Op0, } /// Fold icmp (add X, Y), C. -Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp, - BinaryOperator *Add, - const APInt &C) { - Value *Y = Add->getOperand(1); - Value *X = Add->getOperand(0); +Instruction *InstCombinerImpl::foldICmpAddLikeConstant(ICmpInst &Cmp, + BinaryOperator *AddLike, + const APInt &C) { + Value *X = nullptr; + Value *Y = nullptr; + if (!match(AddLike, m_AddLike(m_Value(X), m_Value(Y)))) + return nullptr; Value *Op0, *Op1; Instruction *Ext0, *Ext1; const CmpInst::Predicate Pred = Cmp.getPredicate(); - if (match(Add, - m_Add(m_CombineAnd(m_Instruction(Ext0), m_ZExtOrSExt(m_Value(Op0))), - m_CombineAnd(m_Instruction(Ext1), - m_ZExtOrSExt(m_Value(Op1))))) && + if (match(AddLike, m_AddLike(m_CombineAnd(m_Instruction(Ext0), + m_ZExtOrSExt(m_Value(Op0))), + m_CombineAnd(m_Instruction(Ext1), + m_ZExtOrSExt(m_Value(Op1))))) && Op0->getType()->isIntOrIntVectorTy(1) && Op1->getType()->isIntOrIntVectorTy(1)) { unsigned BW = C.getBitWidth(); @@ -2953,8 +2955,8 @@ Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp, Table[1] = ComputeTable(false, true); Table[2] = ComputeTable(true, false); Table[3] = ComputeTable(true, true); - if (auto *Cond = - createLogicFromTable(Table, Op0, Op1, Builder, Add->hasOneUse())) + if (auto *Cond = createLogicFromTable(Table, Op0, Op1, Builder, + AddLike->hasOneUse())) return replaceInstUsesWith(Cmp, Cond); } const APInt *C2; @@ -2962,14 +2964,16 @@ Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp, return nullptr; // Fold icmp pred (add X, C2), C. - Type *Ty = Add->getType(); + Type *Ty = AddLike->getType(); // If the add does not wrap, we can always adjust the compare by subtracting // the constants. Equality comparisons are handled elsewhere. SGE/SLE/UGE/ULE // are canonicalized to SGT/SLT/UGT/ULT. - if ((Add->hasNoSignedWrap() && + if (((AddLike->getOpcode() == Instruction::Or || + AddLike->hasNoSignedWrap()) && (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLT)) || - (Add->hasNoUnsignedWrap() && + ((AddLike->getOpcode() == Instruction::Or || + AddLike->hasNoUnsignedWrap()) && (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULT))) { bool Overflow; APInt NewC = @@ -3026,7 +3030,7 @@ Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp, return new ICmpInst(ICmpInst::ICMP_ULE, X, ConstantInt::get(Ty, C)); } - if (!Add->hasOneUse()) + if (!AddLike->hasOneUse()) return nullptr; // X+C (X & -C2) == C @@ -3679,6 +3683,9 @@ InstCombinerImpl::foldICmpInstWithConstantAllowUndef(ICmpInst &Cmp, Instruction *InstCombinerImpl::foldICmpBinOpWithConstant(ICmpInst &Cmp, BinaryOperator *BO, const APInt &C) { + if (Instruction *I = foldICmpAddLikeConstant(Cmp, BO, C)) + return I; + switch (BO->getOpcode()) { case Instruction::Xor: if (Instruction *I = foldICmpXorConstant(Cmp, BO, C)) @@ -3721,10 +3728,6 @@ Instruction *InstCombinerImpl::foldICmpBinOpWithConstant(ICmpInst &Cmp, if (Instruction *I = foldICmpSubConstant(Cmp, BO, C)) return I; break; - case Instruction::Add: - if (Instruction *I = foldICmpAddConstant(Cmp, BO, C)) - return I; - break; default: break; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index 21c61bd990184..df2b7e982392c 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -678,8 +678,8 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final const APInt &C); Instruction *foldICmpSubConstant(ICmpInst &Cmp, BinaryOperator *Sub, const APInt &C); - Instruction *foldICmpAddConstant(ICmpInst &Cmp, BinaryOperator *Add, - const APInt &C); + Instruction *foldICmpAddLikeConstant(ICmpInst &Cmp, BinaryOperator *AddLike, + const APInt &C); Instruction *foldICmpAndConstConst(ICmpInst &Cmp, BinaryOperator *And, const APInt &C1); Instruction *foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index ab55f235920a7..2fe48d7bc3db0 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -1550,7 +1550,7 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp, if (C0->getType() != Sel.getType()) return nullptr; - // ULT with 'add' of a constant is canonical. See foldICmpAddConstant(). + // ULT with 'add' of a constant is canonical. See foldICmpAddLikeConstant(). // FIXME: Are there more magic icmp predicate+constant pairs we must avoid? // Or should we just abandon this transform entirely? if (Pred == CmpInst::ICMP_ULT && match(X, m_Add(m_Value(), m_Constant()))) diff --git a/llvm/test/Transforms/InstCombine/icmp.ll b/llvm/test/Transforms/InstCombine/icmp.ll index 1f554c7b60256..b1d917422fce2 100644 --- a/llvm/test/Transforms/InstCombine/icmp.ll +++ b/llvm/test/Transforms/InstCombine/icmp.ll @@ -5006,7 +5006,6 @@ define i1 @or_positive_sgt_zero_multi_use(i8 %a) { ret i1 %cmp } - define i1 @disjoint_or_sgt_1(i8 %a, i8 %b) { ; CHECK-LABEL: @disjoint_or_sgt_1( ; CHECK-NEXT: [[B1:%.*]] = add nsw i8 [[B:%.*]], 2 @@ -5138,3 +5137,63 @@ entry: %cmp = icmp eq i8 %add2, %add1 ret i1 %cmp } + +define i1 @icmp_disjoint_or_sgt(i32 %x) { +; CHECK-LABEL: @icmp_disjoint_or_sgt( +; CHECK-NEXT: [[C:%.*]] = icmp sgt i32 [[X:%.*]], 35 +; CHECK-NEXT: ret i1 [[C]] +; + %or_ = or disjoint i32 %x, 6 + %C = icmp sgt i32 %or_, 41 + ret i1 %C +} + +define i1 @icmp_disjoint_or_slt(i32 %x) { +; CHECK-LABEL: @icmp_disjoint_or_slt( +; CHECK-NEXT: [[C:%.*]] = icmp slt i32 [[X:%.*]], 35 +; CHECK-NEXT: ret i1 [[C]] +; + %or_ = or disjoint i32 %x, 6 + %C = icmp slt i32 %or_, 41 + ret i1 %C +} + +define i1 @icmp_disjoint_or_ult(i32 %x) { +; CHECK-LABEL: @icmp_disjoint_or_ult( +; CHECK-NEXT: [[C:%.*]] = icmp ult i32 [[X:%.*]], 35 +; CHECK-NEXT: ret i1 [[C]] +; + %or_ = or disjoint i32 %x, 6 + %C = icmp ult i32 %or_, 41 + ret i1 %C +} + +define i1 @icmp_disjoint_or_ugt(i32 %x) { +; CHECK-LABEL: @icmp_disjoint_or_ugt( +; CHECK-NEXT: [[C:%.*]] = icmp ugt i32 [[X:%.*]], 35 +; CHECK-NEXT: ret i1 [[C]] +; + %or_ = or disjoint i32 %x, 6 + %C = icmp ugt i32 %or_, 41 + ret i1 %C +} + +define i1 @icmp_disjoint_or_eq(i32 %x) { +; CHECK-LABEL: @icmp_disjoint_or_eq( +; CHECK-NEXT: [[C:%.*]] = icmp eq i32 [[X:%.*]], 0 +; CHECK-NEXT: ret i1 [[C]] +; + %or_ = or disjoint i32 %x, 5 + %C = icmp eq i32 %or_, 5 + ret i1 %C +} + +define i1 @icmp_disjoint_or_be(i32 %x) { +; CHECK-LABEL: @icmp_disjoint_or_be( +; CHECK-NEXT: [[C:%.*]] = icmp ne i32 [[X:%.*]], 0 +; CHECK-NEXT: ret i1 [[C]] +; + %or_ = or disjoint i32 %x, 5 + %C = icmp ne i32 %or_, 5 + ret i1 %C +}