
Two bugs here. First calling `Inst->getFunction()` has undefined behavior if the instruction is not tracked to a function. I suspect the `replaceAllUsesWith` was leaving the GEPs in a weird ghost parent situation. I switched up the visitor to be able to `eraseFromParent` as part of visiting and then everything started working. The second bug was in `DXILFlattenArrays.cpp`. I was unaware that you can have multidimensional arrays of `zeroinitializer`, and `undef` so fixed up the initializer to handle these two cases. fixes #117273
303 lines
11 KiB
C++
303 lines
11 KiB
C++
//===- DXILDataScalarization.cpp - Perform DXIL Data Legalization ---------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===---------------------------------------------------------------------===//
|
|
|
|
#include "DXILDataScalarization.h"
|
|
#include "DirectX.h"
|
|
#include "llvm/ADT/PostOrderIterator.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/IR/GlobalVariable.h"
|
|
#include "llvm/IR/IRBuilder.h"
|
|
#include "llvm/IR/InstVisitor.h"
|
|
#include "llvm/IR/Module.h"
|
|
#include "llvm/IR/Operator.h"
|
|
#include "llvm/IR/PassManager.h"
|
|
#include "llvm/IR/ReplaceConstant.h"
|
|
#include "llvm/IR/Type.h"
|
|
#include "llvm/Transforms/Utils/Cloning.h"
|
|
#include "llvm/Transforms/Utils/Local.h"
|
|
|
|
#define DEBUG_TYPE "dxil-data-scalarization"
|
|
static const int MaxVecSize = 4;
|
|
|
|
using namespace llvm;
|
|
|
|
class DXILDataScalarizationLegacy : public ModulePass {
|
|
|
|
public:
|
|
bool runOnModule(Module &M) override;
|
|
DXILDataScalarizationLegacy() : ModulePass(ID) {}
|
|
|
|
static char ID; // Pass identification.
|
|
};
|
|
|
|
static bool findAndReplaceVectors(Module &M);
|
|
|
|
class DataScalarizerVisitor : public InstVisitor<DataScalarizerVisitor, bool> {
|
|
public:
|
|
DataScalarizerVisitor() : GlobalMap() {}
|
|
bool visit(Instruction &I);
|
|
// InstVisitor methods. They return true if the instruction was scalarized,
|
|
// false if nothing changed.
|
|
bool visitInstruction(Instruction &I) { return false; }
|
|
bool visitSelectInst(SelectInst &SI) { return false; }
|
|
bool visitICmpInst(ICmpInst &ICI) { return false; }
|
|
bool visitFCmpInst(FCmpInst &FCI) { return false; }
|
|
bool visitUnaryOperator(UnaryOperator &UO) { return false; }
|
|
bool visitBinaryOperator(BinaryOperator &BO) { return false; }
|
|
bool visitGetElementPtrInst(GetElementPtrInst &GEPI);
|
|
bool visitCastInst(CastInst &CI) { return false; }
|
|
bool visitBitCastInst(BitCastInst &BCI) { return false; }
|
|
bool visitInsertElementInst(InsertElementInst &IEI) { return false; }
|
|
bool visitExtractElementInst(ExtractElementInst &EEI) { return false; }
|
|
bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; }
|
|
bool visitPHINode(PHINode &PHI) { return false; }
|
|
bool visitLoadInst(LoadInst &LI);
|
|
bool visitStoreInst(StoreInst &SI);
|
|
bool visitCallInst(CallInst &ICI) { return false; }
|
|
bool visitFreezeInst(FreezeInst &FI) { return false; }
|
|
friend bool findAndReplaceVectors(llvm::Module &M);
|
|
|
|
private:
|
|
GlobalVariable *lookupReplacementGlobal(Value *CurrOperand);
|
|
DenseMap<GlobalVariable *, GlobalVariable *> GlobalMap;
|
|
};
|
|
|
|
bool DataScalarizerVisitor::visit(Instruction &I) {
|
|
assert(!GlobalMap.empty());
|
|
return InstVisitor::visit(I);
|
|
}
|
|
|
|
GlobalVariable *
|
|
DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) {
|
|
if (GlobalVariable *OldGlobal = dyn_cast<GlobalVariable>(CurrOperand)) {
|
|
auto It = GlobalMap.find(OldGlobal);
|
|
if (It != GlobalMap.end()) {
|
|
return It->second; // Found, return the new global
|
|
}
|
|
}
|
|
return nullptr; // Not found
|
|
}
|
|
|
|
bool DataScalarizerVisitor::visitLoadInst(LoadInst &LI) {
|
|
unsigned NumOperands = LI.getNumOperands();
|
|
for (unsigned I = 0; I < NumOperands; ++I) {
|
|
Value *CurrOpperand = LI.getOperand(I);
|
|
ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
|
|
if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
|
|
GetElementPtrInst *OldGEP =
|
|
cast<GetElementPtrInst>(CE->getAsInstruction());
|
|
OldGEP->insertBefore(&LI);
|
|
IRBuilder<> Builder(&LI);
|
|
LoadInst *NewLoad =
|
|
Builder.CreateLoad(LI.getType(), OldGEP, LI.getName());
|
|
NewLoad->setAlignment(LI.getAlign());
|
|
LI.replaceAllUsesWith(NewLoad);
|
|
LI.eraseFromParent();
|
|
visitGetElementPtrInst(*OldGEP);
|
|
return true;
|
|
}
|
|
if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand))
|
|
LI.setOperand(I, NewGlobal);
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool DataScalarizerVisitor::visitStoreInst(StoreInst &SI) {
|
|
unsigned NumOperands = SI.getNumOperands();
|
|
for (unsigned I = 0; I < NumOperands; ++I) {
|
|
Value *CurrOpperand = SI.getOperand(I);
|
|
ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand);
|
|
if (CE && CE->getOpcode() == Instruction::GetElementPtr) {
|
|
GetElementPtrInst *OldGEP =
|
|
cast<GetElementPtrInst>(CE->getAsInstruction());
|
|
OldGEP->insertBefore(&SI);
|
|
IRBuilder<> Builder(&SI);
|
|
StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP);
|
|
NewStore->setAlignment(SI.getAlign());
|
|
SI.replaceAllUsesWith(NewStore);
|
|
SI.eraseFromParent();
|
|
visitGetElementPtrInst(*OldGEP);
|
|
return true;
|
|
}
|
|
if (GlobalVariable *NewGlobal = lookupReplacementGlobal(CurrOpperand))
|
|
SI.setOperand(I, NewGlobal);
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
|
|
|
|
unsigned NumOperands = GEPI.getNumOperands();
|
|
GlobalVariable *NewGlobal = nullptr;
|
|
for (unsigned I = 0; I < NumOperands; ++I) {
|
|
Value *CurrOpperand = GEPI.getOperand(I);
|
|
NewGlobal = lookupReplacementGlobal(CurrOpperand);
|
|
if (NewGlobal)
|
|
break;
|
|
}
|
|
if (!NewGlobal)
|
|
return false;
|
|
|
|
IRBuilder<> Builder(&GEPI);
|
|
SmallVector<Value *, MaxVecSize> Indices;
|
|
for (auto &Index : GEPI.indices())
|
|
Indices.push_back(Index);
|
|
|
|
Value *NewGEP =
|
|
Builder.CreateGEP(NewGlobal->getValueType(), NewGlobal, Indices,
|
|
GEPI.getName(), GEPI.getNoWrapFlags());
|
|
GEPI.replaceAllUsesWith(NewGEP);
|
|
GEPI.eraseFromParent();
|
|
return true;
|
|
}
|
|
|
|
// Recursively Creates and Array like version of the given vector like type.
|
|
static Type *replaceVectorWithArray(Type *T, LLVMContext &Ctx) {
|
|
if (auto *VecTy = dyn_cast<VectorType>(T))
|
|
return ArrayType::get(VecTy->getElementType(),
|
|
dyn_cast<FixedVectorType>(VecTy)->getNumElements());
|
|
if (auto *ArrayTy = dyn_cast<ArrayType>(T)) {
|
|
Type *NewElementType =
|
|
replaceVectorWithArray(ArrayTy->getElementType(), Ctx);
|
|
return ArrayType::get(NewElementType, ArrayTy->getNumElements());
|
|
}
|
|
// If it's not a vector or array, return the original type.
|
|
return T;
|
|
}
|
|
|
|
Constant *transformInitializer(Constant *Init, Type *OrigType, Type *NewType,
|
|
LLVMContext &Ctx) {
|
|
// Handle ConstantAggregateZero (zero-initialized constants)
|
|
if (isa<ConstantAggregateZero>(Init)) {
|
|
return ConstantAggregateZero::get(NewType);
|
|
}
|
|
|
|
// Handle UndefValue (undefined constants)
|
|
if (isa<UndefValue>(Init)) {
|
|
return UndefValue::get(NewType);
|
|
}
|
|
|
|
// Handle vector to array transformation
|
|
if (isa<VectorType>(OrigType) && isa<ArrayType>(NewType)) {
|
|
// Convert vector initializer to array initializer
|
|
SmallVector<Constant *, MaxVecSize> ArrayElements;
|
|
if (ConstantVector *ConstVecInit = dyn_cast<ConstantVector>(Init)) {
|
|
for (unsigned I = 0; I < ConstVecInit->getNumOperands(); ++I)
|
|
ArrayElements.push_back(ConstVecInit->getOperand(I));
|
|
} else if (ConstantDataVector *ConstDataVecInit =
|
|
llvm::dyn_cast<llvm::ConstantDataVector>(Init)) {
|
|
for (unsigned I = 0; I < ConstDataVecInit->getNumElements(); ++I)
|
|
ArrayElements.push_back(ConstDataVecInit->getElementAsConstant(I));
|
|
} else {
|
|
assert(false && "Expected a ConstantVector or ConstantDataVector for "
|
|
"vector initializer!");
|
|
}
|
|
|
|
return ConstantArray::get(cast<ArrayType>(NewType), ArrayElements);
|
|
}
|
|
|
|
// Handle array of vectors transformation
|
|
if (auto *ArrayTy = dyn_cast<ArrayType>(OrigType)) {
|
|
auto *ArrayInit = dyn_cast<ConstantArray>(Init);
|
|
assert(ArrayInit && "Expected a ConstantArray for array initializer!");
|
|
|
|
SmallVector<Constant *, MaxVecSize> NewArrayElements;
|
|
for (unsigned I = 0; I < ArrayTy->getNumElements(); ++I) {
|
|
// Recursively transform array elements
|
|
Constant *NewElemInit = transformInitializer(
|
|
ArrayInit->getOperand(I), ArrayTy->getElementType(),
|
|
cast<ArrayType>(NewType)->getElementType(), Ctx);
|
|
NewArrayElements.push_back(NewElemInit);
|
|
}
|
|
|
|
return ConstantArray::get(cast<ArrayType>(NewType), NewArrayElements);
|
|
}
|
|
|
|
// If not a vector or array, return the original initializer
|
|
return Init;
|
|
}
|
|
|
|
static bool findAndReplaceVectors(Module &M) {
|
|
bool MadeChange = false;
|
|
LLVMContext &Ctx = M.getContext();
|
|
IRBuilder<> Builder(Ctx);
|
|
DataScalarizerVisitor Impl;
|
|
for (GlobalVariable &G : M.globals()) {
|
|
Type *OrigType = G.getValueType();
|
|
|
|
Type *NewType = replaceVectorWithArray(OrigType, Ctx);
|
|
if (OrigType != NewType) {
|
|
// Create a new global variable with the updated type
|
|
// Note: Initializer is set via transformInitializer
|
|
GlobalVariable *NewGlobal = new GlobalVariable(
|
|
M, NewType, G.isConstant(), G.getLinkage(),
|
|
/*Initializer=*/nullptr, G.getName() + ".scalarized", &G,
|
|
G.getThreadLocalMode(), G.getAddressSpace(),
|
|
G.isExternallyInitialized());
|
|
|
|
// Copy relevant attributes
|
|
NewGlobal->setUnnamedAddr(G.getUnnamedAddr());
|
|
if (G.getAlignment() > 0) {
|
|
NewGlobal->setAlignment(G.getAlign());
|
|
}
|
|
|
|
if (G.hasInitializer()) {
|
|
Constant *Init = G.getInitializer();
|
|
Constant *NewInit = transformInitializer(Init, OrigType, NewType, Ctx);
|
|
NewGlobal->setInitializer(NewInit);
|
|
}
|
|
|
|
// Note: we want to do G.replaceAllUsesWith(NewGlobal);, but it assumes
|
|
// type equality. Instead we will use the visitor pattern.
|
|
Impl.GlobalMap[&G] = NewGlobal;
|
|
for (User *U : make_early_inc_range(G.users())) {
|
|
if (isa<ConstantExpr>(U) && isa<Operator>(U)) {
|
|
ConstantExpr *CE = cast<ConstantExpr>(U);
|
|
for (User *UCE : make_early_inc_range(CE->users())) {
|
|
if (Instruction *Inst = dyn_cast<Instruction>(UCE))
|
|
Impl.visit(*Inst);
|
|
}
|
|
}
|
|
if (Instruction *Inst = dyn_cast<Instruction>(U))
|
|
Impl.visit(*Inst);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Remove the old globals after the iteration
|
|
for (auto &[Old, New] : Impl.GlobalMap) {
|
|
Old->eraseFromParent();
|
|
MadeChange = true;
|
|
}
|
|
return MadeChange;
|
|
}
|
|
|
|
PreservedAnalyses DXILDataScalarization::run(Module &M,
|
|
ModuleAnalysisManager &) {
|
|
bool MadeChanges = findAndReplaceVectors(M);
|
|
if (!MadeChanges)
|
|
return PreservedAnalyses::all();
|
|
PreservedAnalyses PA;
|
|
return PA;
|
|
}
|
|
|
|
bool DXILDataScalarizationLegacy::runOnModule(Module &M) {
|
|
return findAndReplaceVectors(M);
|
|
}
|
|
|
|
char DXILDataScalarizationLegacy::ID = 0;
|
|
|
|
INITIALIZE_PASS_BEGIN(DXILDataScalarizationLegacy, DEBUG_TYPE,
|
|
"DXIL Data Scalarization", false, false)
|
|
INITIALIZE_PASS_END(DXILDataScalarizationLegacy, DEBUG_TYPE,
|
|
"DXIL Data Scalarization", false, false)
|
|
|
|
ModulePass *llvm::createDXILDataScalarizationLegacyPass() {
|
|
return new DXILDataScalarizationLegacy();
|
|
}
|