[CodeGen] Move rollback capabilities outside of the rematerializer (#184341)

The rematerializer implements support for rolling back
rematerializations by modifying MIs that should normally be deleted in
an attempt to make them "transparent" to other analyses. This involves:

1. setting their opcode to DBG_VALUE and
2. setting their read register operands to the sentinel register.

This approach has several drawbacks.

1. It forces the rematerializer to support tracking these "dead MIs"
(even if support is optional, these data-structures have to exist).
2. It is not actually clear whether this mechanism will interact well
with all other analyses. This is an issue since the intent of the
rematerializer is to be usable in as many contexts as possible.
3. In practice, it has shown itself to be relatively error-prone.

This commit removes rollback support from the rematerializer and moves
those capabilities to a rematerializer listener than can be instantiated
on-demand and implements the same functionality on top of standard
rematerializer operations. The rematerializer now actually deletes MIs
that are no longer useful after rematerializations, and has support for
re-creating them on-demand without requiring additional tracking on its
part.
This commit is contained in:
Lucas Ramirez 2026-04-06 21:23:19 +02:00 committed by GitHub
parent a2c9146da1
commit 5e1162eebc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 448 additions and 246 deletions

View File

@ -14,6 +14,7 @@
#ifndef LLVM_CODEGEN_REMATERIALIZER_H
#define LLVM_CODEGEN_REMATERIALIZER_H
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/CodeGen/LiveIntervals.h"
#include "llvm/CodeGen/MachineBasicBlock.h"
@ -76,18 +77,7 @@ namespace llvm {
/// The rematerializer supports rematerializing arbitrary complex DAGs of
/// registers to regions where these registers are used, with the option of
/// re-using non-root registers or their previous rematerializations instead of
/// rematerializing them again. It also optionally supports rolling back
/// previous rematerializations (set during analysis phase, see \ref
/// Rematerializer::analyze) to restore the MIR state to what it was
/// pre-rematerialization. When enabled, machine instructions defining
/// rematerializable registers that no longer have any uses following previous
/// rematerializations will not be deleted from the MIR; their opcode will
/// instead be set to a DEBUG_VALUE and their read register operands set to the
/// null register. This maintains their position in the MIR and keeps the
/// original register alive for potential rollback while allowing other
/// passes/analyzes (e.g., machine scheduler, live-interval analysis) to ignore
/// them. \ref Rematerializer::commitRematerializations actually deletes those
/// instructions when rollback is deemed unnecessary.
/// rematerializing them again.
///
/// Throughout its lifetime, the rematerializer tracks new registers it creates
/// (which are rematerializable by construction) and their relations to other
@ -121,10 +111,7 @@ public:
/// arbitrary number of regions, potentially including its own defining
/// region. When rematerializations lead to operand changes in users, a
/// register may find itself without any user left, at which point the
/// rematerializer marks it for deletion. Its defining instruction either
/// becomes nullptr (without rollback support) or its opcode is set to
/// TargetOpcode::DBG_VALUE (with rollback support) until \ref
/// Rematerializer::commitRematerializations is called.
/// rematerializer deletes it (setting its defining MI to nullptr).
struct Reg {
/// Single MI defining the rematerializable register.
MachineInstr *DefMI;
@ -153,7 +140,7 @@ public:
SmallVector<Dependency, 2> Dependencies;
/// Returns the rematerializable register from its defining instruction.
inline Register getDefReg() const {
Register getDefReg() const {
assert(DefMI && "defining instruction was deleted");
assert(DefMI->getOperand(0).isDef() && "not a register def");
return DefMI->getOperand(0).getReg();
@ -174,9 +161,7 @@ public:
std::pair<MachineInstr *, MachineInstr *>
getRegionUseBounds(unsigned UseRegion, const LiveIntervals &LIS) const;
bool isAlive() const {
return DefMI && DefMI->getOpcode() != TargetOpcode::DBG_VALUE;
}
bool isAlive() const { return DefMI; }
private:
void addUser(MachineInstr *MI, unsigned Region);
@ -219,6 +204,8 @@ public:
using RegionBoundaries =
std::pair<MachineBasicBlock::iterator, MachineBasicBlock::iterator>;
using RematsOf = SmallDenseSet<RegisterIdx, 4>;
/// Simply initializes some internal state, does not identify
/// rematerialization candidates.
Rematerializer(MachineFunction &MF,
@ -226,65 +213,72 @@ public:
LiveIntervals &LIS);
/// Goes through the whole MF and identifies all rematerializable registers.
/// When \p SupportRollback is set, rematerializations of original registers
/// can be rolled back and original registers are maintained in the IR even
/// when they longer have any users. Returns whether there is any
/// rematerializable register in regions.
bool analyze(bool SupportRollback);
/// Returns whether there is any rematerializable register in regions.
bool analyze();
/// Adds a new listener to the rematerializer.
void addListener(Listener *Listen) {
assert(Listen && "null listener");
assert(!Listeners.contains(Listen) && "duplicate listener");
Listeners.insert(Listen);
if (!Listeners.insert(Listen).second)
llvm_unreachable("duplicate listener");
}
/// Removes a listener from the rematerializer.
void removeListener(Listener *Listen) {
assert(Listeners.contains(Listen) && "unknown listener");
Listeners.erase(Listen);
if (!Listeners.erase(Listen))
llvm_unreachable("unknown listener");
}
/// Removes all listeners from the rematerializer.
void clearListeners() { Listeners.clear(); }
inline const Reg &getReg(RegisterIdx RegIdx) const {
const Reg &getReg(RegisterIdx RegIdx) const {
assert(RegIdx < Regs.size() && "out of bounds");
return Regs[RegIdx];
};
inline ArrayRef<Reg> getRegs() const { return Regs; };
inline unsigned getNumRegs() const { return Regs.size(); };
ArrayRef<Reg> getRegs() const { return Regs; };
unsigned getNumRegs() const { return Regs.size(); };
inline const RegionBoundaries &getRegion(RegisterIdx RegionIdx) {
const RegionBoundaries &getRegion(RegisterIdx RegionIdx) const {
assert(RegionIdx < Regions.size() && "out of bounds");
return Regions[RegionIdx];
}
inline unsigned getNumRegions() const { return Regions.size(); }
unsigned getNumRegions() const { return Regions.size(); }
/// Whether register \p RegIdx is an original register.
bool isOriginalRegister(RegisterIdx RegIdx) const {
return !isRematerializedRegister(RegIdx);
}
/// Whether register \p RegIdx is a rematerialization of some original
/// register.
inline bool isRematerializedRegister(RegisterIdx RegIdx) const {
bool isRematerializedRegister(RegisterIdx RegIdx) const {
assert(RegIdx < Regs.size() && "out of bounds");
return RegIdx >= UnrematableOprds.size();
}
/// Returns the origin index of rematerializable register \p RegIdx.
inline RegisterIdx getOriginOf(RegisterIdx RematRegIdx) const {
RegisterIdx getOriginOf(RegisterIdx RematRegIdx) const {
assert(isRematerializedRegister(RematRegIdx) && "not a rematerialization");
return Origins[RematRegIdx - UnrematableOprds.size()];
}
/// If \p RegIdx is a rematerialization, returns its origin's index. If it is
/// an original register's index, returns the same index.
inline RegisterIdx getOriginOrSelf(RegisterIdx RegIdx) const {
RegisterIdx getOriginOrSelf(RegisterIdx RegIdx) const {
if (isRematerializedRegister(RegIdx))
return getOriginOf(RegIdx);
return RegIdx;
}
/// Returns operand indices corresponding to unrematerializable operands for
/// any register \p RegIdx.
inline ArrayRef<unsigned> getUnrematableOprds(unsigned RegIdx) const {
ArrayRef<unsigned> getUnrematableOprds(RegisterIdx RegIdx) const {
return UnrematableOprds[getOriginOrSelf(RegIdx)];
}
/// If \p MI's first operand defines a register and that register is a
/// rematerializable register tracked by the rematerializer, returns its
/// index in the \ref Regs vector. Otherwise returns \ref
/// Rematerializer::NoReg.
RegisterIdx getDefRegIdx(const MachineInstr &MI) const;
/// When rematerializating a register (called the "root" register in this
/// context) to a given position, we must decide what to do with all its
/// rematerializable dependencies (for unrematerializable dependencies, we
@ -361,27 +355,27 @@ public:
MachineBasicBlock::iterator InsertPos,
DependencyReuseInfo &DRI);
/// Rolls back all rematerializations of original register \p RootIdx,
/// transfering all their users back to it and permanently deleting them from
/// the MIR. The root register is revived if it was fully rematerialized (this
/// requires that rollback support was set at that time). Transitive
/// dependencies of the root register that were fully rematerialized are
/// re-vived at their original positions; this requires that rollback support
/// was set when they were rematerialized.
void rollbackRematsOf(RegisterIdx RootIdx);
/// Rematerializes register \p RegIdx before \p InsertPos in \p UseRegion,
/// adding the new rematerializable register to the backing vector \ref Regs
/// and returning its index inside the vector. Sets the new register's
/// rematerializable dependencies to \p Dependencies (these are assumed to
/// already exist in the MIR) and its unrematerializable dependencies to the
/// same as \p RegIdx. The new register initially has no user. Since the
/// method appends to \ref Regs, references to elements within it should be
/// considered invalidated across calls to this method unless the vector can
/// be guaranteed to have enough space for an extra element.
RegisterIdx rematerializeReg(RegisterIdx RegIdx, unsigned UseRegion,
MachineBasicBlock::iterator InsertPos,
SmallVectorImpl<Reg::Dependency> &&Dependencies);
/// Rolls back register \p RematIdx (which must be a rematerialization)
/// transfering all its users back to its origin. The latter is revived if it
/// was fully rematerialized (this requires that rollback support was set at
/// that time).
void rollback(RegisterIdx RematIdx);
/// Revives original register \p RootIdx at its original position in the MIR
/// if it was fully rematerialized with rollback support set. Transitive
/// dependencies of the root register that were fully rematerialized are
/// revived at their original positions; this requires that rollback support
/// was set when they were themselves rematerialized.
void reviveRegIfDead(RegisterIdx RootIdx);
/// Re-creates a previously deleted register \p RegIdx before \p InsertPos in
/// \p DefRegion. \p DefReg must be the original virtual register that \p
/// RegIdx used to define. Sets the new register's rematerializable
/// dependencies to \p Dependencies (these are assumed to already exist in the
/// MIR).
void recreateReg(RegisterIdx RegIdx, unsigned DefRegion,
MachineBasicBlock::iterator InsertPos, Register DefReg,
SmallVectorImpl<Reg::Dependency> &&Dependencies);
/// Transfers all users of register \p FromRegIdx in region \p UseRegion to \p
/// ToRegIdx, the latter of which must be a rematerialization of the former or
@ -397,13 +391,15 @@ public:
void transferUser(RegisterIdx FromRegIdx, RegisterIdx ToRegIdx,
unsigned UserRegion, MachineInstr &UserMI);
/// Recomputes all live intervals that have changed as a result of previous
/// rematerializations/rollbacks.
void updateLiveIntervals();
/// Transfers all users of register \p FromRegIdx to register \p ToRegIdx, the
/// latter of which must be a rematerialization of the former or have the same
/// origin register. Users of \p FromRegIdx must be reachable from \p
/// ToRegIdx.
void transferAllUsers(RegisterIdx FromRegIdx, RegisterIdx ToRegIdx);
/// Deletes unused rematerialized registers that were left in the MIR to
/// support rollback.
void commitRematerializations();
/// Recomputes all live intervals that have changed as a result of previous
/// rematerializations.
void updateLiveIntervals();
/// Determines whether (sub-)register operand \p MO has the same value at
/// all \p Uses as at \p MO. This implies that it is also available at all \p
@ -454,9 +450,8 @@ private:
/// Indicates the original register index of each rematerialization, in the
/// order in which they are created. The size of the vector indicates the
/// total number of rematerializations ever created, including those that were
/// deleted or rolled back.
/// deleted.
SmallVector<RegisterIdx> Origins;
using RematsOf = SmallDenseSet<RegisterIdx, 4>;
/// Maps original register indices to their currently alive
/// rematerializations. In practice most registers don't have
/// rematerializations so this is represented as a map to lower memory cost.
@ -469,15 +464,13 @@ private:
/// Parent block of each region, in order.
SmallVector<MachineBasicBlock *> RegionMBB;
/// Set of registers whose live-range may have changed during past
/// rematerializations/rollbacks.
/// rematerializations.
DenseSet<RegisterIdx> LISUpdates;
/// Keys are fully rematerialized registers whose rematerializations are
/// currently rollback-able. Values map register machine operand indices to
/// their original register.
DenseMap<RegisterIdx, DenseMap<unsigned, Register>> Revivable;
/// Whether all rematerializations of registers identified during the last
/// analysis phase will be rollback-able.
bool SupportRollback = false;
/// Common post-processing step after creating a new register \p RematRegIdx
/// at \p InsertPos based on register \p ModelRegIdx.
void postRematerialization(RegisterIdx ModelRegIdx, RegisterIdx RematRegIdx,
MachineBasicBlock::iterator InsertPos);
/// During the analysis phase, creates a \ref Rematerializer::Reg object for
/// virtual register \p VirtRegIdx if it is rematerializable. \p MIRegion maps
@ -494,19 +487,6 @@ private:
/// defined once.
bool isMIRematerializable(const MachineInstr &MI) const;
/// Rematerializes register \p RegIdx at \p InsertPos in \p UseRegion, adding
/// the new rematerializable register to the backing vector \ref Regs and
/// returning its index inside the vector. Sets the new registers'
/// rematerializable dependencies to \p Dependencies (these are assumed to
/// already exist in the MIR) and its unrematerializable dependencies to the
/// same as \p RegIdx. The new register initially has no user. Since the
/// method appends to \ref Regs, references to elements within it should be
/// considered invalidated across calls to this method unless the vector can
/// be guaranteed to have enough space for an extra element.
RegisterIdx rematerializeReg(RegisterIdx RegIdx, unsigned UseRegion,
MachineBasicBlock::iterator InsertPos,
SmallVectorImpl<Reg::Dependency> &&Dependencies);
/// Implementation of \ref Rematerializer::transferUser that doesn't update
/// register users.
void transferUserImpl(RegisterIdx FromRegIdx, RegisterIdx ToRegIdx,
@ -520,12 +500,55 @@ private:
/// Deletes rematerializable register \p RegIdx from the DAG and relevant
/// internal state.
void deleteReg(RegisterIdx RegIdx);
};
/// If \p MI's first operand defines a register and that register is a
/// rematerializable register tracked by the rematerializer, returns its
/// index in the \ref Regs vector. Otherwise returns \ref
/// Rematerializer::NoReg.
RegisterIdx getDefRegIdx(const MachineInstr &MI) const;
/// Rematerializer listener with the ability to re-create deleted registers and
/// rollback rematerializations. Starts recording register deletions and
/// rematerializations as soon as it is attached to the rematerializer.
class Rollbacker : public Rematerializer::Listener {
public:
Rollbacker() = default;
/// Re-creates all deleted registers and rolls back all rematerializations
/// that were recorded.
void rollback(Rematerializer &Remater);
void rematerializerNoteRegCreated(const Rematerializer &Remater,
RegisterIdx RegIdx) override;
void rematerializerNoteRegDeleted(const Rematerializer &Remater,
RegisterIdx RegIdx) override;
private:
struct RollbackInfo {
/// Original register.
Register DefReg;
/// Original defining region.
unsigned DefRegion;
/// Original dependencies.
SmallVector<Rematerializer::Reg::Dependency, 2> Dependencies;
/// Position to re-create the register before in case of rollback. This
/// becomes invalid if it originally points to an MI that is deleted later
/// as a consequence of other rematerializations. In such cases \ref
/// NextRegIdx is guaranteed to be an actual register index from which the
/// rollback logic will determine a valid insert position before which to
/// re-create this register.
MachineBasicBlock::iterator InsertPos;
/// If \ref InsertPos points to an MI defining a rematerializable register,
/// stores its index. Otherwise equals \ref Rematerializer::NoReg.
RegisterIdx NextRegIdx;
RollbackInfo(const Rematerializer &Remater, RegisterIdx RegIdx);
};
/// Original registers that have been deleted, in order of deletion.
MapVector<RegisterIdx, RollbackInfo> DeadRegs;
/// Registers which have been rematerialized (from original index to
/// rematerialized index).
DenseMap<RegisterIdx, Rematerializer::RematsOf> Rematerializations;
/// Used to block further recording of events whenver we are actively rolling
/// back.
bool RollingBack = false;
};
} // namespace llvm

View File

@ -21,7 +21,6 @@
#include "llvm/CodeGen/MachineOperand.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/Register.h"
#include "llvm/CodeGen/TargetOpcodes.h"
#include "llvm/CodeGen/TargetRegisterInfo.h"
#include "llvm/Support/Debug.h"
#include <optional>
@ -131,86 +130,6 @@ Rematerializer::rematerializeToPos(RegisterIdx RootIdx, unsigned UseRegion,
return LastNewIdx;
}
void Rematerializer::rollbackRematsOf(RegisterIdx RootIdx) {
auto Remats = Rematerializations.find(RootIdx);
if (Remats == Rematerializations.end())
return;
LLVM_DEBUG(dbgs() << "Rolling back rematerializations of " << printID(RootIdx)
<< '\n');
reviveRegIfDead(RootIdx);
// All of the rematerialization's users must use the revived register.
for (RegisterIdx RematRegIdx : Remats->getSecond()) {
for (const auto &[UseRegion, RegionUsers] : Regs[RematRegIdx].Uses)
transferRegionUsers(RematRegIdx, RootIdx, UseRegion);
}
Rematerializations.erase(RootIdx);
LLVM_DEBUG(dbgs() << "** Rolled back rematerializations of "
<< printID(RootIdx) << '\n');
}
void Rematerializer::rollback(RegisterIdx RematIdx) {
assert(getReg(RematIdx).DefMI && !Revivable.contains(RematIdx) &&
"cannot rollback dead register");
const RegisterIdx OriginRegIdx = getOriginOf(RematIdx);
reviveRegIfDead(OriginRegIdx);
for (const auto &[UseRegion, RegionUsers] : Regs[RematIdx].Uses)
transferRegionUsers(RematIdx, OriginRegIdx, UseRegion);
}
void Rematerializer::reviveRegIfDead(RegisterIdx RootIdx) {
if (getReg(RootIdx).isAlive())
return;
assert(Revivable.contains(RootIdx) && "not revivable");
// Traverse the root's dependency DAG depth-first to find the set of
// registers we must revive and a legal order to revive them in.
SmallVector<RegisterIdx, 4> DepDAG{RootIdx};
SmallSetVector<RegisterIdx, 8> ReviveOrder;
ReviveOrder.insert(RootIdx);
do {
// All dependencies of a revived register need to be alive too.
const Reg &ReviveReg = getReg(DepDAG.pop_back_val());
for (const Reg::Dependency &Dep : ReviveReg.Dependencies) {
// We may have already seen the dependency in the dependency DAG.
if (ReviveOrder.contains(Dep.RegIdx))
continue;
// Dead dependencies need to be revived.
Reg &DepReg = Regs[Dep.RegIdx];
if (!DepReg.isAlive()) {
assert(Revivable.contains(Dep.RegIdx) && "not revivable");
ReviveOrder.insert(Dep.RegIdx);
DepDAG.push_back(Dep.RegIdx);
}
// All dependencies get a new user (the revived register).
DepReg.addUser(ReviveReg.DefMI, ReviveReg.DefRegion);
LISUpdates.insert(Dep.RegIdx);
}
} while (!DepDAG.empty());
for (RegisterIdx RegIdx : reverse(ReviveOrder)) {
// Pick any rematerialization to retrieve the original opcode from.
Reg &ReviveReg = Regs[RegIdx];
assert(Rematerializations.contains(RegIdx) && "no remats");
RegisterIdx RematIdx = *Rematerializations.at(RegIdx).begin();
ReviveReg.DefMI->setDesc(getReg(RematIdx).DefMI->getDesc());
for (const auto &[MOIdx, Reg] : Revivable.at(RegIdx))
ReviveReg.DefMI->getOperand(MOIdx).setReg(Reg);
Revivable.erase(RegIdx);
LISUpdates.insert(RegIdx);
LLVM_DEBUG({
dbgs() << "** Revived " << printID(RegIdx) << " @ ";
LIS.getInstructionIndex(*ReviveReg.DefMI).print(dbgs());
dbgs() << '\n';
});
}
}
void Rematerializer::transferUser(RegisterIdx FromRegIdx, RegisterIdx ToRegIdx,
unsigned UserRegion, MachineInstr &UserMI) {
transferUserImpl(FromRegIdx, ToRegIdx, UserMI);
@ -235,6 +154,18 @@ void Rematerializer::transferRegionUsers(RegisterIdx FromRegIdx,
deleteRegIfUnused(FromRegIdx);
}
void Rematerializer::transferAllUsers(RegisterIdx FromRegIdx,
RegisterIdx ToRegIdx) {
Reg &FromReg = Regs[FromRegIdx], &ToReg = Regs[ToRegIdx];
for (const auto &[UseRegion, RegionUsers] : FromReg.Uses) {
for (MachineInstr *UserMI : RegionUsers)
transferUserImpl(FromRegIdx, ToRegIdx, *UserMI);
ToReg.addUsers(RegionUsers, UseRegion);
}
FromReg.Uses.clear();
deleteRegIfUnused(FromRegIdx);
}
void Rematerializer::transferUserImpl(RegisterIdx FromRegIdx,
RegisterIdx ToRegIdx,
MachineInstr &UserMI) {
@ -268,7 +199,7 @@ void Rematerializer::updateLiveIntervals() {
DenseSet<Register> SeenUnrematRegs;
for (RegisterIdx RegIdx : LISUpdates) {
const Reg &UpdateReg = getReg(RegIdx);
assert((UpdateReg.DefMI || Revivable.contains(RegIdx)) && "dead reg");
assert(UpdateReg.isAlive() && "dead register");
Register DefReg = UpdateReg.getDefReg();
if (LIS.hasInterval(DefReg))
@ -299,12 +230,6 @@ void Rematerializer::updateLiveIntervals() {
LISUpdates.clear();
}
void Rematerializer::commitRematerializations() {
for (auto &[RegIdx, _] : Revivable)
deleteReg(RegIdx);
Revivable.clear();
}
bool Rematerializer::isMOIdenticalAtUses(MachineOperand &MO,
ArrayRef<SlotIndex> Uses) const {
if (Uses.empty())
@ -361,7 +286,7 @@ void Rematerializer::deleteRegIfUnused(RegisterIdx RootIdx) {
// A deleted register's dependencies may be deletable too.
const Reg &DeleteReg = getReg(DepDAG.pop_back_val());
for (const Reg::Dependency &Dep : DeleteReg.Dependencies) {
// All dependencies loose a user (the delete register).
// All dependencies loose a user (the deleted register).
Reg &DepReg = Regs[Dep.RegIdx];
DepReg.eraseUser(DeleteReg.DefMI, DeleteReg.DefRegion);
if (DepReg.Uses.empty()) {
@ -373,27 +298,16 @@ void Rematerializer::deleteRegIfUnused(RegisterIdx RootIdx) {
for (RegisterIdx RegIdx : reverse(DeleteOrder)) {
Reg &DeleteReg = Regs[RegIdx];
LIS.removeInterval(DeleteReg.getDefReg());
// It is possible that the defined register we are deleting doesn't have an
// interval yet if the LIS hasn't been updated since it was created.
Register DefReg = DeleteReg.getDefReg();
if (LIS.hasInterval(DefReg))
LIS.removeInterval(DefReg);
LISUpdates.erase(RegIdx);
const bool IsRematerializedReg = isRematerializedRegister(RegIdx);
if (SupportRollback && !IsRematerializedReg) {
// Replace all read registers with the null one to prevent them from
// showing up in use-lists, which is disallowed for debug instructions in
// live interval calculations. Store mappings between operand indices and
// original registers for potential rollback.
DenseMap<unsigned, Register> &RegMap =
Revivable.try_emplace(RegIdx).first->getSecond();
for (auto [Idx, MO] : enumerate(DeleteReg.DefMI->operands())) {
if (MO.isReg() && MO.readsReg()) {
RegMap.insert({Idx, MO.getReg()});
MO.setReg(Register());
}
}
DeleteReg.DefMI->setDesc(TII.get(TargetOpcode::DBG_VALUE));
} else {
deleteReg(RegIdx);
}
if (IsRematerializedReg) {
deleteReg(RegIdx);
if (isRematerializedRegister(RegIdx)) {
// Delete rematerialized register from its origin's rematerializations.
RematsOf &OriginRemats = Rematerializations.at(getOriginOf(RegIdx));
assert(OriginRemats.contains(RegIdx) && "broken remat<->origin link");
@ -443,7 +357,7 @@ Rematerializer::Rematerializer(MachineFunction &MF,
#endif
}
bool Rematerializer::analyze(bool SupportRollback) {
bool Rematerializer::analyze() {
Regs.clear();
UnrematableOprds.clear();
Origins.clear();
@ -451,8 +365,6 @@ bool Rematerializer::analyze(bool SupportRollback) {
RegionMBB.clear();
RegToIdx.clear();
LISUpdates.clear();
Revivable.clear();
this->SupportRollback = SupportRollback;
if (Regions.empty())
return false;
@ -611,17 +523,65 @@ RegisterIdx Rematerializer::rematerializeReg(
*FromReg.DefMI);
NewReg.DefMI = &*std::prev(InsertPos);
RegToIdx.insert({NewDefReg, NewRegIdx});
postRematerialization(RegIdx, NewRegIdx, InsertPos);
// Update the DAG.
RegionBoundaries &Bounds = Regions[UseRegion];
if (Bounds.first == std::next(MachineBasicBlock::iterator(NewReg.DefMI)))
Bounds.first = NewReg.DefMI;
LIS.InsertMachineInstrInMaps(*NewReg.DefMI);
LISUpdates.insert(NewRegIdx);
noteRegCreated(NewRegIdx);
LLVM_DEBUG(dbgs() << "** Rematerialized " << printID(RegIdx) << " as "
<< printRematReg(NewRegIdx) << '\n');
return NewRegIdx;
}
void Rematerializer::recreateReg(
RegisterIdx RegIdx, unsigned DefRegion,
MachineBasicBlock::iterator InsertPos, Register DefReg,
SmallVectorImpl<Reg::Dependency> &&Dependencies) {
assert(RegToIdx.contains(DefReg) && "unknown defined register");
assert(RegToIdx.at(DefReg) == RegIdx && "incorrect defined register");
assert(!getReg(RegIdx).DefMI && "register is still alive");
Reg &OriginReg = Regs[RegIdx];
OriginReg.DefRegion = DefRegion;
OriginReg.Dependencies = std::move(Dependencies);
// Re-establish the link between origin and rematerialization if necessary.
const bool RecreateOriginalReg = isOriginalRegister(RegIdx);
if (!RecreateOriginalReg)
Rematerializations[getOriginOf(RegIdx)].insert(RegIdx);
// Rematerialize from one of the existing rematerializations or from the
// origin. We expect at least one to exist, otherwise it would mean the value
// held by the original register is no longer available anywhere in the MF.
RegisterIdx ModelRegIdx;
if (RecreateOriginalReg) {
assert(Rematerializations.contains(RegIdx) && "expected remats");
ModelRegIdx = *Rematerializations.at(RegIdx).begin();
} else {
assert(getReg(getOriginOf(RegIdx)).DefMI && "expected alive origin");
ModelRegIdx = getOriginOf(RegIdx);
}
const MachineInstr &ModelDefMI = *getReg(ModelRegIdx).DefMI;
TII.reMaterialize(*RegionMBB[DefRegion], InsertPos, DefReg, 0, ModelDefMI);
OriginReg.DefMI = &*std::prev(InsertPos);
postRematerialization(ModelRegIdx, RegIdx, InsertPos);
LLVM_DEBUG(dbgs() << "** Recreated " << printID(RegIdx) << " as "
<< printRematReg(RegIdx) << '\n');
}
void Rematerializer::postRematerialization(
RegisterIdx ModelRegIdx, RegisterIdx RematRegIdx,
MachineBasicBlock::iterator InsertPos) {
// The start of the new register's region may have changed.
Reg &ModelReg = Regs[ModelRegIdx], &RematReg = Regs[RematRegIdx];
LIS.InsertMachineInstrInMaps(*RematReg.DefMI);
MachineBasicBlock::iterator &RegionBegin = Regions[RematReg.DefRegion].first;
if (RegionBegin == std::next(MachineBasicBlock::iterator(RematReg.DefMI)))
RegionBegin = RematReg.DefMI;
// Replace dependencies as needed in the rematerialized MI. All dependencies
// of the latter gain a new user.
auto ZipedDeps = zip_equal(FromReg.Dependencies, NewReg.Dependencies);
auto ZipedDeps = zip_equal(ModelReg.Dependencies, RematReg.Dependencies);
for (const auto &[OldDep, NewDep] : ZipedDeps) {
assert(OldDep.MOIdx == NewDep.MOIdx && "operand mismatch");
LLVM_DEBUG(dbgs() << " Operand #" << OldDep.MOIdx << ": "
@ -630,22 +590,14 @@ RegisterIdx Rematerializer::rematerializeReg(
Reg &NewDepReg = Regs[NewDep.RegIdx];
if (OldDep.RegIdx != NewDep.RegIdx) {
Register OldDefReg = FromReg.DefMI->getOperand(OldDep.MOIdx).getReg();
NewReg.DefMI->substituteRegister(OldDefReg, NewDepReg.getDefReg(), 0,
TRI);
Register OldDefReg = ModelReg.DefMI->getOperand(OldDep.MOIdx).getReg();
RematReg.DefMI->substituteRegister(OldDefReg, NewDepReg.getDefReg(), 0,
TRI);
LISUpdates.insert(OldDep.RegIdx);
}
NewDepReg.addUser(NewReg.DefMI, UseRegion);
NewDepReg.addUser(RematReg.DefMI, RematReg.DefRegion);
LISUpdates.insert(NewDep.RegIdx);
}
noteRegCreated(NewRegIdx);
LLVM_DEBUG({
dbgs() << "** Rematerialized " << printID(RegIdx) << " as "
<< printRematReg(NewRegIdx) << '\n';
});
return NewRegIdx;
}
std::pair<MachineInstr *, MachineInstr *>
@ -807,3 +759,73 @@ Printable Rematerializer::printUser(const MachineInstr *MI,
LIS.getInstructionIndex(*MI).print(OS);
});
}
Rollbacker::RollbackInfo::RollbackInfo(const Rematerializer &Remater,
RegisterIdx RegIdx) {
const Rematerializer::Reg &Reg = Remater.getReg(RegIdx);
DefReg = Reg.getDefReg();
DefRegion = Reg.DefRegion;
Dependencies = Reg.Dependencies;
InsertPos = std::next(Reg.DefMI->getIterator());
if (InsertPos != Reg.DefMI->getParent()->end())
NextRegIdx = Remater.getDefRegIdx(*InsertPos);
else
NextRegIdx = Rematerializer::NoReg;
}
void Rollbacker::rematerializerNoteRegCreated(const Rematerializer &Remater,
RegisterIdx RegIdx) {
if (RollingBack)
return;
Rematerializations[Remater.getOriginOf(RegIdx)].insert(RegIdx);
}
void Rollbacker::rematerializerNoteRegDeleted(const Rematerializer &Remater,
RegisterIdx RegIdx) {
if (RollingBack || Remater.isRematerializedRegister(RegIdx))
return;
DeadRegs.try_emplace(RegIdx, Remater, RegIdx);
}
void Rollbacker::rollback(Rematerializer &Remater) {
RollingBack = true;
// Re-create deleted registers.
for (auto &[RegIdx, Info] : DeadRegs) {
assert(!Remater.getReg(RegIdx).isAlive() && "register should be dead");
// The MI that was originally just after the MI defining the register we
// are trying to re-create may have been deleted. In such cases, we can
// re-create at that MI's own insert position (and apply the same logic
// recursively).
MachineBasicBlock::iterator InsertPos = Info.InsertPos;
RegisterIdx NextRegIdx = Info.NextRegIdx;
while (NextRegIdx != Rematerializer::NoReg) {
const auto *NextRegRollback = DeadRegs.find(NextRegIdx);
if (NextRegRollback == DeadRegs.end())
break;
InsertPos = NextRegRollback->second.InsertPos;
NextRegIdx = NextRegRollback->second.NextRegIdx;
}
Remater.recreateReg(RegIdx, Info.DefRegion, InsertPos, Info.DefReg,
std::move(Info.Dependencies));
}
// Rollback rematerializations.
for (const auto &[RegIdx, RematsOf] : Rematerializations) {
for (RegisterIdx RematRegIdx : RematsOf) {
// It is possible that rematerializations were deleted. Their users would
// have been transfered to some other rematerialization so we can safely
// ignore them. Original registers that were deleted were just re-created
// so we do not need to check for that.
if (Remater.getReg(RematRegIdx).isAlive())
Remater.transferAllUsers(RematRegIdx, RegIdx);
}
}
Remater.updateLiveIntervals();
DeadRegs.clear();
Rematerializations.clear();
RollingBack = false;
}

View File

@ -80,8 +80,7 @@ public:
MAM.registerPass([&] { return MachineModuleAnalysis(*MMI); });
}
bool parseMIRAndInit(StringRef MIRCode, StringRef FunName,
bool SupportRollback) {
bool parseMIRAndInit(StringRef MIRCode, StringRef FunName) {
SMDiagnostic Diagnostic;
std::unique_ptr<MemoryBuffer> MBuffer = MemoryBuffer::getMemBuffer(MIRCode);
MIR = createMIRParser(std::move(MBuffer), Context);
@ -122,7 +121,7 @@ public:
}
Remater = std::make_unique<Rematerializer>(*MF, *Regions, LIS);
Remater->analyze(SupportRollback);
Remater->analyze();
return true;
}
@ -197,10 +196,11 @@ body: |
S_ENDPGM 0
...
)";
ASSERT_TRUE(
parseMIRAndInit(MIR, "TreeRematRollback", /*SupportRollback=*/true));
ASSERT_TRUE(parseMIRAndInit(MIR, "TreeRematRollback"));
Rematerializer &Remater = getRematerializer();
Rematerializer::DependencyReuseInfo DRI;
Rollbacker Rollbacker;
Remater.addListener(&Rollbacker);
// MBB/Region indices.
const unsigned MBB0 = 0, MBB1 = 1;
@ -218,15 +218,16 @@ body: |
Remater.rematerializeToRegion(/*RootIdx=*/Add23, /*UseRegion=*/MBB1, DRI);
Remater.updateLiveIntervals();
// None of the original registers have any users, but they still are in the
// MIR because we enabled rollback support.
// None of the original registers have any users left.
EXPECT_NO_USERS(Cst0);
EXPECT_NO_USERS(Cst1);
EXPECT_NO_USERS(Add01);
EXPECT_NO_USERS(Cst3);
EXPECT_NO_USERS(Add23);
// Copies of all MIs were inserted into the second MBB.
// Copies of all MIs were inserted into the second MBB. Original registers
// were deleted.
RegionSizes[MBB0] -= 5;
RegionSizes[MBB1] += 5;
ASSERT_REGION_SIZES(RegionSizes);
NumRegs += 5;
@ -234,7 +235,8 @@ body: |
}
// After rollback all rematerializations are removed from the MIR.
Remater.rollbackRematsOf(Add23);
Rollbacker.rollback(Remater);
RegionSizes[MBB0] += 5;
RegionSizes[MBB1] -= 5;
ASSERT_REGION_SIZES(RegionSizes);
@ -253,6 +255,7 @@ body: |
EXPECT_NO_USERS(Add23);
// Only immediate dependencies are copied to the second MBB.
RegionSizes[MBB0] -= 3;
RegionSizes[MBB1] += 3;
ASSERT_REGION_SIZES(RegionSizes);
NumRegs += 3;
@ -260,7 +263,8 @@ body: |
}
// After rollback all rematerializations are removed from the MIR.
Remater.rollbackRematsOf(Add23);
Rollbacker.rollback(Remater);
RegionSizes[MBB0] += 3;
RegionSizes[MBB1] -= 3;
ASSERT_REGION_SIZES(RegionSizes);
@ -302,21 +306,15 @@ body: |
EXPECT_NO_USERS(Add23);
EXPECT_NUM_USERS(RematAdd23, 1);
RegionSizes[MBB0] -= 3;
RegionSizes[MBB1] += 3;
ASSERT_REGION_SIZES(RegionSizes);
NumRegs += 3;
ASSERT_EQ(Remater.getNumRegs(), NumRegs);
}
// This time don't rollback; commit the rematerializations. This finally
// deletes unused registers in the first block. However the number of
// registers tracked by the rematerializer doesn't change.
// This time don't rollback.
Remater.updateLiveIntervals();
Remater.commitRematerializations();
RegionSizes[MBB0] -= 3;
ASSERT_REGION_SIZES(RegionSizes);
ASSERT_EQ(Remater.getNumRegs(), NumRegs);
EXPECT_TRUE(getMF().verify());
}
@ -345,8 +343,7 @@ body: |
S_ENDPGM 0
...
)";
ASSERT_TRUE(
parseMIRAndInit(MIR, "MultiRegionsRemat", /*SupportRollback=*/false));
ASSERT_TRUE(parseMIRAndInit(MIR, "MultiRegionsRemat"));
Rematerializer &Remater = getRematerializer();
Rematerializer::DependencyReuseInfo DRI;
@ -416,7 +413,7 @@ body: |
S_ENDPGM 0
...
)";
ASSERT_TRUE(parseMIRAndInit(MIR, "MultiStep", /*SupportRollback=*/false));
ASSERT_TRUE(parseMIRAndInit(MIR, "MultiStep"));
Rematerializer &Remater = getRematerializer();
Rematerializer::DependencyReuseInfo DRI;
@ -497,7 +494,7 @@ body: |
S_ENDPGM 0
...
)";
ASSERT_TRUE(parseMIRAndInit(MIR, "EmptyRegion", /*SupportRollback=*/false));
ASSERT_TRUE(parseMIRAndInit(MIR, "EmptyRegion"));
Rematerializer &Remater = getRematerializer();
Rematerializer::DependencyReuseInfo DRI;
@ -566,7 +563,7 @@ body: |
S_ENDPGM 0
...
)";
ASSERT_TRUE(parseMIRAndInit(MIR, "SubReg", /*SupportRollback=*/false));
ASSERT_TRUE(parseMIRAndInit(MIR, "SubReg"));
Rematerializer &Remater = getRematerializer();
Rematerializer::DependencyReuseInfo DRI;
@ -592,3 +589,163 @@ body: |
Remater.updateLiveIntervals();
EXPECT_TRUE(getMF().verify());
}
/// Checks that rollback works as expected when the rollback listener is added
/// mid-rematerializations.
TEST_F(RematerializerTest, Rollback) {
StringRef MIR = R"(
name: Rollback
tracksRegLiveness: true
machineFunctionInfo:
isEntryFunction: true
body: |
bb.0:
%0:vgpr_32 = nofpexcept V_CVT_I32_F64_e32 0, implicit $exec, implicit $mode
%1:vgpr_32 = nofpexcept V_CVT_I32_F64_e32 1, implicit $exec, implicit $mode
bb.1:
S_NOP 0, implicit %0, implicit %1
bb.2:
S_NOP 0, implicit %0, implicit %1
S_ENDPGM 0
)";
ASSERT_TRUE(parseMIRAndInit(MIR, "Rollback"));
Rematerializer &Remater = getRematerializer();
Rematerializer::DependencyReuseInfo DRI;
// MBB/Region indices.
const unsigned MBB0 = 0, MBB1 = 1, MBB2 = 2;
SmallVector<unsigned, 4> RegionSizes{2, 1, 1};
ASSERT_REGION_SIZES(RegionSizes);
// Indices of rematerializable registers.
unsigned NumRegs = 0;
const RegisterIdx Cst0 = NumRegs++, Cst1 = NumRegs++;
ASSERT_EQ(Remater.getNumRegs(), NumRegs);
// Rematerialize %0 to MBB1, taking one user from the original register.
RegisterIdx RematCst0MBB1 = Remater.rematerializeToRegion(Cst0, MBB1, DRI);
RegionSizes[MBB1] += 1;
ASSERT_REGION_SIZES(RegionSizes);
NumRegs += 1;
ASSERT_EQ(Remater.getNumRegs(), NumRegs);
Rollbacker Rollback;
Remater.addListener(&Rollback);
// Rematerialize %0 to MBB2 amd %1 to MBB1/MBB2; each rematerialization ends
// up with a single user and both original registers are deleted.
RegisterIdx RematCst0MBB2 =
Remater.rematerializeToRegion(Cst0, MBB2, DRI.clear());
RegisterIdx RematCst1MBB1 =
Remater.rematerializeToRegion(Cst1, MBB1, DRI.clear());
RegisterIdx RematCst1MBB2 =
Remater.rematerializeToRegion(Cst1, MBB2, DRI.clear());
RegionSizes[MBB0] -= 2;
RegionSizes[MBB1] += 1;
RegionSizes[MBB2] += 2;
ASSERT_REGION_SIZES(RegionSizes);
NumRegs += 3;
ASSERT_EQ(Remater.getNumRegs(), NumRegs);
EXPECT_NO_USERS(Cst0);
EXPECT_NO_USERS(Cst1);
EXPECT_NUM_USERS(RematCst0MBB1, 1);
EXPECT_NUM_USERS(RematCst0MBB2, 1);
EXPECT_NUM_USERS(RematCst1MBB1, 1);
EXPECT_NUM_USERS(RematCst1MBB2, 1);
// Rollback all changes since the rollbacker was added. The first
// rematerialization of %0 to MBB1 happened before so it is not rolled back.
// However %0 is re-created because it was deleted after.
Rollback.rollback(Remater);
RegionSizes[MBB0] += 2;
RegionSizes[MBB1] -= 1;
RegionSizes[MBB2] -= 2;
ASSERT_REGION_SIZES(RegionSizes);
ASSERT_EQ(Remater.getNumRegs(), NumRegs);
EXPECT_NUM_USERS(Cst0, 1);
EXPECT_NUM_USERS(Cst1, 2);
EXPECT_NUM_USERS(RematCst0MBB1, 1);
EXPECT_NO_USERS(RematCst0MBB2);
EXPECT_NO_USERS(RematCst1MBB1);
EXPECT_NO_USERS(RematCst1MBB2);
EXPECT_TRUE(getMF().verify());
}
/// Checks that rollback re-creates MIs at correct positions when the order of
/// register deletions forces the re-creation logic to iterate through multiple
/// deleted registers' respective insert position to find a valid one.
TEST_F(RematerializerTest, RollbackInvalidInsertPos) {
StringRef MIR = R"(
name: RollbackInvalidInsertPos
tracksRegLiveness: true
machineFunctionInfo:
isEntryFunction: true
body: |
bb.0:
%0:vgpr_32 = nofpexcept V_CVT_I32_F64_e32 0, implicit $exec, implicit $mode
%1:vgpr_32 = nofpexcept V_CVT_I32_F64_e32 1, implicit $exec, implicit $mode
%2:vgpr_32 = nofpexcept V_CVT_I32_F64_e32 2, implicit $exec, implicit $mode
%3:vgpr_32 = nofpexcept V_CVT_I32_F64_e32 3, implicit $exec, implicit $mode
bb.1:
S_NOP 0, implicit %0, implicit %1, implicit %2, implicit %3
S_ENDPGM 0
)";
ASSERT_TRUE(parseMIRAndInit(MIR, "RollbackInvalidInsertPos"));
Rematerializer &Remater = getRematerializer();
Rematerializer::DependencyReuseInfo DRI;
Rollbacker Rollback;
Remater.addListener(&Rollback);
// MBB/Region indices.
const unsigned MBB0 = 0, MBB1 = 1;
SmallVector<unsigned, 4> RegionSizes{4, 1};
ASSERT_REGION_SIZES(RegionSizes);
// Indices of rematerializable registers.
const RegisterIdx Cst0 = 0, Cst1 = 1, Cst2 = 2, Cst3 = 3;
// Rematerialize %0 to MBB1, deleting the original register
Remater.rematerializeToRegion(Cst0, MBB1, DRI);
RegionSizes[MBB0] -= 1;
RegionSizes[MBB1] += 1;
ASSERT_REGION_SIZES(RegionSizes);
// Rematerialize %1 to MBB1, deleting the original register
Remater.rematerializeToRegion(Cst1, MBB1, DRI.clear());
RegionSizes[MBB0] -= 1;
RegionSizes[MBB1] += 1;
ASSERT_REGION_SIZES(RegionSizes);
// Rematerialize %2 to MBB1, deleting the original register
Remater.rematerializeToRegion(Cst2, MBB1, DRI.clear());
RegionSizes[MBB0] -= 1;
RegionSizes[MBB1] += 1;
ASSERT_REGION_SIZES(RegionSizes);
// Now rollback and check for correct instruction order in the original
// defining region. The asserts on region sizes ensure that all original
// registers were indeed deleted and will be re-created in the original
// region.
Rollback.rollback(Remater);
RegionSizes[MBB0] += 3;
RegionSizes[MBB1] -= 3;
ASSERT_REGION_SIZES(RegionSizes);
MachineInstr &DefCst0 = *Remater.getReg(Cst0).DefMI;
MachineInstr &DefCst1 = *Remater.getReg(Cst1).DefMI;
MachineInstr &DefCst2 = *Remater.getReg(Cst2).DefMI;
MachineInstr &DefCst3 = *Remater.getReg(Cst3).DefMI;
EXPECT_EQ(std::next(DefCst0.getIterator()), DefCst1.getIterator());
EXPECT_EQ(std::next(DefCst1.getIterator()), DefCst2.getIterator());
EXPECT_EQ(std::next(DefCst2.getIterator()), DefCst3.getIterator());
EXPECT_TRUE(getMF().verify());
}