llvm-project/llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp
Nikita Popov 92c55a315e
[IR] Only allow lifetime.start/end on allocas (#149310)
lifetime.start and lifetime.end are primarily intended for use on
allocas, to enable stack coloring and other liveness optimizations. This
is necessary because all (static) allocas are hoisted into the entry
block, so lifetime markers are the only way to convey the actual
lifetimes.

However, lifetime.start and lifetime.end are currently *allowed* to be
used on non-alloca pointers. We don't actually do this in practice, but
just the mere fact that this is possible breaks the core purpose of the
lifetime markers, which is stack coloring of allocas. Stack coloring can
only work correctly if all lifetime markers for an alloca are
analyzable.

* If a lifetime marker may operate on multiple allocas via a select/phi,
we don't know which lifetime actually starts/ends and handle it
incorrectly (https://github.com/llvm/llvm-project/issues/104776).
* Stack coloring operates on the assumption that all lifetime markers
are visible, and not, for example, hidden behind a function call or
escaped pointer. It's not possible to change this, as part of the
purpose of lifetime markers is that they work even in the presence of
escaped pointers, where simple use analysis is insufficient.

I don't think there is any way to have coherent semantics for lifetime
markers on allocas, while also permitting them on arbitrary pointer
values.

This PR restricts lifetimes to operate on allocas only. As a followup, I
will also drop the size argument, which is superfluous if we always
operate on an alloca. (This change also renders various code handling
lifetime markers on non-alloca dead. I plan to clean up that kind of
code after dropping the size argument as well.)

In practice, I've only found a few places that currently produce
lifetimes on non-allocas:

* CoroEarly replaces the promise alloca with the result of an intrinsic,
which will later be replaced back with an alloca. I think this is the
only place where there is some legitimate loss of functionality, but I
don't think this is particularly important (I don't think we'd expect
the promise in a coroutine to admit useful lifetime optimization.)
* SafeStack moves unsafe allocas onto a separate frame. We can safely
drop lifetimes here, as SafeStack performs its own stack coloring.
* Similar for AddressSanitizer, it also moves allocas into separate
memory.
* LSR sometimes replaces the lifetime argument with a GEP chain of the
alloca (where the offsets ultimately cancel out). This is just
unnecessary. (Fixed separately in
https://github.com/llvm/llvm-project/pull/149492.)
* InferAddrSpaces sometimes makes lifetimes operate on an addrspacecast
of an alloca. I don't think this is necessary.
2025-07-21 15:04:50 +02:00

533 lines
20 KiB
C++

//===-- SPIRVPrepareFunctions.cpp - modify function signatures --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This pass modifies function signatures containing aggregate arguments
// and/or return value before IRTranslator. Information about the original
// signatures is stored in metadata. It is used during call lowering to
// restore correct SPIR-V types of function arguments and return values.
// This pass also substitutes some llvm intrinsic calls with calls to newly
// generated functions (as the Khronos LLVM/SPIR-V Translator does).
//
// NOTE: this pass is a module-level one due to the necessity to modify
// GVs/functions.
//
//===----------------------------------------------------------------------===//
#include "SPIRV.h"
#include "SPIRVSubtarget.h"
#include "SPIRVTargetMachine.h"
#include "SPIRVUtils.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/CodeGen/IntrinsicLowering.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsSPIRV.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
#include <regex>
using namespace llvm;
namespace {
class SPIRVPrepareFunctions : public ModulePass {
const SPIRVTargetMachine &TM;
bool substituteIntrinsicCalls(Function *F);
Function *removeAggregateTypesFromSignature(Function *F);
public:
static char ID;
SPIRVPrepareFunctions(const SPIRVTargetMachine &TM)
: ModulePass(ID), TM(TM) {}
bool runOnModule(Module &M) override;
StringRef getPassName() const override { return "SPIRV prepare functions"; }
void getAnalysisUsage(AnalysisUsage &AU) const override {
ModulePass::getAnalysisUsage(AU);
}
};
} // namespace
char SPIRVPrepareFunctions::ID = 0;
INITIALIZE_PASS(SPIRVPrepareFunctions, "prepare-functions",
"SPIRV prepare functions", false, false)
static std::string lowerLLVMIntrinsicName(IntrinsicInst *II) {
Function *IntrinsicFunc = II->getCalledFunction();
assert(IntrinsicFunc && "Missing function");
std::string FuncName = IntrinsicFunc->getName().str();
llvm::replace(FuncName, '.', '_');
FuncName = "spirv." + FuncName;
return FuncName;
}
static Function *getOrCreateFunction(Module *M, Type *RetTy,
ArrayRef<Type *> ArgTypes,
StringRef Name) {
FunctionType *FT = FunctionType::get(RetTy, ArgTypes, false);
Function *F = M->getFunction(Name);
if (F && F->getFunctionType() == FT)
return F;
Function *NewF = Function::Create(FT, GlobalValue::ExternalLinkage, Name, M);
if (F)
NewF->setDSOLocal(F->isDSOLocal());
NewF->setCallingConv(CallingConv::SPIR_FUNC);
return NewF;
}
static bool lowerIntrinsicToFunction(IntrinsicInst *Intrinsic) {
// For @llvm.memset.* intrinsic cases with constant value and length arguments
// are emulated via "storing" a constant array to the destination. For other
// cases we wrap the intrinsic in @spirv.llvm_memset_* function and expand the
// intrinsic to a loop via expandMemSetAsLoop().
if (auto *MSI = dyn_cast<MemSetInst>(Intrinsic))
if (isa<Constant>(MSI->getValue()) && isa<ConstantInt>(MSI->getLength()))
return false; // It is handled later using OpCopyMemorySized.
Module *M = Intrinsic->getModule();
std::string FuncName = lowerLLVMIntrinsicName(Intrinsic);
if (Intrinsic->isVolatile())
FuncName += ".volatile";
// Redirect @llvm.intrinsic.* call to @spirv.llvm_intrinsic_*
Function *F = M->getFunction(FuncName);
if (F) {
Intrinsic->setCalledFunction(F);
return true;
}
// TODO copy arguments attributes: nocapture writeonly.
FunctionCallee FC =
M->getOrInsertFunction(FuncName, Intrinsic->getFunctionType());
auto IntrinsicID = Intrinsic->getIntrinsicID();
Intrinsic->setCalledFunction(FC);
F = dyn_cast<Function>(FC.getCallee());
assert(F && "Callee must be a function");
switch (IntrinsicID) {
case Intrinsic::memset: {
auto *MSI = static_cast<MemSetInst *>(Intrinsic);
Argument *Dest = F->getArg(0);
Argument *Val = F->getArg(1);
Argument *Len = F->getArg(2);
Argument *IsVolatile = F->getArg(3);
Dest->setName("dest");
Val->setName("val");
Len->setName("len");
IsVolatile->setName("isvolatile");
BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F);
IRBuilder<> IRB(EntryBB);
auto *MemSet = IRB.CreateMemSet(Dest, Val, Len, MSI->getDestAlign(),
MSI->isVolatile());
IRB.CreateRetVoid();
expandMemSetAsLoop(cast<MemSetInst>(MemSet));
MemSet->eraseFromParent();
break;
}
case Intrinsic::bswap: {
BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F);
IRBuilder<> IRB(EntryBB);
auto *BSwap = IRB.CreateIntrinsic(Intrinsic::bswap, Intrinsic->getType(),
F->getArg(0));
IRB.CreateRet(BSwap);
IntrinsicLowering IL(M->getDataLayout());
IL.LowerIntrinsicCall(BSwap);
break;
}
default:
break;
}
return true;
}
static std::string getAnnotation(Value *AnnoVal, Value *OptAnnoVal) {
if (auto *Ref = dyn_cast_or_null<GetElementPtrInst>(AnnoVal))
AnnoVal = Ref->getOperand(0);
if (auto *Ref = dyn_cast_or_null<BitCastInst>(OptAnnoVal))
OptAnnoVal = Ref->getOperand(0);
std::string Anno;
if (auto *C = dyn_cast_or_null<Constant>(AnnoVal)) {
StringRef Str;
if (getConstantStringInfo(C, Str))
Anno = Str;
}
// handle optional annotation parameter in a way that Khronos Translator do
// (collect integers wrapped in a struct)
if (auto *C = dyn_cast_or_null<Constant>(OptAnnoVal);
C && C->getNumOperands()) {
Value *MaybeStruct = C->getOperand(0);
if (auto *Struct = dyn_cast<ConstantStruct>(MaybeStruct)) {
for (unsigned I = 0, E = Struct->getNumOperands(); I != E; ++I) {
if (auto *CInt = dyn_cast<ConstantInt>(Struct->getOperand(I)))
Anno += (I == 0 ? ": " : ", ") +
std::to_string(CInt->getType()->getIntegerBitWidth() == 1
? CInt->getZExtValue()
: CInt->getSExtValue());
}
} else if (auto *Struct = dyn_cast<ConstantAggregateZero>(MaybeStruct)) {
// { i32 i32 ... } zeroinitializer
for (unsigned I = 0, E = Struct->getType()->getStructNumElements();
I != E; ++I)
Anno += I == 0 ? ": 0" : ", 0";
}
}
return Anno;
}
static SmallVector<Metadata *> parseAnnotation(Value *I,
const std::string &Anno,
LLVMContext &Ctx,
Type *Int32Ty) {
// Try to parse the annotation string according to the following rules:
// annotation := ({kind} | {kind:value,value,...})+
// kind := number
// value := number | string
static const std::regex R(
"\\{(\\d+)(?:[:,](\\d+|\"[^\"]*\")(?:,(\\d+|\"[^\"]*\"))*)?\\}");
SmallVector<Metadata *> MDs;
int Pos = 0;
for (std::sregex_iterator
It = std::sregex_iterator(Anno.begin(), Anno.end(), R),
ItEnd = std::sregex_iterator();
It != ItEnd; ++It) {
if (It->position() != Pos)
return SmallVector<Metadata *>{};
Pos = It->position() + It->length();
std::smatch Match = *It;
SmallVector<Metadata *> MDsItem;
for (std::size_t i = 1; i < Match.size(); ++i) {
std::ssub_match SMatch = Match[i];
std::string Item = SMatch.str();
if (Item.length() == 0)
break;
if (Item[0] == '"') {
Item = Item.substr(1, Item.length() - 2);
// Acceptable format of the string snippet is:
static const std::regex RStr("^(\\d+)(?:,(\\d+))*$");
if (std::smatch MatchStr; std::regex_match(Item, MatchStr, RStr)) {
for (std::size_t SubIdx = 1; SubIdx < MatchStr.size(); ++SubIdx)
if (std::string SubStr = MatchStr[SubIdx].str(); SubStr.length())
MDsItem.push_back(ConstantAsMetadata::get(
ConstantInt::get(Int32Ty, std::stoi(SubStr))));
} else {
MDsItem.push_back(MDString::get(Ctx, Item));
}
} else if (int32_t Num; llvm::to_integer(StringRef(Item), Num, 10)) {
MDsItem.push_back(
ConstantAsMetadata::get(ConstantInt::get(Int32Ty, Num)));
} else {
MDsItem.push_back(MDString::get(Ctx, Item));
}
}
if (MDsItem.size() == 0)
return SmallVector<Metadata *>{};
MDs.push_back(MDNode::get(Ctx, MDsItem));
}
return Pos == static_cast<int>(Anno.length()) ? MDs
: SmallVector<Metadata *>{};
}
static void lowerPtrAnnotation(IntrinsicInst *II) {
LLVMContext &Ctx = II->getContext();
Type *Int32Ty = Type::getInt32Ty(Ctx);
// Retrieve an annotation string from arguments.
Value *PtrArg = nullptr;
if (auto *BI = dyn_cast<BitCastInst>(II->getArgOperand(0)))
PtrArg = BI->getOperand(0);
else
PtrArg = II->getOperand(0);
std::string Anno =
getAnnotation(II->getArgOperand(1),
4 < II->arg_size() ? II->getArgOperand(4) : nullptr);
// Parse the annotation.
SmallVector<Metadata *> MDs = parseAnnotation(II, Anno, Ctx, Int32Ty);
// If the annotation string is not parsed successfully we don't know the
// format used and output it as a general UserSemantic decoration.
// Otherwise MDs is a Metadata tuple (a decoration list) in the format
// expected by `spirv.Decorations`.
if (MDs.size() == 0) {
auto UserSemantic = ConstantAsMetadata::get(ConstantInt::get(
Int32Ty, static_cast<uint32_t>(SPIRV::Decoration::UserSemantic)));
MDs.push_back(MDNode::get(Ctx, {UserSemantic, MDString::get(Ctx, Anno)}));
}
// Build the internal intrinsic function.
IRBuilder<> IRB(II->getParent());
IRB.SetInsertPoint(II);
IRB.CreateIntrinsic(
Intrinsic::spv_assign_decoration, {PtrArg->getType()},
{PtrArg, MetadataAsValue::get(Ctx, MDNode::get(Ctx, MDs))});
II->replaceAllUsesWith(II->getOperand(0));
}
static void lowerFunnelShifts(IntrinsicInst *FSHIntrinsic) {
// Get a separate function - otherwise, we'd have to rework the CFG of the
// current one. Then simply replace the intrinsic uses with a call to the new
// function.
// Generate LLVM IR for i* @spirv.llvm_fsh?_i* (i* %a, i* %b, i* %c)
Module *M = FSHIntrinsic->getModule();
FunctionType *FSHFuncTy = FSHIntrinsic->getFunctionType();
Type *FSHRetTy = FSHFuncTy->getReturnType();
const std::string FuncName = lowerLLVMIntrinsicName(FSHIntrinsic);
Function *FSHFunc =
getOrCreateFunction(M, FSHRetTy, FSHFuncTy->params(), FuncName);
if (!FSHFunc->empty()) {
FSHIntrinsic->setCalledFunction(FSHFunc);
return;
}
BasicBlock *RotateBB = BasicBlock::Create(M->getContext(), "rotate", FSHFunc);
IRBuilder<> IRB(RotateBB);
Type *Ty = FSHFunc->getReturnType();
// Build the actual funnel shift rotate logic.
// In the comments, "int" is used interchangeably with "vector of int
// elements".
FixedVectorType *VectorTy = dyn_cast<FixedVectorType>(Ty);
Type *IntTy = VectorTy ? VectorTy->getElementType() : Ty;
unsigned BitWidth = IntTy->getIntegerBitWidth();
ConstantInt *BitWidthConstant = IRB.getInt({BitWidth, BitWidth});
Value *BitWidthForInsts =
VectorTy
? IRB.CreateVectorSplat(VectorTy->getNumElements(), BitWidthConstant)
: BitWidthConstant;
Value *RotateModVal =
IRB.CreateURem(/*Rotate*/ FSHFunc->getArg(2), BitWidthForInsts);
Value *FirstShift = nullptr, *SecShift = nullptr;
if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {
// Shift the less significant number right, the "rotate" number of bits
// will be 0-filled on the left as a result of this regular shift.
FirstShift = IRB.CreateLShr(FSHFunc->getArg(1), RotateModVal);
} else {
// Shift the more significant number left, the "rotate" number of bits
// will be 0-filled on the right as a result of this regular shift.
FirstShift = IRB.CreateShl(FSHFunc->getArg(0), RotateModVal);
}
// We want the "rotate" number of the more significant int's LSBs (MSBs) to
// occupy the leftmost (rightmost) "0 space" left by the previous operation.
// Therefore, subtract the "rotate" number from the integer bitsize...
Value *SubRotateVal = IRB.CreateSub(BitWidthForInsts, RotateModVal);
if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {
// ...and left-shift the more significant int by this number, zero-filling
// the LSBs.
SecShift = IRB.CreateShl(FSHFunc->getArg(0), SubRotateVal);
} else {
// ...and right-shift the less significant int by this number, zero-filling
// the MSBs.
SecShift = IRB.CreateLShr(FSHFunc->getArg(1), SubRotateVal);
}
// A simple binary addition of the shifted ints yields the final result.
IRB.CreateRet(IRB.CreateOr(FirstShift, SecShift));
FSHIntrinsic->setCalledFunction(FSHFunc);
}
static void lowerExpectAssume(IntrinsicInst *II) {
// If we cannot use the SPV_KHR_expect_assume extension, then we need to
// ignore the intrinsic and move on. It should be removed later on by LLVM.
// Otherwise we should lower the intrinsic to the corresponding SPIR-V
// instruction.
// For @llvm.assume we have OpAssumeTrueKHR.
// For @llvm.expect we have OpExpectKHR.
//
// We need to lower this into a builtin and then the builtin into a SPIR-V
// instruction.
if (II->getIntrinsicID() == Intrinsic::assume) {
Function *F = Intrinsic::getOrInsertDeclaration(
II->getModule(), Intrinsic::SPVIntrinsics::spv_assume);
II->setCalledFunction(F);
} else if (II->getIntrinsicID() == Intrinsic::expect) {
Function *F = Intrinsic::getOrInsertDeclaration(
II->getModule(), Intrinsic::SPVIntrinsics::spv_expect,
{II->getOperand(0)->getType()});
II->setCalledFunction(F);
} else {
llvm_unreachable("Unknown intrinsic");
}
}
static bool toSpvOverloadedIntrinsic(IntrinsicInst *II, Intrinsic::ID NewID,
ArrayRef<unsigned> OpNos) {
Function *F = nullptr;
if (OpNos.empty()) {
F = Intrinsic::getOrInsertDeclaration(II->getModule(), NewID);
} else {
SmallVector<Type *, 4> Tys;
for (unsigned OpNo : OpNos)
Tys.push_back(II->getOperand(OpNo)->getType());
F = Intrinsic::getOrInsertDeclaration(II->getModule(), NewID, Tys);
}
II->setCalledFunction(F);
return true;
}
// Substitutes calls to LLVM intrinsics with either calls to SPIR-V intrinsics
// or calls to proper generated functions. Returns True if F was modified.
bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) {
bool Changed = false;
const SPIRVSubtarget &STI = TM.getSubtarget<SPIRVSubtarget>(*F);
for (BasicBlock &BB : *F) {
for (Instruction &I : make_early_inc_range(BB)) {
auto Call = dyn_cast<CallInst>(&I);
if (!Call)
continue;
Function *CF = Call->getCalledFunction();
if (!CF || !CF->isIntrinsic())
continue;
auto *II = cast<IntrinsicInst>(Call);
switch (II->getIntrinsicID()) {
case Intrinsic::memset:
case Intrinsic::bswap:
Changed |= lowerIntrinsicToFunction(II);
break;
case Intrinsic::fshl:
case Intrinsic::fshr:
lowerFunnelShifts(II);
Changed = true;
break;
case Intrinsic::assume:
case Intrinsic::expect:
if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume))
lowerExpectAssume(II);
Changed = true;
break;
case Intrinsic::lifetime_start:
if (!STI.isShader()) {
Changed |= toSpvOverloadedIntrinsic(
II, Intrinsic::SPVIntrinsics::spv_lifetime_start, {1});
} else {
II->eraseFromParent();
}
break;
case Intrinsic::lifetime_end:
if (!STI.isShader()) {
Changed |= toSpvOverloadedIntrinsic(
II, Intrinsic::SPVIntrinsics::spv_lifetime_end, {1});
} else {
II->eraseFromParent();
}
break;
case Intrinsic::ptr_annotation:
lowerPtrAnnotation(II);
Changed = true;
break;
}
}
}
return Changed;
}
// Returns F if aggregate argument/return types are not present or cloned F
// function with the types replaced by i32 types. The change in types is
// noted in 'spv.cloned_funcs' metadata for later restoration.
Function *
SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) {
bool IsRetAggr = F->getReturnType()->isAggregateType();
// Allow intrinsics with aggregate return type to reach GlobalISel
if (F->isIntrinsic() && IsRetAggr)
return F;
IRBuilder<> B(F->getContext());
bool HasAggrArg = llvm::any_of(F->args(), [](Argument &Arg) {
return Arg.getType()->isAggregateType();
});
bool DoClone = IsRetAggr || HasAggrArg;
if (!DoClone)
return F;
SmallVector<std::pair<int, Type *>, 4> ChangedTypes;
Type *RetType = IsRetAggr ? B.getInt32Ty() : F->getReturnType();
if (IsRetAggr)
ChangedTypes.push_back(std::pair<int, Type *>(-1, F->getReturnType()));
SmallVector<Type *, 4> ArgTypes;
for (const auto &Arg : F->args()) {
if (Arg.getType()->isAggregateType()) {
ArgTypes.push_back(B.getInt32Ty());
ChangedTypes.push_back(
std::pair<int, Type *>(Arg.getArgNo(), Arg.getType()));
} else
ArgTypes.push_back(Arg.getType());
}
FunctionType *NewFTy =
FunctionType::get(RetType, ArgTypes, F->getFunctionType()->isVarArg());
Function *NewF =
Function::Create(NewFTy, F->getLinkage(), F->getName(), *F->getParent());
ValueToValueMapTy VMap;
auto NewFArgIt = NewF->arg_begin();
for (auto &Arg : F->args()) {
StringRef ArgName = Arg.getName();
NewFArgIt->setName(ArgName);
VMap[&Arg] = &(*NewFArgIt++);
}
SmallVector<ReturnInst *, 8> Returns;
CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly,
Returns);
NewF->takeName(F);
NamedMDNode *FuncMD =
F->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs");
SmallVector<Metadata *, 2> MDArgs;
MDArgs.push_back(MDString::get(B.getContext(), NewF->getName()));
for (auto &ChangedTyP : ChangedTypes)
MDArgs.push_back(MDNode::get(
B.getContext(),
{ConstantAsMetadata::get(B.getInt32(ChangedTyP.first)),
ValueAsMetadata::get(Constant::getNullValue(ChangedTyP.second))}));
MDNode *ThisFuncMD = MDNode::get(B.getContext(), MDArgs);
FuncMD->addOperand(ThisFuncMD);
for (auto *U : make_early_inc_range(F->users())) {
if (auto *CI = dyn_cast<CallInst>(U))
CI->mutateFunctionType(NewF->getFunctionType());
U->replaceUsesOfWith(F, NewF);
}
// register the mutation
if (RetType != F->getReturnType())
TM.getSubtarget<SPIRVSubtarget>(*F).getSPIRVGlobalRegistry()->addMutated(
NewF, F->getReturnType());
return NewF;
}
bool SPIRVPrepareFunctions::runOnModule(Module &M) {
bool Changed = false;
for (Function &F : M) {
Changed |= substituteIntrinsicCalls(&F);
Changed |= sortBlocks(F);
}
std::vector<Function *> FuncsWorklist;
for (auto &F : M)
FuncsWorklist.push_back(&F);
for (auto *F : FuncsWorklist) {
Function *NewF = removeAggregateTypesFromSignature(F);
if (NewF != F) {
F->eraseFromParent();
Changed = true;
}
}
return Changed;
}
ModulePass *
llvm::createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM) {
return new SPIRVPrepareFunctions(TM);
}