From 1ad7ff982a767f7a510c9eb7e1048fdb23214912 Mon Sep 17 00:00:00 2001 From: Mikhail Gudim Date: Tue, 19 Dec 2023 01:58:12 -0500 Subject: [PATCH 1/2] [InstCombine] Extend `foldICmpAddConstant` to disjoint `or`. --- .../InstCombine/InstCombineCompares.cpp | 40 ++++++++++--------- .../InstCombine/InstCombineInternal.h | 4 +- .../InstCombine/InstCombineSelect.cpp | 2 +- llvm/test/Transforms/InstCombine/icmp.ll | 11 ++++- 4 files changed, 34 insertions(+), 23 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 3875e59c3ede3..74c8f7f496f9b 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -2923,19 +2923,21 @@ static Value *createLogicFromTable(const std::bitset<4> &Table, Value *Op0, } /// Fold icmp (add X, Y), C. -Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp, - BinaryOperator *Add, - const APInt &C) { - Value *Y = Add->getOperand(1); - Value *X = Add->getOperand(0); +Instruction *InstCombinerImpl::foldICmpAddLikeConstant(ICmpInst &Cmp, + BinaryOperator *AddLike, + const APInt &C) { + Value *X = nullptr; + Value *Y = nullptr; + if (!match(AddLike, m_AddLike(m_Value(X), m_Value(Y)))) + return nullptr; Value *Op0, *Op1; Instruction *Ext0, *Ext1; const CmpInst::Predicate Pred = Cmp.getPredicate(); - if (match(Add, - m_Add(m_CombineAnd(m_Instruction(Ext0), m_ZExtOrSExt(m_Value(Op0))), - m_CombineAnd(m_Instruction(Ext1), - m_ZExtOrSExt(m_Value(Op1))))) && + if (match(AddLike, m_AddLike(m_CombineAnd(m_Instruction(Ext0), + m_ZExtOrSExt(m_Value(Op0))), + m_CombineAnd(m_Instruction(Ext1), + m_ZExtOrSExt(m_Value(Op1))))) && Op0->getType()->isIntOrIntVectorTy(1) && Op1->getType()->isIntOrIntVectorTy(1)) { unsigned BW = C.getBitWidth(); @@ -2953,8 +2955,8 @@ Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp, Table[1] = ComputeTable(false, true); Table[2] = ComputeTable(true, false); Table[3] = ComputeTable(true, true); - if (auto *Cond = - createLogicFromTable(Table, Op0, Op1, Builder, Add->hasOneUse())) + if (auto *Cond = createLogicFromTable(Table, Op0, Op1, Builder, + AddLike->hasOneUse())) return replaceInstUsesWith(Cmp, Cond); } const APInt *C2; @@ -2962,14 +2964,15 @@ Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp, return nullptr; // Fold icmp pred (add X, C2), C. - Type *Ty = Add->getType(); + Type *Ty = AddLike->getType(); // If the add does not wrap, we can always adjust the compare by subtracting // the constants. Equality comparisons are handled elsewhere. SGE/SLE/UGE/ULE // are canonicalized to SGT/SLT/UGT/ULT. - if ((Add->hasNoSignedWrap() && + if (AddLike->getOpcode() == Instruction::Or || + (AddLike->hasNoSignedWrap() && (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLT)) || - (Add->hasNoUnsignedWrap() && + (AddLike->hasNoUnsignedWrap() && (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULT))) { bool Overflow; APInt NewC = @@ -3026,7 +3029,7 @@ Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp, return new ICmpInst(ICmpInst::ICMP_ULE, X, ConstantInt::get(Ty, C)); } - if (!Add->hasOneUse()) + if (!AddLike->hasOneUse()) return nullptr; // X+C (X & -C2) == C @@ -3679,6 +3682,9 @@ InstCombinerImpl::foldICmpInstWithConstantAllowUndef(ICmpInst &Cmp, Instruction *InstCombinerImpl::foldICmpBinOpWithConstant(ICmpInst &Cmp, BinaryOperator *BO, const APInt &C) { + if (Instruction *I = foldICmpAddLikeConstant(Cmp, BO, C)) + return I; + switch (BO->getOpcode()) { case Instruction::Xor: if (Instruction *I = foldICmpXorConstant(Cmp, BO, C)) @@ -3721,10 +3727,6 @@ Instruction *InstCombinerImpl::foldICmpBinOpWithConstant(ICmpInst &Cmp, if (Instruction *I = foldICmpSubConstant(Cmp, BO, C)) return I; break; - case Instruction::Add: - if (Instruction *I = foldICmpAddConstant(Cmp, BO, C)) - return I; - break; default: break; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index 21c61bd990184..df2b7e982392c 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -678,8 +678,8 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final const APInt &C); Instruction *foldICmpSubConstant(ICmpInst &Cmp, BinaryOperator *Sub, const APInt &C); - Instruction *foldICmpAddConstant(ICmpInst &Cmp, BinaryOperator *Add, - const APInt &C); + Instruction *foldICmpAddLikeConstant(ICmpInst &Cmp, BinaryOperator *AddLike, + const APInt &C); Instruction *foldICmpAndConstConst(ICmpInst &Cmp, BinaryOperator *And, const APInt &C1); Instruction *foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index ab55f235920a7..2fe48d7bc3db0 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -1550,7 +1550,7 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp, if (C0->getType() != Sel.getType()) return nullptr; - // ULT with 'add' of a constant is canonical. See foldICmpAddConstant(). + // ULT with 'add' of a constant is canonical. See foldICmpAddLikeConstant(). // FIXME: Are there more magic icmp predicate+constant pairs we must avoid? // Or should we just abandon this transform entirely? if (Pred == CmpInst::ICMP_ULT && match(X, m_Add(m_Value(), m_Constant()))) diff --git a/llvm/test/Transforms/InstCombine/icmp.ll b/llvm/test/Transforms/InstCombine/icmp.ll index 1f554c7b60256..ac8d653041b62 100644 --- a/llvm/test/Transforms/InstCombine/icmp.ll +++ b/llvm/test/Transforms/InstCombine/icmp.ll @@ -5006,7 +5006,6 @@ define i1 @or_positive_sgt_zero_multi_use(i8 %a) { ret i1 %cmp } - define i1 @disjoint_or_sgt_1(i8 %a, i8 %b) { ; CHECK-LABEL: @disjoint_or_sgt_1( ; CHECK-NEXT: [[B1:%.*]] = add nsw i8 [[B:%.*]], 2 @@ -5138,3 +5137,13 @@ entry: %cmp = icmp eq i8 %add2, %add1 ret i1 %cmp } + +define i1 @icmp_disjoint_or(i32 %x) { +; CHECK-LABEL: @icmp_disjoint_or( +; CHECK-NEXT: [[C:%.*]] = icmp sgt i32 [[X:%.*]], 35 +; CHECK-NEXT: ret i1 [[C]] +; + %or_ = or disjoint i32 %x, 6 + %C = icmp sgt i32 %or_, 41 + ret i1 %C +} From 6b7e41cf12bb9e7d7a0a36b99498203ee9d99e78 Mon Sep 17 00:00:00 2001 From: Mikhail Gudim Date: Tue, 23 Jan 2024 00:50:18 -0500 Subject: [PATCH 2/2] Added some checks to preserve the logic of existing code exactly. Added more tests. --- .../InstCombine/InstCombineCompares.cpp | 7 +-- llvm/test/Transforms/InstCombine/icmp.ll | 54 ++++++++++++++++++- 2 files changed, 56 insertions(+), 5 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 74c8f7f496f9b..0ebe27bcbeeee 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -2969,10 +2969,11 @@ Instruction *InstCombinerImpl::foldICmpAddLikeConstant(ICmpInst &Cmp, // If the add does not wrap, we can always adjust the compare by subtracting // the constants. Equality comparisons are handled elsewhere. SGE/SLE/UGE/ULE // are canonicalized to SGT/SLT/UGT/ULT. - if (AddLike->getOpcode() == Instruction::Or || - (AddLike->hasNoSignedWrap() && + if (((AddLike->getOpcode() == Instruction::Or || + AddLike->hasNoSignedWrap()) && (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLT)) || - (AddLike->hasNoUnsignedWrap() && + ((AddLike->getOpcode() == Instruction::Or || + AddLike->hasNoUnsignedWrap()) && (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULT))) { bool Overflow; APInt NewC = diff --git a/llvm/test/Transforms/InstCombine/icmp.ll b/llvm/test/Transforms/InstCombine/icmp.ll index ac8d653041b62..b1d917422fce2 100644 --- a/llvm/test/Transforms/InstCombine/icmp.ll +++ b/llvm/test/Transforms/InstCombine/icmp.ll @@ -5138,8 +5138,8 @@ entry: ret i1 %cmp } -define i1 @icmp_disjoint_or(i32 %x) { -; CHECK-LABEL: @icmp_disjoint_or( +define i1 @icmp_disjoint_or_sgt(i32 %x) { +; CHECK-LABEL: @icmp_disjoint_or_sgt( ; CHECK-NEXT: [[C:%.*]] = icmp sgt i32 [[X:%.*]], 35 ; CHECK-NEXT: ret i1 [[C]] ; @@ -5147,3 +5147,53 @@ define i1 @icmp_disjoint_or(i32 %x) { %C = icmp sgt i32 %or_, 41 ret i1 %C } + +define i1 @icmp_disjoint_or_slt(i32 %x) { +; CHECK-LABEL: @icmp_disjoint_or_slt( +; CHECK-NEXT: [[C:%.*]] = icmp slt i32 [[X:%.*]], 35 +; CHECK-NEXT: ret i1 [[C]] +; + %or_ = or disjoint i32 %x, 6 + %C = icmp slt i32 %or_, 41 + ret i1 %C +} + +define i1 @icmp_disjoint_or_ult(i32 %x) { +; CHECK-LABEL: @icmp_disjoint_or_ult( +; CHECK-NEXT: [[C:%.*]] = icmp ult i32 [[X:%.*]], 35 +; CHECK-NEXT: ret i1 [[C]] +; + %or_ = or disjoint i32 %x, 6 + %C = icmp ult i32 %or_, 41 + ret i1 %C +} + +define i1 @icmp_disjoint_or_ugt(i32 %x) { +; CHECK-LABEL: @icmp_disjoint_or_ugt( +; CHECK-NEXT: [[C:%.*]] = icmp ugt i32 [[X:%.*]], 35 +; CHECK-NEXT: ret i1 [[C]] +; + %or_ = or disjoint i32 %x, 6 + %C = icmp ugt i32 %or_, 41 + ret i1 %C +} + +define i1 @icmp_disjoint_or_eq(i32 %x) { +; CHECK-LABEL: @icmp_disjoint_or_eq( +; CHECK-NEXT: [[C:%.*]] = icmp eq i32 [[X:%.*]], 0 +; CHECK-NEXT: ret i1 [[C]] +; + %or_ = or disjoint i32 %x, 5 + %C = icmp eq i32 %or_, 5 + ret i1 %C +} + +define i1 @icmp_disjoint_or_be(i32 %x) { +; CHECK-LABEL: @icmp_disjoint_or_be( +; CHECK-NEXT: [[C:%.*]] = icmp ne i32 [[X:%.*]], 0 +; CHECK-NEXT: ret i1 [[C]] +; + %or_ = or disjoint i32 %x, 5 + %C = icmp ne i32 %or_, 5 + ret i1 %C +}