llvm-project/llvm/lib/Target/SPIRV/SPIRVRegularizer.cpp
Manuel Carrasco 90b8a481e4
[SPIRV] Support additional comparison predicates for i1 types (#174585)
Previously, the SPIRV BE only handled equality and inequality
comparisons (ICMP_EQ, ICMP_NE) for i1 types using logical operations
(OpLogicalEqual, OpLogicalNotEqual). Other comparison predicates
(signed/unsigned less than, greater than, etc.) triggered an unreachable
assertion.

This patch extends the support for the missing predicates. The BE
considers i1 values as booleans and in SPIR-V only logical operations
can work on them. This patch lowers the missing predicates into
supported logical operations.

The lowering has been validated with instcombine to avoid introducing an
unsound transformation (using the new test case).
2026-01-08 11:53:45 +00:00

221 lines
7.6 KiB
C++

//===-- SPIRVRegularizer.cpp - regularize IR for SPIR-V ---------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This pass implements regularization of LLVM IR for SPIR-V. The prototype of
// the pass was taken from SPIRV-LLVM translator.
//
//===----------------------------------------------------------------------===//
#include "SPIRV.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/PassManager.h"
#include <list>
#define DEBUG_TYPE "spirv-regularizer"
using namespace llvm;
namespace {
struct SPIRVRegularizer : public FunctionPass {
public:
static char ID;
SPIRVRegularizer() : FunctionPass(ID) {}
bool runOnFunction(Function &F) override;
StringRef getPassName() const override { return "SPIR-V Regularizer"; }
void getAnalysisUsage(AnalysisUsage &AU) const override {
FunctionPass::getAnalysisUsage(AU);
}
private:
void runLowerConstExpr(Function &F);
void runLowerI1Comparisons(Function &F);
};
} // namespace
char SPIRVRegularizer::ID = 0;
INITIALIZE_PASS(SPIRVRegularizer, DEBUG_TYPE, "SPIR-V Regularizer", false,
false)
// Since SPIR-V cannot represent constant expression, constant expressions
// in LLVM IR need to be lowered to instructions. For each function,
// the constant expressions used by instructions of the function are replaced
// by instructions placed in the entry block since it dominates all other BBs.
// Each constant expression only needs to be lowered once in each function
// and all uses of it by instructions in that function are replaced by
// one instruction.
// TODO: remove redundant instructions for common subexpression.
void SPIRVRegularizer::runLowerConstExpr(Function &F) {
LLVMContext &Ctx = F.getContext();
std::list<Instruction *> WorkList;
for (auto &II : instructions(F))
WorkList.push_back(&II);
auto FBegin = F.begin();
while (!WorkList.empty()) {
Instruction *II = WorkList.front();
auto LowerOp = [&II, &FBegin, &F](Value *V) -> Value * {
if (isa<Function>(V))
return V;
auto *CE = cast<ConstantExpr>(V);
LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] " << *CE);
auto ReplInst = CE->getAsInstruction();
auto InsPoint = II->getParent() == &*FBegin ? II : &FBegin->back();
ReplInst->insertBefore(InsPoint->getIterator());
LLVM_DEBUG(dbgs() << " -> " << *ReplInst << '\n');
std::vector<Instruction *> Users;
// Do not replace use during iteration of use. Do it in another loop.
for (auto U : CE->users()) {
LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] Use: " << *U << '\n');
auto InstUser = dyn_cast<Instruction>(U);
// Only replace users in scope of current function.
if (InstUser && InstUser->getParent()->getParent() == &F)
Users.push_back(InstUser);
}
for (auto &User : Users) {
if (ReplInst->getParent() == User->getParent() &&
User->comesBefore(ReplInst))
ReplInst->moveBefore(User->getIterator());
User->replaceUsesOfWith(CE, ReplInst);
}
return ReplInst;
};
WorkList.pop_front();
auto LowerConstantVec = [&II, &LowerOp, &WorkList,
&Ctx](ConstantVector *Vec,
unsigned NumOfOp) -> Value * {
if (std::all_of(Vec->op_begin(), Vec->op_end(), [](Value *V) {
return isa<ConstantExpr>(V) || isa<Function>(V);
})) {
// Expand a vector of constexprs and construct it back with
// series of insertelement instructions.
std::list<Value *> OpList;
std::transform(Vec->op_begin(), Vec->op_end(),
std::back_inserter(OpList),
[LowerOp](Value *V) { return LowerOp(V); });
Value *Repl = nullptr;
unsigned Idx = 0;
auto *PhiII = dyn_cast<PHINode>(II);
Instruction *InsPoint =
PhiII ? &PhiII->getIncomingBlock(NumOfOp)->back() : II;
std::list<Instruction *> ReplList;
for (auto V : OpList) {
if (auto *Inst = dyn_cast<Instruction>(V))
ReplList.push_back(Inst);
Repl = InsertElementInst::Create(
(Repl ? Repl : PoisonValue::get(Vec->getType())), V,
ConstantInt::get(Type::getInt32Ty(Ctx), Idx++), "",
InsPoint->getIterator());
}
WorkList.splice(WorkList.begin(), ReplList);
return Repl;
}
return nullptr;
};
for (unsigned OI = 0, OE = II->getNumOperands(); OI != OE; ++OI) {
auto *Op = II->getOperand(OI);
if (auto *Vec = dyn_cast<ConstantVector>(Op)) {
Value *ReplInst = LowerConstantVec(Vec, OI);
if (ReplInst)
II->replaceUsesOfWith(Op, ReplInst);
} else if (auto CE = dyn_cast<ConstantExpr>(Op)) {
WorkList.push_front(cast<Instruction>(LowerOp(CE)));
} else if (auto MDAsVal = dyn_cast<MetadataAsValue>(Op)) {
auto ConstMD = dyn_cast<ConstantAsMetadata>(MDAsVal->getMetadata());
if (!ConstMD)
continue;
Constant *C = ConstMD->getValue();
Value *ReplInst = nullptr;
if (auto *Vec = dyn_cast<ConstantVector>(C))
ReplInst = LowerConstantVec(Vec, OI);
if (auto *CE = dyn_cast<ConstantExpr>(C))
ReplInst = LowerOp(CE);
if (!ReplInst)
continue;
Metadata *RepMD = ValueAsMetadata::get(ReplInst);
Value *RepMDVal = MetadataAsValue::get(Ctx, RepMD);
II->setOperand(OI, RepMDVal);
WorkList.push_front(cast<Instruction>(ReplInst));
}
}
}
}
// Lower i1 comparisons with certain predicates to logical operations.
// The backend treats i1 as boolean values, and SPIR-V only allows logical
// operations for boolean values. This function lowers i1 comparisons with
// certain predicates to logical operations to generate valid SPIR-V.
void SPIRVRegularizer::runLowerI1Comparisons(Function &F) {
for (auto &I : make_early_inc_range(instructions(F))) {
auto *Cmp = dyn_cast<ICmpInst>(&I);
if (!Cmp)
continue;
bool IsI1 = Cmp->getOperand(0)->getType()->isIntegerTy(1);
if (!IsI1)
continue;
auto Pred = Cmp->getPredicate();
bool IsTargetPred =
Pred >= ICmpInst::ICMP_UGT && Pred <= ICmpInst::ICMP_SLE;
if (!IsTargetPred)
continue;
Value *P = Cmp->getOperand(0);
Value *Q = Cmp->getOperand(1);
IRBuilder<> Builder(Cmp);
Value *Result = nullptr;
switch (Pred) {
case ICmpInst::ICMP_UGT:
case ICmpInst::ICMP_SLT:
// Result = p & !q
Result = Builder.CreateAnd(P, Builder.CreateNot(Q));
break;
case ICmpInst::ICMP_ULT:
case ICmpInst::ICMP_SGT:
// Result = q & !p
Result = Builder.CreateAnd(Q, Builder.CreateNot(P));
break;
case ICmpInst::ICMP_ULE:
case ICmpInst::ICMP_SGE:
// Result = q | !p
Result = Builder.CreateOr(Q, Builder.CreateNot(P));
break;
case ICmpInst::ICMP_UGE:
case ICmpInst::ICMP_SLE:
// Result = p | !q
Result = Builder.CreateOr(P, Builder.CreateNot(Q));
break;
default:
llvm_unreachable("Unexpected predicate");
}
Result->takeName(Cmp);
Cmp->replaceAllUsesWith(Result);
Cmp->eraseFromParent();
}
}
bool SPIRVRegularizer::runOnFunction(Function &F) {
runLowerI1Comparisons(F);
runLowerConstExpr(F);
return true;
}
FunctionPass *llvm::createSPIRVRegularizerPass() {
return new SPIRVRegularizer();
}