diff --git a/llvm/include/llvm/CodeGen/Rematerializer.h b/llvm/include/llvm/CodeGen/Rematerializer.h index c76c5f06cecc..96c00c59f318 100644 --- a/llvm/include/llvm/CodeGen/Rematerializer.h +++ b/llvm/include/llvm/CodeGen/Rematerializer.h @@ -14,6 +14,7 @@ #ifndef LLVM_CODEGEN_REMATERIALIZER_H #define LLVM_CODEGEN_REMATERIALIZER_H +#include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/CodeGen/LiveIntervals.h" #include "llvm/CodeGen/MachineBasicBlock.h" @@ -76,18 +77,7 @@ namespace llvm { /// The rematerializer supports rematerializing arbitrary complex DAGs of /// registers to regions where these registers are used, with the option of /// re-using non-root registers or their previous rematerializations instead of -/// rematerializing them again. It also optionally supports rolling back -/// previous rematerializations (set during analysis phase, see \ref -/// Rematerializer::analyze) to restore the MIR state to what it was -/// pre-rematerialization. When enabled, machine instructions defining -/// rematerializable registers that no longer have any uses following previous -/// rematerializations will not be deleted from the MIR; their opcode will -/// instead be set to a DEBUG_VALUE and their read register operands set to the -/// null register. This maintains their position in the MIR and keeps the -/// original register alive for potential rollback while allowing other -/// passes/analyzes (e.g., machine scheduler, live-interval analysis) to ignore -/// them. \ref Rematerializer::commitRematerializations actually deletes those -/// instructions when rollback is deemed unnecessary. +/// rematerializing them again. /// /// Throughout its lifetime, the rematerializer tracks new registers it creates /// (which are rematerializable by construction) and their relations to other @@ -121,10 +111,7 @@ public: /// arbitrary number of regions, potentially including its own defining /// region. When rematerializations lead to operand changes in users, a /// register may find itself without any user left, at which point the - /// rematerializer marks it for deletion. Its defining instruction either - /// becomes nullptr (without rollback support) or its opcode is set to - /// TargetOpcode::DBG_VALUE (with rollback support) until \ref - /// Rematerializer::commitRematerializations is called. + /// rematerializer deletes it (setting its defining MI to nullptr). struct Reg { /// Single MI defining the rematerializable register. MachineInstr *DefMI; @@ -153,7 +140,7 @@ public: SmallVector Dependencies; /// Returns the rematerializable register from its defining instruction. - inline Register getDefReg() const { + Register getDefReg() const { assert(DefMI && "defining instruction was deleted"); assert(DefMI->getOperand(0).isDef() && "not a register def"); return DefMI->getOperand(0).getReg(); @@ -174,9 +161,7 @@ public: std::pair getRegionUseBounds(unsigned UseRegion, const LiveIntervals &LIS) const; - bool isAlive() const { - return DefMI && DefMI->getOpcode() != TargetOpcode::DBG_VALUE; - } + bool isAlive() const { return DefMI; } private: void addUser(MachineInstr *MI, unsigned Region); @@ -219,6 +204,8 @@ public: using RegionBoundaries = std::pair; + using RematsOf = SmallDenseSet; + /// Simply initializes some internal state, does not identify /// rematerialization candidates. Rematerializer(MachineFunction &MF, @@ -226,65 +213,72 @@ public: LiveIntervals &LIS); /// Goes through the whole MF and identifies all rematerializable registers. - /// When \p SupportRollback is set, rematerializations of original registers - /// can be rolled back and original registers are maintained in the IR even - /// when they longer have any users. Returns whether there is any - /// rematerializable register in regions. - bool analyze(bool SupportRollback); + /// Returns whether there is any rematerializable register in regions. + bool analyze(); /// Adds a new listener to the rematerializer. void addListener(Listener *Listen) { assert(Listen && "null listener"); - assert(!Listeners.contains(Listen) && "duplicate listener"); - Listeners.insert(Listen); + if (!Listeners.insert(Listen).second) + llvm_unreachable("duplicate listener"); } /// Removes a listener from the rematerializer. void removeListener(Listener *Listen) { - assert(Listeners.contains(Listen) && "unknown listener"); - Listeners.erase(Listen); + if (!Listeners.erase(Listen)) + llvm_unreachable("unknown listener"); } /// Removes all listeners from the rematerializer. void clearListeners() { Listeners.clear(); } - inline const Reg &getReg(RegisterIdx RegIdx) const { + const Reg &getReg(RegisterIdx RegIdx) const { assert(RegIdx < Regs.size() && "out of bounds"); return Regs[RegIdx]; }; - inline ArrayRef getRegs() const { return Regs; }; - inline unsigned getNumRegs() const { return Regs.size(); }; + ArrayRef getRegs() const { return Regs; }; + unsigned getNumRegs() const { return Regs.size(); }; - inline const RegionBoundaries &getRegion(RegisterIdx RegionIdx) { + const RegionBoundaries &getRegion(RegisterIdx RegionIdx) const { assert(RegionIdx < Regions.size() && "out of bounds"); return Regions[RegionIdx]; } - inline unsigned getNumRegions() const { return Regions.size(); } + unsigned getNumRegions() const { return Regions.size(); } + /// Whether register \p RegIdx is an original register. + bool isOriginalRegister(RegisterIdx RegIdx) const { + return !isRematerializedRegister(RegIdx); + } /// Whether register \p RegIdx is a rematerialization of some original /// register. - inline bool isRematerializedRegister(RegisterIdx RegIdx) const { + bool isRematerializedRegister(RegisterIdx RegIdx) const { assert(RegIdx < Regs.size() && "out of bounds"); return RegIdx >= UnrematableOprds.size(); } /// Returns the origin index of rematerializable register \p RegIdx. - inline RegisterIdx getOriginOf(RegisterIdx RematRegIdx) const { + RegisterIdx getOriginOf(RegisterIdx RematRegIdx) const { assert(isRematerializedRegister(RematRegIdx) && "not a rematerialization"); return Origins[RematRegIdx - UnrematableOprds.size()]; } /// If \p RegIdx is a rematerialization, returns its origin's index. If it is /// an original register's index, returns the same index. - inline RegisterIdx getOriginOrSelf(RegisterIdx RegIdx) const { + RegisterIdx getOriginOrSelf(RegisterIdx RegIdx) const { if (isRematerializedRegister(RegIdx)) return getOriginOf(RegIdx); return RegIdx; } /// Returns operand indices corresponding to unrematerializable operands for /// any register \p RegIdx. - inline ArrayRef getUnrematableOprds(unsigned RegIdx) const { + ArrayRef getUnrematableOprds(RegisterIdx RegIdx) const { return UnrematableOprds[getOriginOrSelf(RegIdx)]; } + /// If \p MI's first operand defines a register and that register is a + /// rematerializable register tracked by the rematerializer, returns its + /// index in the \ref Regs vector. Otherwise returns \ref + /// Rematerializer::NoReg. + RegisterIdx getDefRegIdx(const MachineInstr &MI) const; + /// When rematerializating a register (called the "root" register in this /// context) to a given position, we must decide what to do with all its /// rematerializable dependencies (for unrematerializable dependencies, we @@ -361,27 +355,27 @@ public: MachineBasicBlock::iterator InsertPos, DependencyReuseInfo &DRI); - /// Rolls back all rematerializations of original register \p RootIdx, - /// transfering all their users back to it and permanently deleting them from - /// the MIR. The root register is revived if it was fully rematerialized (this - /// requires that rollback support was set at that time). Transitive - /// dependencies of the root register that were fully rematerialized are - /// re-vived at their original positions; this requires that rollback support - /// was set when they were rematerialized. - void rollbackRematsOf(RegisterIdx RootIdx); + /// Rematerializes register \p RegIdx before \p InsertPos in \p UseRegion, + /// adding the new rematerializable register to the backing vector \ref Regs + /// and returning its index inside the vector. Sets the new register's + /// rematerializable dependencies to \p Dependencies (these are assumed to + /// already exist in the MIR) and its unrematerializable dependencies to the + /// same as \p RegIdx. The new register initially has no user. Since the + /// method appends to \ref Regs, references to elements within it should be + /// considered invalidated across calls to this method unless the vector can + /// be guaranteed to have enough space for an extra element. + RegisterIdx rematerializeReg(RegisterIdx RegIdx, unsigned UseRegion, + MachineBasicBlock::iterator InsertPos, + SmallVectorImpl &&Dependencies); - /// Rolls back register \p RematIdx (which must be a rematerialization) - /// transfering all its users back to its origin. The latter is revived if it - /// was fully rematerialized (this requires that rollback support was set at - /// that time). - void rollback(RegisterIdx RematIdx); - - /// Revives original register \p RootIdx at its original position in the MIR - /// if it was fully rematerialized with rollback support set. Transitive - /// dependencies of the root register that were fully rematerialized are - /// revived at their original positions; this requires that rollback support - /// was set when they were themselves rematerialized. - void reviveRegIfDead(RegisterIdx RootIdx); + /// Re-creates a previously deleted register \p RegIdx before \p InsertPos in + /// \p DefRegion. \p DefReg must be the original virtual register that \p + /// RegIdx used to define. Sets the new register's rematerializable + /// dependencies to \p Dependencies (these are assumed to already exist in the + /// MIR). + void recreateReg(RegisterIdx RegIdx, unsigned DefRegion, + MachineBasicBlock::iterator InsertPos, Register DefReg, + SmallVectorImpl &&Dependencies); /// Transfers all users of register \p FromRegIdx in region \p UseRegion to \p /// ToRegIdx, the latter of which must be a rematerialization of the former or @@ -397,13 +391,15 @@ public: void transferUser(RegisterIdx FromRegIdx, RegisterIdx ToRegIdx, unsigned UserRegion, MachineInstr &UserMI); - /// Recomputes all live intervals that have changed as a result of previous - /// rematerializations/rollbacks. - void updateLiveIntervals(); + /// Transfers all users of register \p FromRegIdx to register \p ToRegIdx, the + /// latter of which must be a rematerialization of the former or have the same + /// origin register. Users of \p FromRegIdx must be reachable from \p + /// ToRegIdx. + void transferAllUsers(RegisterIdx FromRegIdx, RegisterIdx ToRegIdx); - /// Deletes unused rematerialized registers that were left in the MIR to - /// support rollback. - void commitRematerializations(); + /// Recomputes all live intervals that have changed as a result of previous + /// rematerializations. + void updateLiveIntervals(); /// Determines whether (sub-)register operand \p MO has the same value at /// all \p Uses as at \p MO. This implies that it is also available at all \p @@ -454,9 +450,8 @@ private: /// Indicates the original register index of each rematerialization, in the /// order in which they are created. The size of the vector indicates the /// total number of rematerializations ever created, including those that were - /// deleted or rolled back. + /// deleted. SmallVector Origins; - using RematsOf = SmallDenseSet; /// Maps original register indices to their currently alive /// rematerializations. In practice most registers don't have /// rematerializations so this is represented as a map to lower memory cost. @@ -469,15 +464,13 @@ private: /// Parent block of each region, in order. SmallVector RegionMBB; /// Set of registers whose live-range may have changed during past - /// rematerializations/rollbacks. + /// rematerializations. DenseSet LISUpdates; - /// Keys are fully rematerialized registers whose rematerializations are - /// currently rollback-able. Values map register machine operand indices to - /// their original register. - DenseMap> Revivable; - /// Whether all rematerializations of registers identified during the last - /// analysis phase will be rollback-able. - bool SupportRollback = false; + + /// Common post-processing step after creating a new register \p RematRegIdx + /// at \p InsertPos based on register \p ModelRegIdx. + void postRematerialization(RegisterIdx ModelRegIdx, RegisterIdx RematRegIdx, + MachineBasicBlock::iterator InsertPos); /// During the analysis phase, creates a \ref Rematerializer::Reg object for /// virtual register \p VirtRegIdx if it is rematerializable. \p MIRegion maps @@ -494,19 +487,6 @@ private: /// defined once. bool isMIRematerializable(const MachineInstr &MI) const; - /// Rematerializes register \p RegIdx at \p InsertPos in \p UseRegion, adding - /// the new rematerializable register to the backing vector \ref Regs and - /// returning its index inside the vector. Sets the new registers' - /// rematerializable dependencies to \p Dependencies (these are assumed to - /// already exist in the MIR) and its unrematerializable dependencies to the - /// same as \p RegIdx. The new register initially has no user. Since the - /// method appends to \ref Regs, references to elements within it should be - /// considered invalidated across calls to this method unless the vector can - /// be guaranteed to have enough space for an extra element. - RegisterIdx rematerializeReg(RegisterIdx RegIdx, unsigned UseRegion, - MachineBasicBlock::iterator InsertPos, - SmallVectorImpl &&Dependencies); - /// Implementation of \ref Rematerializer::transferUser that doesn't update /// register users. void transferUserImpl(RegisterIdx FromRegIdx, RegisterIdx ToRegIdx, @@ -520,12 +500,55 @@ private: /// Deletes rematerializable register \p RegIdx from the DAG and relevant /// internal state. void deleteReg(RegisterIdx RegIdx); +}; - /// If \p MI's first operand defines a register and that register is a - /// rematerializable register tracked by the rematerializer, returns its - /// index in the \ref Regs vector. Otherwise returns \ref - /// Rematerializer::NoReg. - RegisterIdx getDefRegIdx(const MachineInstr &MI) const; +/// Rematerializer listener with the ability to re-create deleted registers and +/// rollback rematerializations. Starts recording register deletions and +/// rematerializations as soon as it is attached to the rematerializer. +class Rollbacker : public Rematerializer::Listener { +public: + Rollbacker() = default; + + /// Re-creates all deleted registers and rolls back all rematerializations + /// that were recorded. + void rollback(Rematerializer &Remater); + + void rematerializerNoteRegCreated(const Rematerializer &Remater, + RegisterIdx RegIdx) override; + + void rematerializerNoteRegDeleted(const Rematerializer &Remater, + RegisterIdx RegIdx) override; + +private: + struct RollbackInfo { + /// Original register. + Register DefReg; + /// Original defining region. + unsigned DefRegion; + /// Original dependencies. + SmallVector Dependencies; + /// Position to re-create the register before in case of rollback. This + /// becomes invalid if it originally points to an MI that is deleted later + /// as a consequence of other rematerializations. In such cases \ref + /// NextRegIdx is guaranteed to be an actual register index from which the + /// rollback logic will determine a valid insert position before which to + /// re-create this register. + MachineBasicBlock::iterator InsertPos; + /// If \ref InsertPos points to an MI defining a rematerializable register, + /// stores its index. Otherwise equals \ref Rematerializer::NoReg. + RegisterIdx NextRegIdx; + + RollbackInfo(const Rematerializer &Remater, RegisterIdx RegIdx); + }; + + /// Original registers that have been deleted, in order of deletion. + MapVector DeadRegs; + /// Registers which have been rematerialized (from original index to + /// rematerialized index). + DenseMap Rematerializations; + /// Used to block further recording of events whenver we are actively rolling + /// back. + bool RollingBack = false; }; } // namespace llvm diff --git a/llvm/lib/CodeGen/Rematerializer.cpp b/llvm/lib/CodeGen/Rematerializer.cpp index b0dcd8d502a9..4125c36b4720 100644 --- a/llvm/lib/CodeGen/Rematerializer.cpp +++ b/llvm/lib/CodeGen/Rematerializer.cpp @@ -21,7 +21,6 @@ #include "llvm/CodeGen/MachineOperand.h" #include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/CodeGen/Register.h" -#include "llvm/CodeGen/TargetOpcodes.h" #include "llvm/CodeGen/TargetRegisterInfo.h" #include "llvm/Support/Debug.h" #include @@ -131,86 +130,6 @@ Rematerializer::rematerializeToPos(RegisterIdx RootIdx, unsigned UseRegion, return LastNewIdx; } -void Rematerializer::rollbackRematsOf(RegisterIdx RootIdx) { - auto Remats = Rematerializations.find(RootIdx); - if (Remats == Rematerializations.end()) - return; - - LLVM_DEBUG(dbgs() << "Rolling back rematerializations of " << printID(RootIdx) - << '\n'); - - reviveRegIfDead(RootIdx); - // All of the rematerialization's users must use the revived register. - for (RegisterIdx RematRegIdx : Remats->getSecond()) { - for (const auto &[UseRegion, RegionUsers] : Regs[RematRegIdx].Uses) - transferRegionUsers(RematRegIdx, RootIdx, UseRegion); - } - Rematerializations.erase(RootIdx); - - LLVM_DEBUG(dbgs() << "** Rolled back rematerializations of " - << printID(RootIdx) << '\n'); -} - -void Rematerializer::rollback(RegisterIdx RematIdx) { - assert(getReg(RematIdx).DefMI && !Revivable.contains(RematIdx) && - "cannot rollback dead register"); - const RegisterIdx OriginRegIdx = getOriginOf(RematIdx); - reviveRegIfDead(OriginRegIdx); - for (const auto &[UseRegion, RegionUsers] : Regs[RematIdx].Uses) - transferRegionUsers(RematIdx, OriginRegIdx, UseRegion); -} - -void Rematerializer::reviveRegIfDead(RegisterIdx RootIdx) { - if (getReg(RootIdx).isAlive()) - return; - assert(Revivable.contains(RootIdx) && "not revivable"); - - // Traverse the root's dependency DAG depth-first to find the set of - // registers we must revive and a legal order to revive them in. - SmallVector DepDAG{RootIdx}; - SmallSetVector ReviveOrder; - ReviveOrder.insert(RootIdx); - do { - // All dependencies of a revived register need to be alive too. - const Reg &ReviveReg = getReg(DepDAG.pop_back_val()); - for (const Reg::Dependency &Dep : ReviveReg.Dependencies) { - // We may have already seen the dependency in the dependency DAG. - if (ReviveOrder.contains(Dep.RegIdx)) - continue; - - // Dead dependencies need to be revived. - Reg &DepReg = Regs[Dep.RegIdx]; - if (!DepReg.isAlive()) { - assert(Revivable.contains(Dep.RegIdx) && "not revivable"); - ReviveOrder.insert(Dep.RegIdx); - DepDAG.push_back(Dep.RegIdx); - } - - // All dependencies get a new user (the revived register). - DepReg.addUser(ReviveReg.DefMI, ReviveReg.DefRegion); - LISUpdates.insert(Dep.RegIdx); - } - } while (!DepDAG.empty()); - - for (RegisterIdx RegIdx : reverse(ReviveOrder)) { - // Pick any rematerialization to retrieve the original opcode from. - Reg &ReviveReg = Regs[RegIdx]; - assert(Rematerializations.contains(RegIdx) && "no remats"); - RegisterIdx RematIdx = *Rematerializations.at(RegIdx).begin(); - ReviveReg.DefMI->setDesc(getReg(RematIdx).DefMI->getDesc()); - for (const auto &[MOIdx, Reg] : Revivable.at(RegIdx)) - ReviveReg.DefMI->getOperand(MOIdx).setReg(Reg); - Revivable.erase(RegIdx); - LISUpdates.insert(RegIdx); - - LLVM_DEBUG({ - dbgs() << "** Revived " << printID(RegIdx) << " @ "; - LIS.getInstructionIndex(*ReviveReg.DefMI).print(dbgs()); - dbgs() << '\n'; - }); - } -} - void Rematerializer::transferUser(RegisterIdx FromRegIdx, RegisterIdx ToRegIdx, unsigned UserRegion, MachineInstr &UserMI) { transferUserImpl(FromRegIdx, ToRegIdx, UserMI); @@ -235,6 +154,18 @@ void Rematerializer::transferRegionUsers(RegisterIdx FromRegIdx, deleteRegIfUnused(FromRegIdx); } +void Rematerializer::transferAllUsers(RegisterIdx FromRegIdx, + RegisterIdx ToRegIdx) { + Reg &FromReg = Regs[FromRegIdx], &ToReg = Regs[ToRegIdx]; + for (const auto &[UseRegion, RegionUsers] : FromReg.Uses) { + for (MachineInstr *UserMI : RegionUsers) + transferUserImpl(FromRegIdx, ToRegIdx, *UserMI); + ToReg.addUsers(RegionUsers, UseRegion); + } + FromReg.Uses.clear(); + deleteRegIfUnused(FromRegIdx); +} + void Rematerializer::transferUserImpl(RegisterIdx FromRegIdx, RegisterIdx ToRegIdx, MachineInstr &UserMI) { @@ -268,7 +199,7 @@ void Rematerializer::updateLiveIntervals() { DenseSet SeenUnrematRegs; for (RegisterIdx RegIdx : LISUpdates) { const Reg &UpdateReg = getReg(RegIdx); - assert((UpdateReg.DefMI || Revivable.contains(RegIdx)) && "dead reg"); + assert(UpdateReg.isAlive() && "dead register"); Register DefReg = UpdateReg.getDefReg(); if (LIS.hasInterval(DefReg)) @@ -299,12 +230,6 @@ void Rematerializer::updateLiveIntervals() { LISUpdates.clear(); } -void Rematerializer::commitRematerializations() { - for (auto &[RegIdx, _] : Revivable) - deleteReg(RegIdx); - Revivable.clear(); -} - bool Rematerializer::isMOIdenticalAtUses(MachineOperand &MO, ArrayRef Uses) const { if (Uses.empty()) @@ -361,7 +286,7 @@ void Rematerializer::deleteRegIfUnused(RegisterIdx RootIdx) { // A deleted register's dependencies may be deletable too. const Reg &DeleteReg = getReg(DepDAG.pop_back_val()); for (const Reg::Dependency &Dep : DeleteReg.Dependencies) { - // All dependencies loose a user (the delete register). + // All dependencies loose a user (the deleted register). Reg &DepReg = Regs[Dep.RegIdx]; DepReg.eraseUser(DeleteReg.DefMI, DeleteReg.DefRegion); if (DepReg.Uses.empty()) { @@ -373,27 +298,16 @@ void Rematerializer::deleteRegIfUnused(RegisterIdx RootIdx) { for (RegisterIdx RegIdx : reverse(DeleteOrder)) { Reg &DeleteReg = Regs[RegIdx]; - LIS.removeInterval(DeleteReg.getDefReg()); + + // It is possible that the defined register we are deleting doesn't have an + // interval yet if the LIS hasn't been updated since it was created. + Register DefReg = DeleteReg.getDefReg(); + if (LIS.hasInterval(DefReg)) + LIS.removeInterval(DefReg); LISUpdates.erase(RegIdx); - const bool IsRematerializedReg = isRematerializedRegister(RegIdx); - if (SupportRollback && !IsRematerializedReg) { - // Replace all read registers with the null one to prevent them from - // showing up in use-lists, which is disallowed for debug instructions in - // live interval calculations. Store mappings between operand indices and - // original registers for potential rollback. - DenseMap &RegMap = - Revivable.try_emplace(RegIdx).first->getSecond(); - for (auto [Idx, MO] : enumerate(DeleteReg.DefMI->operands())) { - if (MO.isReg() && MO.readsReg()) { - RegMap.insert({Idx, MO.getReg()}); - MO.setReg(Register()); - } - } - DeleteReg.DefMI->setDesc(TII.get(TargetOpcode::DBG_VALUE)); - } else { - deleteReg(RegIdx); - } - if (IsRematerializedReg) { + + deleteReg(RegIdx); + if (isRematerializedRegister(RegIdx)) { // Delete rematerialized register from its origin's rematerializations. RematsOf &OriginRemats = Rematerializations.at(getOriginOf(RegIdx)); assert(OriginRemats.contains(RegIdx) && "broken remat<->origin link"); @@ -443,7 +357,7 @@ Rematerializer::Rematerializer(MachineFunction &MF, #endif } -bool Rematerializer::analyze(bool SupportRollback) { +bool Rematerializer::analyze() { Regs.clear(); UnrematableOprds.clear(); Origins.clear(); @@ -451,8 +365,6 @@ bool Rematerializer::analyze(bool SupportRollback) { RegionMBB.clear(); RegToIdx.clear(); LISUpdates.clear(); - Revivable.clear(); - this->SupportRollback = SupportRollback; if (Regions.empty()) return false; @@ -611,17 +523,65 @@ RegisterIdx Rematerializer::rematerializeReg( *FromReg.DefMI); NewReg.DefMI = &*std::prev(InsertPos); RegToIdx.insert({NewDefReg, NewRegIdx}); + postRematerialization(RegIdx, NewRegIdx, InsertPos); - // Update the DAG. - RegionBoundaries &Bounds = Regions[UseRegion]; - if (Bounds.first == std::next(MachineBasicBlock::iterator(NewReg.DefMI))) - Bounds.first = NewReg.DefMI; - LIS.InsertMachineInstrInMaps(*NewReg.DefMI); - LISUpdates.insert(NewRegIdx); + noteRegCreated(NewRegIdx); + LLVM_DEBUG(dbgs() << "** Rematerialized " << printID(RegIdx) << " as " + << printRematReg(NewRegIdx) << '\n'); + return NewRegIdx; +} + +void Rematerializer::recreateReg( + RegisterIdx RegIdx, unsigned DefRegion, + MachineBasicBlock::iterator InsertPos, Register DefReg, + SmallVectorImpl &&Dependencies) { + assert(RegToIdx.contains(DefReg) && "unknown defined register"); + assert(RegToIdx.at(DefReg) == RegIdx && "incorrect defined register"); + assert(!getReg(RegIdx).DefMI && "register is still alive"); + + Reg &OriginReg = Regs[RegIdx]; + OriginReg.DefRegion = DefRegion; + OriginReg.Dependencies = std::move(Dependencies); + + // Re-establish the link between origin and rematerialization if necessary. + const bool RecreateOriginalReg = isOriginalRegister(RegIdx); + if (!RecreateOriginalReg) + Rematerializations[getOriginOf(RegIdx)].insert(RegIdx); + + // Rematerialize from one of the existing rematerializations or from the + // origin. We expect at least one to exist, otherwise it would mean the value + // held by the original register is no longer available anywhere in the MF. + RegisterIdx ModelRegIdx; + if (RecreateOriginalReg) { + assert(Rematerializations.contains(RegIdx) && "expected remats"); + ModelRegIdx = *Rematerializations.at(RegIdx).begin(); + } else { + assert(getReg(getOriginOf(RegIdx)).DefMI && "expected alive origin"); + ModelRegIdx = getOriginOf(RegIdx); + } + const MachineInstr &ModelDefMI = *getReg(ModelRegIdx).DefMI; + + TII.reMaterialize(*RegionMBB[DefRegion], InsertPos, DefReg, 0, ModelDefMI); + OriginReg.DefMI = &*std::prev(InsertPos); + postRematerialization(ModelRegIdx, RegIdx, InsertPos); + LLVM_DEBUG(dbgs() << "** Recreated " << printID(RegIdx) << " as " + << printRematReg(RegIdx) << '\n'); +} + +void Rematerializer::postRematerialization( + RegisterIdx ModelRegIdx, RegisterIdx RematRegIdx, + MachineBasicBlock::iterator InsertPos) { + + // The start of the new register's region may have changed. + Reg &ModelReg = Regs[ModelRegIdx], &RematReg = Regs[RematRegIdx]; + LIS.InsertMachineInstrInMaps(*RematReg.DefMI); + MachineBasicBlock::iterator &RegionBegin = Regions[RematReg.DefRegion].first; + if (RegionBegin == std::next(MachineBasicBlock::iterator(RematReg.DefMI))) + RegionBegin = RematReg.DefMI; // Replace dependencies as needed in the rematerialized MI. All dependencies // of the latter gain a new user. - auto ZipedDeps = zip_equal(FromReg.Dependencies, NewReg.Dependencies); + auto ZipedDeps = zip_equal(ModelReg.Dependencies, RematReg.Dependencies); for (const auto &[OldDep, NewDep] : ZipedDeps) { assert(OldDep.MOIdx == NewDep.MOIdx && "operand mismatch"); LLVM_DEBUG(dbgs() << " Operand #" << OldDep.MOIdx << ": " @@ -630,22 +590,14 @@ RegisterIdx Rematerializer::rematerializeReg( Reg &NewDepReg = Regs[NewDep.RegIdx]; if (OldDep.RegIdx != NewDep.RegIdx) { - Register OldDefReg = FromReg.DefMI->getOperand(OldDep.MOIdx).getReg(); - NewReg.DefMI->substituteRegister(OldDefReg, NewDepReg.getDefReg(), 0, - TRI); + Register OldDefReg = ModelReg.DefMI->getOperand(OldDep.MOIdx).getReg(); + RematReg.DefMI->substituteRegister(OldDefReg, NewDepReg.getDefReg(), 0, + TRI); LISUpdates.insert(OldDep.RegIdx); } - NewDepReg.addUser(NewReg.DefMI, UseRegion); + NewDepReg.addUser(RematReg.DefMI, RematReg.DefRegion); LISUpdates.insert(NewDep.RegIdx); } - - noteRegCreated(NewRegIdx); - - LLVM_DEBUG({ - dbgs() << "** Rematerialized " << printID(RegIdx) << " as " - << printRematReg(NewRegIdx) << '\n'; - }); - return NewRegIdx; } std::pair @@ -807,3 +759,73 @@ Printable Rematerializer::printUser(const MachineInstr *MI, LIS.getInstructionIndex(*MI).print(OS); }); } + +Rollbacker::RollbackInfo::RollbackInfo(const Rematerializer &Remater, + RegisterIdx RegIdx) { + const Rematerializer::Reg &Reg = Remater.getReg(RegIdx); + DefReg = Reg.getDefReg(); + DefRegion = Reg.DefRegion; + Dependencies = Reg.Dependencies; + + InsertPos = std::next(Reg.DefMI->getIterator()); + if (InsertPos != Reg.DefMI->getParent()->end()) + NextRegIdx = Remater.getDefRegIdx(*InsertPos); + else + NextRegIdx = Rematerializer::NoReg; +} + +void Rollbacker::rematerializerNoteRegCreated(const Rematerializer &Remater, + RegisterIdx RegIdx) { + if (RollingBack) + return; + Rematerializations[Remater.getOriginOf(RegIdx)].insert(RegIdx); +} + +void Rollbacker::rematerializerNoteRegDeleted(const Rematerializer &Remater, + RegisterIdx RegIdx) { + if (RollingBack || Remater.isRematerializedRegister(RegIdx)) + return; + DeadRegs.try_emplace(RegIdx, Remater, RegIdx); +} + +void Rollbacker::rollback(Rematerializer &Remater) { + RollingBack = true; + + // Re-create deleted registers. + for (auto &[RegIdx, Info] : DeadRegs) { + assert(!Remater.getReg(RegIdx).isAlive() && "register should be dead"); + + // The MI that was originally just after the MI defining the register we + // are trying to re-create may have been deleted. In such cases, we can + // re-create at that MI's own insert position (and apply the same logic + // recursively). + MachineBasicBlock::iterator InsertPos = Info.InsertPos; + RegisterIdx NextRegIdx = Info.NextRegIdx; + while (NextRegIdx != Rematerializer::NoReg) { + const auto *NextRegRollback = DeadRegs.find(NextRegIdx); + if (NextRegRollback == DeadRegs.end()) + break; + InsertPos = NextRegRollback->second.InsertPos; + NextRegIdx = NextRegRollback->second.NextRegIdx; + } + Remater.recreateReg(RegIdx, Info.DefRegion, InsertPos, Info.DefReg, + std::move(Info.Dependencies)); + } + + // Rollback rematerializations. + for (const auto &[RegIdx, RematsOf] : Rematerializations) { + for (RegisterIdx RematRegIdx : RematsOf) { + // It is possible that rematerializations were deleted. Their users would + // have been transfered to some other rematerialization so we can safely + // ignore them. Original registers that were deleted were just re-created + // so we do not need to check for that. + if (Remater.getReg(RematRegIdx).isAlive()) + Remater.transferAllUsers(RematRegIdx, RegIdx); + } + } + + Remater.updateLiveIntervals(); + DeadRegs.clear(); + Rematerializations.clear(); + RollingBack = false; +} diff --git a/llvm/unittests/CodeGen/RematerializerTest.cpp b/llvm/unittests/CodeGen/RematerializerTest.cpp index ca2bc3b86d47..6b1373377816 100644 --- a/llvm/unittests/CodeGen/RematerializerTest.cpp +++ b/llvm/unittests/CodeGen/RematerializerTest.cpp @@ -80,8 +80,7 @@ public: MAM.registerPass([&] { return MachineModuleAnalysis(*MMI); }); } - bool parseMIRAndInit(StringRef MIRCode, StringRef FunName, - bool SupportRollback) { + bool parseMIRAndInit(StringRef MIRCode, StringRef FunName) { SMDiagnostic Diagnostic; std::unique_ptr MBuffer = MemoryBuffer::getMemBuffer(MIRCode); MIR = createMIRParser(std::move(MBuffer), Context); @@ -122,7 +121,7 @@ public: } Remater = std::make_unique(*MF, *Regions, LIS); - Remater->analyze(SupportRollback); + Remater->analyze(); return true; } @@ -197,10 +196,11 @@ body: | S_ENDPGM 0 ... )"; - ASSERT_TRUE( - parseMIRAndInit(MIR, "TreeRematRollback", /*SupportRollback=*/true)); + ASSERT_TRUE(parseMIRAndInit(MIR, "TreeRematRollback")); Rematerializer &Remater = getRematerializer(); Rematerializer::DependencyReuseInfo DRI; + Rollbacker Rollbacker; + Remater.addListener(&Rollbacker); // MBB/Region indices. const unsigned MBB0 = 0, MBB1 = 1; @@ -218,15 +218,16 @@ body: | Remater.rematerializeToRegion(/*RootIdx=*/Add23, /*UseRegion=*/MBB1, DRI); Remater.updateLiveIntervals(); - // None of the original registers have any users, but they still are in the - // MIR because we enabled rollback support. + // None of the original registers have any users left. EXPECT_NO_USERS(Cst0); EXPECT_NO_USERS(Cst1); EXPECT_NO_USERS(Add01); EXPECT_NO_USERS(Cst3); EXPECT_NO_USERS(Add23); - // Copies of all MIs were inserted into the second MBB. + // Copies of all MIs were inserted into the second MBB. Original registers + // were deleted. + RegionSizes[MBB0] -= 5; RegionSizes[MBB1] += 5; ASSERT_REGION_SIZES(RegionSizes); NumRegs += 5; @@ -234,7 +235,8 @@ body: | } // After rollback all rematerializations are removed from the MIR. - Remater.rollbackRematsOf(Add23); + Rollbacker.rollback(Remater); + RegionSizes[MBB0] += 5; RegionSizes[MBB1] -= 5; ASSERT_REGION_SIZES(RegionSizes); @@ -253,6 +255,7 @@ body: | EXPECT_NO_USERS(Add23); // Only immediate dependencies are copied to the second MBB. + RegionSizes[MBB0] -= 3; RegionSizes[MBB1] += 3; ASSERT_REGION_SIZES(RegionSizes); NumRegs += 3; @@ -260,7 +263,8 @@ body: | } // After rollback all rematerializations are removed from the MIR. - Remater.rollbackRematsOf(Add23); + Rollbacker.rollback(Remater); + RegionSizes[MBB0] += 3; RegionSizes[MBB1] -= 3; ASSERT_REGION_SIZES(RegionSizes); @@ -302,21 +306,15 @@ body: | EXPECT_NO_USERS(Add23); EXPECT_NUM_USERS(RematAdd23, 1); + RegionSizes[MBB0] -= 3; RegionSizes[MBB1] += 3; ASSERT_REGION_SIZES(RegionSizes); NumRegs += 3; ASSERT_EQ(Remater.getNumRegs(), NumRegs); } - // This time don't rollback; commit the rematerializations. This finally - // deletes unused registers in the first block. However the number of - // registers tracked by the rematerializer doesn't change. + // This time don't rollback. Remater.updateLiveIntervals(); - Remater.commitRematerializations(); - RegionSizes[MBB0] -= 3; - ASSERT_REGION_SIZES(RegionSizes); - ASSERT_EQ(Remater.getNumRegs(), NumRegs); - EXPECT_TRUE(getMF().verify()); } @@ -345,8 +343,7 @@ body: | S_ENDPGM 0 ... )"; - ASSERT_TRUE( - parseMIRAndInit(MIR, "MultiRegionsRemat", /*SupportRollback=*/false)); + ASSERT_TRUE(parseMIRAndInit(MIR, "MultiRegionsRemat")); Rematerializer &Remater = getRematerializer(); Rematerializer::DependencyReuseInfo DRI; @@ -416,7 +413,7 @@ body: | S_ENDPGM 0 ... )"; - ASSERT_TRUE(parseMIRAndInit(MIR, "MultiStep", /*SupportRollback=*/false)); + ASSERT_TRUE(parseMIRAndInit(MIR, "MultiStep")); Rematerializer &Remater = getRematerializer(); Rematerializer::DependencyReuseInfo DRI; @@ -497,7 +494,7 @@ body: | S_ENDPGM 0 ... )"; - ASSERT_TRUE(parseMIRAndInit(MIR, "EmptyRegion", /*SupportRollback=*/false)); + ASSERT_TRUE(parseMIRAndInit(MIR, "EmptyRegion")); Rematerializer &Remater = getRematerializer(); Rematerializer::DependencyReuseInfo DRI; @@ -566,7 +563,7 @@ body: | S_ENDPGM 0 ... )"; - ASSERT_TRUE(parseMIRAndInit(MIR, "SubReg", /*SupportRollback=*/false)); + ASSERT_TRUE(parseMIRAndInit(MIR, "SubReg")); Rematerializer &Remater = getRematerializer(); Rematerializer::DependencyReuseInfo DRI; @@ -592,3 +589,163 @@ body: | Remater.updateLiveIntervals(); EXPECT_TRUE(getMF().verify()); } + +/// Checks that rollback works as expected when the rollback listener is added +/// mid-rematerializations. +TEST_F(RematerializerTest, Rollback) { + StringRef MIR = R"( +name: Rollback +tracksRegLiveness: true +machineFunctionInfo: + isEntryFunction: true +body: | + bb.0: + %0:vgpr_32 = nofpexcept V_CVT_I32_F64_e32 0, implicit $exec, implicit $mode + %1:vgpr_32 = nofpexcept V_CVT_I32_F64_e32 1, implicit $exec, implicit $mode + + bb.1: + S_NOP 0, implicit %0, implicit %1 + + bb.2: + S_NOP 0, implicit %0, implicit %1 + S_ENDPGM 0 +)"; + ASSERT_TRUE(parseMIRAndInit(MIR, "Rollback")); + Rematerializer &Remater = getRematerializer(); + Rematerializer::DependencyReuseInfo DRI; + + // MBB/Region indices. + const unsigned MBB0 = 0, MBB1 = 1, MBB2 = 2; + SmallVector RegionSizes{2, 1, 1}; + ASSERT_REGION_SIZES(RegionSizes); + + // Indices of rematerializable registers. + unsigned NumRegs = 0; + const RegisterIdx Cst0 = NumRegs++, Cst1 = NumRegs++; + ASSERT_EQ(Remater.getNumRegs(), NumRegs); + + // Rematerialize %0 to MBB1, taking one user from the original register. + RegisterIdx RematCst0MBB1 = Remater.rematerializeToRegion(Cst0, MBB1, DRI); + RegionSizes[MBB1] += 1; + ASSERT_REGION_SIZES(RegionSizes); + NumRegs += 1; + ASSERT_EQ(Remater.getNumRegs(), NumRegs); + + Rollbacker Rollback; + Remater.addListener(&Rollback); + + // Rematerialize %0 to MBB2 amd %1 to MBB1/MBB2; each rematerialization ends + // up with a single user and both original registers are deleted. + RegisterIdx RematCst0MBB2 = + Remater.rematerializeToRegion(Cst0, MBB2, DRI.clear()); + RegisterIdx RematCst1MBB1 = + Remater.rematerializeToRegion(Cst1, MBB1, DRI.clear()); + RegisterIdx RematCst1MBB2 = + Remater.rematerializeToRegion(Cst1, MBB2, DRI.clear()); + + RegionSizes[MBB0] -= 2; + RegionSizes[MBB1] += 1; + RegionSizes[MBB2] += 2; + ASSERT_REGION_SIZES(RegionSizes); + NumRegs += 3; + ASSERT_EQ(Remater.getNumRegs(), NumRegs); + + EXPECT_NO_USERS(Cst0); + EXPECT_NO_USERS(Cst1); + EXPECT_NUM_USERS(RematCst0MBB1, 1); + EXPECT_NUM_USERS(RematCst0MBB2, 1); + EXPECT_NUM_USERS(RematCst1MBB1, 1); + EXPECT_NUM_USERS(RematCst1MBB2, 1); + + // Rollback all changes since the rollbacker was added. The first + // rematerialization of %0 to MBB1 happened before so it is not rolled back. + // However %0 is re-created because it was deleted after. + Rollback.rollback(Remater); + + RegionSizes[MBB0] += 2; + RegionSizes[MBB1] -= 1; + RegionSizes[MBB2] -= 2; + ASSERT_REGION_SIZES(RegionSizes); + ASSERT_EQ(Remater.getNumRegs(), NumRegs); + + EXPECT_NUM_USERS(Cst0, 1); + EXPECT_NUM_USERS(Cst1, 2); + EXPECT_NUM_USERS(RematCst0MBB1, 1); + EXPECT_NO_USERS(RematCst0MBB2); + EXPECT_NO_USERS(RematCst1MBB1); + EXPECT_NO_USERS(RematCst1MBB2); + + EXPECT_TRUE(getMF().verify()); +} + +/// Checks that rollback re-creates MIs at correct positions when the order of +/// register deletions forces the re-creation logic to iterate through multiple +/// deleted registers' respective insert position to find a valid one. +TEST_F(RematerializerTest, RollbackInvalidInsertPos) { + StringRef MIR = R"( +name: RollbackInvalidInsertPos +tracksRegLiveness: true +machineFunctionInfo: + isEntryFunction: true +body: | + bb.0: + %0:vgpr_32 = nofpexcept V_CVT_I32_F64_e32 0, implicit $exec, implicit $mode + %1:vgpr_32 = nofpexcept V_CVT_I32_F64_e32 1, implicit $exec, implicit $mode + %2:vgpr_32 = nofpexcept V_CVT_I32_F64_e32 2, implicit $exec, implicit $mode + %3:vgpr_32 = nofpexcept V_CVT_I32_F64_e32 3, implicit $exec, implicit $mode + + bb.1: + S_NOP 0, implicit %0, implicit %1, implicit %2, implicit %3 + S_ENDPGM 0 +)"; + ASSERT_TRUE(parseMIRAndInit(MIR, "RollbackInvalidInsertPos")); + Rematerializer &Remater = getRematerializer(); + Rematerializer::DependencyReuseInfo DRI; + Rollbacker Rollback; + Remater.addListener(&Rollback); + + // MBB/Region indices. + const unsigned MBB0 = 0, MBB1 = 1; + SmallVector RegionSizes{4, 1}; + ASSERT_REGION_SIZES(RegionSizes); + + // Indices of rematerializable registers. + const RegisterIdx Cst0 = 0, Cst1 = 1, Cst2 = 2, Cst3 = 3; + + // Rematerialize %0 to MBB1, deleting the original register + Remater.rematerializeToRegion(Cst0, MBB1, DRI); + RegionSizes[MBB0] -= 1; + RegionSizes[MBB1] += 1; + ASSERT_REGION_SIZES(RegionSizes); + + // Rematerialize %1 to MBB1, deleting the original register + Remater.rematerializeToRegion(Cst1, MBB1, DRI.clear()); + RegionSizes[MBB0] -= 1; + RegionSizes[MBB1] += 1; + ASSERT_REGION_SIZES(RegionSizes); + + // Rematerialize %2 to MBB1, deleting the original register + Remater.rematerializeToRegion(Cst2, MBB1, DRI.clear()); + RegionSizes[MBB0] -= 1; + RegionSizes[MBB1] += 1; + ASSERT_REGION_SIZES(RegionSizes); + + // Now rollback and check for correct instruction order in the original + // defining region. The asserts on region sizes ensure that all original + // registers were indeed deleted and will be re-created in the original + // region. + Rollback.rollback(Remater); + RegionSizes[MBB0] += 3; + RegionSizes[MBB1] -= 3; + ASSERT_REGION_SIZES(RegionSizes); + + MachineInstr &DefCst0 = *Remater.getReg(Cst0).DefMI; + MachineInstr &DefCst1 = *Remater.getReg(Cst1).DefMI; + MachineInstr &DefCst2 = *Remater.getReg(Cst2).DefMI; + MachineInstr &DefCst3 = *Remater.getReg(Cst3).DefMI; + EXPECT_EQ(std::next(DefCst0.getIterator()), DefCst1.getIterator()); + EXPECT_EQ(std::next(DefCst1.getIterator()), DefCst2.getIterator()); + EXPECT_EQ(std::next(DefCst2.getIterator()), DefCst3.getIterator()); + + EXPECT_TRUE(getMF().verify()); +}