Skip to content

Commit 70f3ac9

Browse files
vporpoyuxuanchen1997
authored andcommitted
Reapply "[SandboxIR] Implement BranchInst (#100063)"
Summary: This reverts commit c312a1a. Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: https://phabricator.intern.facebook.com/D60251151
1 parent 7a70550 commit 70f3ac9

File tree

8 files changed

+381
-2
lines changed

8 files changed

+381
-2
lines changed

llvm/include/llvm/SandboxIR/SandboxIR.h

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ class Context;
7676
class Function;
7777
class Instruction;
7878
class SelectInst;
79+
class BranchInst;
7980
class LoadInst;
8081
class ReturnInst;
8182
class StoreInst;
@@ -179,6 +180,7 @@ class Value {
179180
friend class User; // For getting `Val`.
180181
friend class Use; // For getting `Val`.
181182
friend class SelectInst; // For getting `Val`.
183+
friend class BranchInst; // For getting `Val`.
182184
friend class LoadInst; // For getting `Val`.
183185
friend class StoreInst; // For getting `Val`.
184186
friend class ReturnInst; // For getting `Val`.
@@ -343,6 +345,14 @@ class User : public Value {
343345
virtual unsigned getUseOperandNo(const Use &Use) const = 0;
344346
friend unsigned Use::getOperandNo() const; // For getUseOperandNo()
345347

348+
void swapOperandsInternal(unsigned OpIdxA, unsigned OpIdxB) {
349+
assert(OpIdxA < getNumOperands() && "OpIdxA out of bounds!");
350+
assert(OpIdxB < getNumOperands() && "OpIdxB out of bounds!");
351+
auto UseA = getOperandUse(OpIdxA);
352+
auto UseB = getOperandUse(OpIdxB);
353+
UseA.swap(UseB);
354+
}
355+
346356
#ifndef NDEBUG
347357
void verifyUserOfLLVMUse(const llvm::Use &Use) const;
348358
#endif // NDEBUG
@@ -504,6 +514,7 @@ class Instruction : public sandboxir::User {
504514
/// returns its topmost LLVM IR instruction.
505515
llvm::Instruction *getTopmostLLVMInstruction() const;
506516
friend class SelectInst; // For getTopmostLLVMInstruction().
517+
friend class BranchInst; // For getTopmostLLVMInstruction().
507518
friend class LoadInst; // For getTopmostLLVMInstruction().
508519
friend class StoreInst; // For getTopmostLLVMInstruction().
509520
friend class ReturnInst; // For getTopmostLLVMInstruction().
@@ -617,6 +628,100 @@ class SelectInst : public Instruction {
617628
#endif
618629
};
619630

631+
class BranchInst : public Instruction {
632+
/// Use Context::createBranchInst(). Don't call the constructor directly.
633+
BranchInst(llvm::BranchInst *BI, Context &Ctx)
634+
: Instruction(ClassID::Br, Opcode::Br, BI, Ctx) {}
635+
friend Context; // for BranchInst()
636+
Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
637+
return getOperandUseDefault(OpIdx, Verify);
638+
}
639+
SmallVector<llvm::Instruction *, 1> getLLVMInstrs() const final {
640+
return {cast<llvm::Instruction>(Val)};
641+
}
642+
643+
public:
644+
unsigned getUseOperandNo(const Use &Use) const final {
645+
return getUseOperandNoDefault(Use);
646+
}
647+
unsigned getNumOfIRInstrs() const final { return 1u; }
648+
static BranchInst *create(BasicBlock *IfTrue, Instruction *InsertBefore,
649+
Context &Ctx);
650+
static BranchInst *create(BasicBlock *IfTrue, BasicBlock *InsertAtEnd,
651+
Context &Ctx);
652+
static BranchInst *create(BasicBlock *IfTrue, BasicBlock *IfFalse,
653+
Value *Cond, Instruction *InsertBefore,
654+
Context &Ctx);
655+
static BranchInst *create(BasicBlock *IfTrue, BasicBlock *IfFalse,
656+
Value *Cond, BasicBlock *InsertAtEnd, Context &Ctx);
657+
/// For isa/dyn_cast.
658+
static bool classof(const Value *From);
659+
bool isUnconditional() const {
660+
return cast<llvm::BranchInst>(Val)->isUnconditional();
661+
}
662+
bool isConditional() const {
663+
return cast<llvm::BranchInst>(Val)->isConditional();
664+
}
665+
Value *getCondition() const;
666+
void setCondition(Value *V) { setOperand(0, V); }
667+
unsigned getNumSuccessors() const { return 1 + isConditional(); }
668+
BasicBlock *getSuccessor(unsigned SuccIdx) const;
669+
void setSuccessor(unsigned Idx, BasicBlock *NewSucc);
670+
void swapSuccessors() { swapOperandsInternal(1, 2); }
671+
672+
private:
673+
struct LLVMBBToSBBB {
674+
Context &Ctx;
675+
LLVMBBToSBBB(Context &Ctx) : Ctx(Ctx) {}
676+
BasicBlock *operator()(llvm::BasicBlock *BB) const;
677+
};
678+
679+
struct ConstLLVMBBToSBBB {
680+
Context &Ctx;
681+
ConstLLVMBBToSBBB(Context &Ctx) : Ctx(Ctx) {}
682+
const BasicBlock *operator()(const llvm::BasicBlock *BB) const;
683+
};
684+
685+
public:
686+
using sb_succ_op_iterator =
687+
mapped_iterator<llvm::BranchInst::succ_op_iterator, LLVMBBToSBBB>;
688+
iterator_range<sb_succ_op_iterator> successors() {
689+
iterator_range<llvm::BranchInst::succ_op_iterator> LLVMRange =
690+
cast<llvm::BranchInst>(Val)->successors();
691+
LLVMBBToSBBB BBMap(Ctx);
692+
sb_succ_op_iterator MappedBegin = map_iterator(LLVMRange.begin(), BBMap);
693+
sb_succ_op_iterator MappedEnd = map_iterator(LLVMRange.end(), BBMap);
694+
return make_range(MappedBegin, MappedEnd);
695+
}
696+
697+
using const_sb_succ_op_iterator =
698+
mapped_iterator<llvm::BranchInst::const_succ_op_iterator,
699+
ConstLLVMBBToSBBB>;
700+
iterator_range<const_sb_succ_op_iterator> successors() const {
701+
iterator_range<llvm::BranchInst::const_succ_op_iterator> ConstLLVMRange =
702+
static_cast<const llvm::BranchInst *>(cast<llvm::BranchInst>(Val))
703+
->successors();
704+
ConstLLVMBBToSBBB ConstBBMap(Ctx);
705+
const_sb_succ_op_iterator ConstMappedBegin =
706+
map_iterator(ConstLLVMRange.begin(), ConstBBMap);
707+
const_sb_succ_op_iterator ConstMappedEnd =
708+
map_iterator(ConstLLVMRange.end(), ConstBBMap);
709+
return make_range(ConstMappedBegin, ConstMappedEnd);
710+
}
711+
712+
#ifndef NDEBUG
713+
void verify() const final {
714+
assert(isa<llvm::BranchInst>(Val) && "Expected BranchInst!");
715+
}
716+
friend raw_ostream &operator<<(raw_ostream &OS, const BranchInst &BI) {
717+
BI.dump(OS);
718+
return OS;
719+
}
720+
void dump(raw_ostream &OS) const override;
721+
LLVM_DUMP_METHOD void dump() const override;
722+
#endif
723+
};
724+
620725
class LoadInst final : public Instruction {
621726
/// Use LoadInst::create() instead of calling the constructor.
622727
LoadInst(llvm::LoadInst *LI, Context &Ctx)
@@ -870,6 +975,8 @@ class Context {
870975

871976
SelectInst *createSelectInst(llvm::SelectInst *SI);
872977
friend SelectInst; // For createSelectInst()
978+
BranchInst *createBranchInst(llvm::BranchInst *I);
979+
friend BranchInst; // For createBranchInst()
873980
LoadInst *createLoadInst(llvm::LoadInst *LI);
874981
friend LoadInst; // For createLoadInst()
875982
StoreInst *createStoreInst(llvm::StoreInst *SI);

llvm/include/llvm/SandboxIR/SandboxIRValues.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ DEF_USER(Constant, Constant)
2626
// ClassID, Opcode(s), Class
2727
DEF_INSTR(Opaque, OP(Opaque), OpaqueInst)
2828
DEF_INSTR(Select, OP(Select), SelectInst)
29+
DEF_INSTR(Br, OP(Br), BranchInst)
2930
DEF_INSTR(Load, OP(Load), LoadInst)
3031
DEF_INSTR(Store, OP(Store), StoreInst)
3132
DEF_INSTR(Ret, OP(Ret), ReturnInst)

llvm/include/llvm/SandboxIR/Tracker.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,27 @@ class UseSet : public IRChangeBase {
101101
#endif
102102
};
103103

104+
/// Tracks swapping a Use with another Use.
105+
class UseSwap : public IRChangeBase {
106+
Use ThisUse;
107+
Use OtherUse;
108+
109+
public:
110+
UseSwap(const Use &ThisUse, const Use &OtherUse, Tracker &Tracker)
111+
: IRChangeBase(Tracker), ThisUse(ThisUse), OtherUse(OtherUse) {
112+
assert(ThisUse.getUser() == OtherUse.getUser() && "Expected same user!");
113+
}
114+
void revert() final { ThisUse.swap(OtherUse); }
115+
void accept() final {}
116+
#ifndef NDEBUG
117+
void dump(raw_ostream &OS) const final {
118+
dumpCommon(OS);
119+
OS << "UseSwap";
120+
}
121+
LLVM_DUMP_METHOD void dump() const final;
122+
#endif
123+
};
124+
104125
class EraseFromParent : public IRChangeBase {
105126
/// Contains all the data we need to restore an "erased" (i.e., detached)
106127
/// instruction: the instruction itself and its operands in order.

llvm/include/llvm/SandboxIR/Use.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class Use {
4747
void set(Value *V);
4848
class User *getUser() const { return Usr; }
4949
unsigned getOperandNo() const;
50+
void swap(Use &OtherUse);
5051
Context *getContext() const { return Ctx; }
5152
bool operator==(const Use &Other) const {
5253
assert(Ctx == Other.Ctx && "Contexts differ!");

llvm/lib/SandboxIR/SandboxIR.cpp

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@ void Use::set(Value *V) { LLVMUse->set(V->Val); }
2020

2121
unsigned Use::getOperandNo() const { return Usr->getUseOperandNo(*this); }
2222

23+
void Use::swap(Use &OtherUse) {
24+
auto &Tracker = Ctx->getTracker();
25+
if (Tracker.isTracking())
26+
Tracker.track(std::make_unique<UseSwap>(*this, OtherUse, Tracker));
27+
LLVMUse->swap(*OtherUse.LLVMUse);
28+
}
29+
2330
#ifndef NDEBUG
2431
void Use::dump(raw_ostream &OS) const {
2532
Value *Def = nullptr;
@@ -500,6 +507,85 @@ void SelectInst::dump() const {
500507
}
501508
#endif // NDEBUG
502509

510+
BranchInst *BranchInst::create(BasicBlock *IfTrue, Instruction *InsertBefore,
511+
Context &Ctx) {
512+
auto &Builder = Ctx.getLLVMIRBuilder();
513+
Builder.SetInsertPoint(cast<llvm::Instruction>(InsertBefore->Val));
514+
llvm::BranchInst *NewBr =
515+
Builder.CreateBr(cast<llvm::BasicBlock>(IfTrue->Val));
516+
return Ctx.createBranchInst(NewBr);
517+
}
518+
519+
BranchInst *BranchInst::create(BasicBlock *IfTrue, BasicBlock *InsertAtEnd,
520+
Context &Ctx) {
521+
auto &Builder = Ctx.getLLVMIRBuilder();
522+
Builder.SetInsertPoint(cast<llvm::BasicBlock>(InsertAtEnd->Val));
523+
llvm::BranchInst *NewBr =
524+
Builder.CreateBr(cast<llvm::BasicBlock>(IfTrue->Val));
525+
return Ctx.createBranchInst(NewBr);
526+
}
527+
528+
BranchInst *BranchInst::create(BasicBlock *IfTrue, BasicBlock *IfFalse,
529+
Value *Cond, Instruction *InsertBefore,
530+
Context &Ctx) {
531+
auto &Builder = Ctx.getLLVMIRBuilder();
532+
Builder.SetInsertPoint(cast<llvm::Instruction>(InsertBefore->Val));
533+
llvm::BranchInst *NewBr =
534+
Builder.CreateCondBr(Cond->Val, cast<llvm::BasicBlock>(IfTrue->Val),
535+
cast<llvm::BasicBlock>(IfFalse->Val));
536+
return Ctx.createBranchInst(NewBr);
537+
}
538+
539+
BranchInst *BranchInst::create(BasicBlock *IfTrue, BasicBlock *IfFalse,
540+
Value *Cond, BasicBlock *InsertAtEnd,
541+
Context &Ctx) {
542+
auto &Builder = Ctx.getLLVMIRBuilder();
543+
Builder.SetInsertPoint(cast<llvm::BasicBlock>(InsertAtEnd->Val));
544+
llvm::BranchInst *NewBr =
545+
Builder.CreateCondBr(Cond->Val, cast<llvm::BasicBlock>(IfTrue->Val),
546+
cast<llvm::BasicBlock>(IfFalse->Val));
547+
return Ctx.createBranchInst(NewBr);
548+
}
549+
550+
bool BranchInst::classof(const Value *From) {
551+
return From->getSubclassID() == ClassID::Br;
552+
}
553+
554+
Value *BranchInst::getCondition() const {
555+
assert(isConditional() && "Cannot get condition of an uncond branch!");
556+
return Ctx.getValue(cast<llvm::BranchInst>(Val)->getCondition());
557+
}
558+
559+
BasicBlock *BranchInst::getSuccessor(unsigned SuccIdx) const {
560+
assert(SuccIdx < getNumSuccessors() &&
561+
"Successor # out of range for Branch!");
562+
return cast_or_null<BasicBlock>(
563+
Ctx.getValue(cast<llvm::BranchInst>(Val)->getSuccessor(SuccIdx)));
564+
}
565+
566+
void BranchInst::setSuccessor(unsigned Idx, BasicBlock *NewSucc) {
567+
assert((Idx == 0 || Idx == 1) && "Out of bounds!");
568+
setOperand(2u - Idx, NewSucc);
569+
}
570+
571+
BasicBlock *BranchInst::LLVMBBToSBBB::operator()(llvm::BasicBlock *BB) const {
572+
return cast<BasicBlock>(Ctx.getValue(BB));
573+
}
574+
const BasicBlock *
575+
BranchInst::ConstLLVMBBToSBBB::operator()(const llvm::BasicBlock *BB) const {
576+
return cast<BasicBlock>(Ctx.getValue(BB));
577+
}
578+
#ifndef NDEBUG
579+
void BranchInst::dump(raw_ostream &OS) const {
580+
dumpCommonPrefix(OS);
581+
dumpCommonSuffix(OS);
582+
}
583+
void BranchInst::dump() const {
584+
dump(dbgs());
585+
dbgs() << "\n";
586+
}
587+
#endif // NDEBUG
588+
503589
LoadInst *LoadInst::create(Type *Ty, Value *Ptr, MaybeAlign Align,
504590
Instruction *InsertBefore, Context &Ctx,
505591
const Twine &Name) {
@@ -758,6 +844,11 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
758844
It->second = std::unique_ptr<SelectInst>(new SelectInst(LLVMSel, *this));
759845
return It->second.get();
760846
}
847+
case llvm::Instruction::Br: {
848+
auto *LLVMBr = cast<llvm::BranchInst>(LLVMV);
849+
It->second = std::unique_ptr<BranchInst>(new BranchInst(LLVMBr, *this));
850+
return It->second.get();
851+
}
761852
case llvm::Instruction::Load: {
762853
auto *LLVMLd = cast<llvm::LoadInst>(LLVMV);
763854
It->second = std::unique_ptr<LoadInst>(new LoadInst(LLVMLd, *this));
@@ -796,6 +887,11 @@ SelectInst *Context::createSelectInst(llvm::SelectInst *SI) {
796887
return cast<SelectInst>(registerValue(std::move(NewPtr)));
797888
}
798889

890+
BranchInst *Context::createBranchInst(llvm::BranchInst *BI) {
891+
auto NewPtr = std::unique_ptr<BranchInst>(new BranchInst(BI, *this));
892+
return cast<BranchInst>(registerValue(std::move(NewPtr)));
893+
}
894+
799895
LoadInst *Context::createLoadInst(llvm::LoadInst *LI) {
800896
auto NewPtr = std::unique_ptr<LoadInst>(new LoadInst(LI, *this));
801897
return cast<LoadInst>(registerValue(std::move(NewPtr)));

llvm/lib/SandboxIR/Tracker.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@ void UseSet::dump() const {
3535
dump(dbgs());
3636
dbgs() << "\n";
3737
}
38+
39+
void UseSwap::dump() const {
40+
dump(dbgs());
41+
dbgs() << "\n";
42+
}
3843
#endif // NDEBUG
3944

4045
Tracker::~Tracker() {

0 commit comments

Comments
 (0)