Fixes a logic issue in the `faceforward` pattern matcher in `SPIRVCombinerHelper.cpp`. Previously when `mi_match` failed, we would still go through the nested `Pred == CmpInst::FCMP_OGT || Pred == CmpInst::FCMP_UGT` check. It was possible that whatever garbage was in Pred could randomly pass this check and make us continue through the code. This change fixes that logic issue by returning false as soon as `mi_match` fails. Likely fixes #177803. Can't confirm since it seems another change has obscured the crash.
403 lines
15 KiB
C++
403 lines
15 KiB
C++
//===-- SPIRVCombinerHelper.cpp -------------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "SPIRVCombinerHelper.h"
|
|
#include "SPIRVGlobalRegistry.h"
|
|
#include "SPIRVUtils.h"
|
|
#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
|
|
#include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
|
|
#include "llvm/IR/DerivedTypes.h"
|
|
#include "llvm/IR/IntrinsicsSPIRV.h"
|
|
#include "llvm/IR/LLVMContext.h" // Explicitly include for LLVMContext
|
|
#include "llvm/Target/TargetMachine.h"
|
|
|
|
using namespace llvm;
|
|
using namespace MIPatternMatch;
|
|
|
|
SPIRVCombinerHelper::SPIRVCombinerHelper(
|
|
GISelChangeObserver &Observer, MachineIRBuilder &B, bool IsPreLegalize,
|
|
GISelValueTracking *VT, MachineDominatorTree *MDT, const LegalizerInfo *LI,
|
|
const SPIRVSubtarget &STI)
|
|
: CombinerHelper(Observer, B, IsPreLegalize, VT, MDT, LI), STI(STI) {}
|
|
|
|
/// This match is part of a combine that
|
|
/// rewrites length(X - Y) to distance(X, Y)
|
|
/// (f32 (g_intrinsic length
|
|
/// (g_fsub (vXf32 X) (vXf32 Y))))
|
|
/// ->
|
|
/// (f32 (g_intrinsic distance
|
|
/// (vXf32 X) (vXf32 Y)))
|
|
///
|
|
bool SPIRVCombinerHelper::matchLengthToDistance(MachineInstr &MI) const {
|
|
if (MI.getOpcode() != TargetOpcode::G_INTRINSIC ||
|
|
cast<GIntrinsic>(MI).getIntrinsicID() != Intrinsic::spv_length)
|
|
return false;
|
|
|
|
// First operand of MI is `G_INTRINSIC` so start at operand 2.
|
|
Register SubReg = MI.getOperand(2).getReg();
|
|
MachineInstr *SubInstr = MRI.getVRegDef(SubReg);
|
|
if (SubInstr->getOpcode() != TargetOpcode::G_FSUB)
|
|
return false;
|
|
|
|
return true;
|
|
}
|
|
|
|
void SPIRVCombinerHelper::applySPIRVDistance(MachineInstr &MI) const {
|
|
// Extract the operands for X and Y from the match criteria.
|
|
Register SubDestReg = MI.getOperand(2).getReg();
|
|
MachineInstr *SubInstr = MRI.getVRegDef(SubDestReg);
|
|
Register SubOperand1 = SubInstr->getOperand(1).getReg();
|
|
Register SubOperand2 = SubInstr->getOperand(2).getReg();
|
|
Register ResultReg = MI.getOperand(0).getReg();
|
|
|
|
Builder.setInstrAndDebugLoc(MI);
|
|
Builder.buildIntrinsic(Intrinsic::spv_distance, ResultReg)
|
|
.addUse(SubOperand1)
|
|
.addUse(SubOperand2);
|
|
|
|
MI.eraseFromParent();
|
|
}
|
|
|
|
/// This match is part of a combine that
|
|
/// rewrites select(fcmp(dot(I, Ng), 0), N, -N) to faceforward(N, I, Ng)
|
|
/// (vXf32 (g_select
|
|
/// (g_fcmp
|
|
/// (g_intrinsic dot(vXf32 I) (vXf32 Ng)
|
|
/// 0)
|
|
/// (vXf32 N)
|
|
/// (vXf32 g_fneg (vXf32 N))))
|
|
/// ->
|
|
/// (vXf32 (g_intrinsic faceforward
|
|
/// (vXf32 N) (vXf32 I) (vXf32 Ng)))
|
|
///
|
|
/// This only works for Vulkan shader targets.
|
|
///
|
|
bool SPIRVCombinerHelper::matchSelectToFaceForward(MachineInstr &MI) const {
|
|
if (!STI.isShader())
|
|
return false;
|
|
|
|
// Match overall select pattern.
|
|
Register CondReg, TrueReg, FalseReg;
|
|
if (!mi_match(MI.getOperand(0).getReg(), MRI,
|
|
m_GISelect(m_Reg(CondReg), m_Reg(TrueReg), m_Reg(FalseReg))))
|
|
return false;
|
|
|
|
// Match the FCMP condition.
|
|
Register DotReg, CondZeroReg;
|
|
CmpInst::Predicate Pred;
|
|
if (!mi_match(CondReg, MRI,
|
|
m_GFCmp(m_Pred(Pred), m_Reg(DotReg), m_Reg(CondZeroReg))))
|
|
return false;
|
|
if (Pred == CmpInst::FCMP_OGT || Pred == CmpInst::FCMP_UGT)
|
|
std::swap(DotReg, CondZeroReg);
|
|
else if (!(Pred == CmpInst::FCMP_OLT || Pred == CmpInst::FCMP_ULT))
|
|
return false;
|
|
|
|
// Check if FCMP is a comparison between a dot product and 0.
|
|
MachineInstr *DotInstr = MRI.getVRegDef(DotReg);
|
|
if (DotInstr->getOpcode() != TargetOpcode::G_INTRINSIC ||
|
|
cast<GIntrinsic>(DotInstr)->getIntrinsicID() != Intrinsic::spv_fdot) {
|
|
Register DotOperand1, DotOperand2;
|
|
// Check for scalar dot product.
|
|
if (!mi_match(DotReg, MRI,
|
|
m_GFMul(m_Reg(DotOperand1), m_Reg(DotOperand2))) ||
|
|
!MRI.getType(DotOperand1).isScalar() ||
|
|
!MRI.getType(DotOperand2).isScalar())
|
|
return false;
|
|
}
|
|
|
|
const ConstantFP *ZeroVal;
|
|
if (!mi_match(CondZeroReg, MRI, m_GFCst(ZeroVal)) || !ZeroVal->isZero())
|
|
return false;
|
|
|
|
// Check if select's false operand is the negation of the true operand.
|
|
auto AreNegatedConstantsOrSplats = [&](Register TrueReg, Register FalseReg) {
|
|
std::optional<FPValueAndVReg> TrueVal, FalseVal;
|
|
if (!mi_match(TrueReg, MRI, m_GFCstOrSplat(TrueVal)) ||
|
|
!mi_match(FalseReg, MRI, m_GFCstOrSplat(FalseVal)))
|
|
return false;
|
|
APFloat TrueValNegated = TrueVal->Value;
|
|
TrueValNegated.changeSign();
|
|
return FalseVal->Value.compare(TrueValNegated) == APFloat::cmpEqual;
|
|
};
|
|
|
|
if (!mi_match(TrueReg, MRI, m_GFNeg(m_SpecificReg(FalseReg))) &&
|
|
!mi_match(FalseReg, MRI, m_GFNeg(m_SpecificReg(TrueReg)))) {
|
|
std::optional<FPValueAndVReg> MulConstant;
|
|
MachineInstr *TrueInstr = MRI.getVRegDef(TrueReg);
|
|
MachineInstr *FalseInstr = MRI.getVRegDef(FalseReg);
|
|
if (TrueInstr->getOpcode() == TargetOpcode::G_BUILD_VECTOR &&
|
|
FalseInstr->getOpcode() == TargetOpcode::G_BUILD_VECTOR &&
|
|
TrueInstr->getNumOperands() == FalseInstr->getNumOperands()) {
|
|
for (unsigned I = 1; I < TrueInstr->getNumOperands(); ++I)
|
|
if (!AreNegatedConstantsOrSplats(TrueInstr->getOperand(I).getReg(),
|
|
FalseInstr->getOperand(I).getReg()))
|
|
return false;
|
|
} else if (mi_match(TrueReg, MRI,
|
|
m_GFMul(m_SpecificReg(FalseReg),
|
|
m_GFCstOrSplat(MulConstant))) ||
|
|
mi_match(FalseReg, MRI,
|
|
m_GFMul(m_SpecificReg(TrueReg),
|
|
m_GFCstOrSplat(MulConstant))) ||
|
|
mi_match(TrueReg, MRI,
|
|
m_GFMul(m_GFCstOrSplat(MulConstant),
|
|
m_SpecificReg(FalseReg))) ||
|
|
mi_match(FalseReg, MRI,
|
|
m_GFMul(m_GFCstOrSplat(MulConstant),
|
|
m_SpecificReg(TrueReg)))) {
|
|
if (!MulConstant || !MulConstant->Value.isExactlyValue(-1.0))
|
|
return false;
|
|
} else if (!AreNegatedConstantsOrSplats(TrueReg, FalseReg))
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
void SPIRVCombinerHelper::applySPIRVFaceForward(MachineInstr &MI) const {
|
|
// Extract the operands for N, I, and Ng from the match criteria.
|
|
Register CondReg = MI.getOperand(1).getReg();
|
|
MachineInstr *CondInstr = MRI.getVRegDef(CondReg);
|
|
Register DotReg = CondInstr->getOperand(2).getReg();
|
|
CmpInst::Predicate Pred = cast<GFCmp>(CondInstr)->getCond();
|
|
if (Pred == CmpInst::FCMP_OGT || Pred == CmpInst::FCMP_UGT)
|
|
DotReg = CondInstr->getOperand(3).getReg();
|
|
MachineInstr *DotInstr = MRI.getVRegDef(DotReg);
|
|
Register DotOperand1, DotOperand2;
|
|
if (DotInstr->getOpcode() == TargetOpcode::G_FMUL) {
|
|
DotOperand1 = DotInstr->getOperand(1).getReg();
|
|
DotOperand2 = DotInstr->getOperand(2).getReg();
|
|
} else {
|
|
DotOperand1 = DotInstr->getOperand(2).getReg();
|
|
DotOperand2 = DotInstr->getOperand(3).getReg();
|
|
}
|
|
Register TrueReg = MI.getOperand(2).getReg();
|
|
Register FalseReg = MI.getOperand(3).getReg();
|
|
MachineInstr *TrueInstr = MRI.getVRegDef(TrueReg);
|
|
if (TrueInstr->getOpcode() == TargetOpcode::G_FNEG ||
|
|
TrueInstr->getOpcode() == TargetOpcode::G_FMUL)
|
|
std::swap(TrueReg, FalseReg);
|
|
MachineInstr *FalseInstr = MRI.getVRegDef(FalseReg);
|
|
|
|
Register ResultReg = MI.getOperand(0).getReg();
|
|
Builder.setInstrAndDebugLoc(MI);
|
|
Builder.buildIntrinsic(Intrinsic::spv_faceforward, ResultReg)
|
|
.addUse(TrueReg) // N
|
|
.addUse(DotOperand1) // I
|
|
.addUse(DotOperand2); // Ng
|
|
|
|
SPIRVGlobalRegistry *GR =
|
|
MI.getMF()->getSubtarget<SPIRVSubtarget>().getSPIRVGlobalRegistry();
|
|
auto RemoveAllUses = [&](Register Reg) {
|
|
SmallVector<MachineInstr *, 4> UsesToErase;
|
|
for (auto &UseMI : MRI.use_instructions(Reg))
|
|
UsesToErase.push_back(&UseMI);
|
|
|
|
// calling eraseFromParent to early invalidates the iterator.
|
|
for (auto *MIToErase : UsesToErase)
|
|
MIToErase->eraseFromParent();
|
|
};
|
|
|
|
RemoveAllUses(CondReg); // remove all uses of FCMP Result
|
|
GR->invalidateMachineInstr(CondInstr);
|
|
CondInstr->eraseFromParent(); // remove FCMP instruction
|
|
RemoveAllUses(DotReg); // remove all uses of spv_fdot/G_FMUL Result
|
|
GR->invalidateMachineInstr(DotInstr);
|
|
DotInstr->eraseFromParent(); // remove spv_fdot/G_FMUL instruction
|
|
RemoveAllUses(FalseReg);
|
|
GR->invalidateMachineInstr(FalseInstr);
|
|
FalseInstr->eraseFromParent();
|
|
}
|
|
|
|
bool SPIRVCombinerHelper::matchMatrixTranspose(MachineInstr &MI) const {
|
|
return MI.getOpcode() == TargetOpcode::G_INTRINSIC &&
|
|
cast<GIntrinsic>(MI).getIntrinsicID() == Intrinsic::matrix_transpose;
|
|
}
|
|
|
|
void SPIRVCombinerHelper::applyMatrixTranspose(MachineInstr &MI) const {
|
|
Register ResReg = MI.getOperand(0).getReg();
|
|
Register InReg = MI.getOperand(2).getReg();
|
|
uint32_t Rows = MI.getOperand(3).getImm();
|
|
uint32_t Cols = MI.getOperand(4).getImm();
|
|
|
|
Builder.setInstrAndDebugLoc(MI);
|
|
|
|
if (Rows == 1 && Cols == 1) {
|
|
Builder.buildCopy(ResReg, InReg);
|
|
MI.eraseFromParent();
|
|
return;
|
|
}
|
|
|
|
SmallVector<int, 16> Mask;
|
|
for (uint32_t K = 0; K < Rows * Cols; ++K) {
|
|
uint32_t R = K / Cols;
|
|
uint32_t C = K % Cols;
|
|
Mask.push_back(C * Rows + R);
|
|
}
|
|
|
|
Builder.buildShuffleVector(ResReg, InReg, InReg, Mask);
|
|
MI.eraseFromParent();
|
|
}
|
|
|
|
bool SPIRVCombinerHelper::matchMatrixMultiply(MachineInstr &MI) const {
|
|
return MI.getOpcode() == TargetOpcode::G_INTRINSIC &&
|
|
cast<GIntrinsic>(MI).getIntrinsicID() == Intrinsic::matrix_multiply;
|
|
}
|
|
|
|
SmallVector<Register, 4>
|
|
SPIRVCombinerHelper::extractColumns(Register MatrixReg, uint32_t NumberOfCols,
|
|
SPIRVTypeInst SpvColType,
|
|
SPIRVGlobalRegistry *GR) const {
|
|
// If the matrix is a single colunm, return that single column.
|
|
if (NumberOfCols == 1)
|
|
return {MatrixReg};
|
|
|
|
SmallVector<Register, 4> Cols;
|
|
LLT ColTy = GR->getRegType(SpvColType);
|
|
for (uint32_t J = 0; J < NumberOfCols; ++J)
|
|
Cols.push_back(MRI.createGenericVirtualRegister(ColTy));
|
|
Builder.buildUnmerge(Cols, MatrixReg);
|
|
for (Register R : Cols) {
|
|
setRegClassType(R, SpvColType, GR, &MRI, Builder.getMF());
|
|
}
|
|
return Cols;
|
|
}
|
|
|
|
SmallVector<Register, 4>
|
|
SPIRVCombinerHelper::extractRows(Register MatrixReg, uint32_t NumRows,
|
|
uint32_t NumCols, SPIRVTypeInst SpvRowType,
|
|
SPIRVGlobalRegistry *GR) const {
|
|
SmallVector<Register, 4> Rows;
|
|
LLT VecTy = GR->getRegType(SpvRowType);
|
|
|
|
// If there is only one column, then each row is a scalar that needs
|
|
// to be extracted.
|
|
if (NumCols == 1) {
|
|
assert(SpvRowType->getOpcode() != SPIRV::OpTypeVector);
|
|
for (uint32_t I = 0; I < NumRows; ++I)
|
|
Rows.push_back(MRI.createGenericVirtualRegister(VecTy));
|
|
Builder.buildUnmerge(Rows, MatrixReg);
|
|
for (Register R : Rows) {
|
|
setRegClassType(R, SpvRowType, GR, &MRI, Builder.getMF());
|
|
}
|
|
return Rows;
|
|
}
|
|
|
|
// If the matrix is a single row return that row.
|
|
if (NumRows == 1) {
|
|
return {MatrixReg};
|
|
}
|
|
|
|
for (uint32_t I = 0; I < NumRows; ++I) {
|
|
SmallVector<int, 4> Mask;
|
|
for (uint32_t k = 0; k < NumCols; ++k)
|
|
Mask.push_back(k * NumRows + I);
|
|
Rows.push_back(Builder.buildShuffleVector(VecTy, MatrixReg, MatrixReg, Mask)
|
|
.getReg(0));
|
|
}
|
|
for (Register R : Rows) {
|
|
setRegClassType(R, SpvRowType, GR, &MRI, Builder.getMF());
|
|
}
|
|
return Rows;
|
|
}
|
|
|
|
Register SPIRVCombinerHelper::computeDotProduct(Register RowA, Register ColB,
|
|
SPIRVTypeInst SpvVecType,
|
|
SPIRVGlobalRegistry *GR) const {
|
|
bool IsVectorOp = SpvVecType->getOpcode() == SPIRV::OpTypeVector;
|
|
SPIRVTypeInst SpvScalarType = GR->getScalarOrVectorComponentType(SpvVecType);
|
|
bool IsFloatOp = SpvScalarType->getOpcode() == SPIRV::OpTypeFloat;
|
|
LLT VecTy = GR->getRegType(SpvVecType);
|
|
|
|
Register DotRes;
|
|
if (IsVectorOp) {
|
|
LLT ScalarTy = VecTy.getElementType();
|
|
Intrinsic::SPVIntrinsics DotIntrinsic =
|
|
(IsFloatOp ? Intrinsic::spv_fdot : Intrinsic::spv_udot);
|
|
DotRes = Builder.buildIntrinsic(DotIntrinsic, {ScalarTy})
|
|
.addUse(RowA)
|
|
.addUse(ColB)
|
|
.getReg(0);
|
|
} else {
|
|
if (IsFloatOp)
|
|
DotRes = Builder.buildFMul(VecTy, RowA, ColB).getReg(0);
|
|
else
|
|
DotRes = Builder.buildMul(VecTy, RowA, ColB).getReg(0);
|
|
}
|
|
setRegClassType(DotRes, SpvScalarType, GR, &MRI, Builder.getMF());
|
|
return DotRes;
|
|
}
|
|
|
|
SmallVector<Register, 16>
|
|
SPIRVCombinerHelper::computeDotProducts(const SmallVector<Register, 4> &RowsA,
|
|
const SmallVector<Register, 4> &ColsB,
|
|
SPIRVTypeInst SpvVecType,
|
|
SPIRVGlobalRegistry *GR) const {
|
|
SmallVector<Register, 16> ResultScalars;
|
|
for (uint32_t J = 0; J < ColsB.size(); ++J) {
|
|
for (uint32_t I = 0; I < RowsA.size(); ++I) {
|
|
ResultScalars.push_back(
|
|
computeDotProduct(RowsA[I], ColsB[J], SpvVecType, GR));
|
|
}
|
|
}
|
|
return ResultScalars;
|
|
}
|
|
|
|
SPIRVTypeInst
|
|
SPIRVCombinerHelper::getDotProductVectorType(Register ResReg, uint32_t K,
|
|
SPIRVGlobalRegistry *GR) const {
|
|
// Loop over all non debug uses of ResReg
|
|
Type *ScalarResType = nullptr;
|
|
for (auto &UseMI : MRI.use_instructions(ResReg)) {
|
|
if (UseMI.getOpcode() != TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS)
|
|
continue;
|
|
|
|
if (!isSpvIntrinsic(UseMI, Intrinsic::spv_assign_type))
|
|
continue;
|
|
|
|
Type *Ty = getMDOperandAsType(UseMI.getOperand(2).getMetadata(), 0);
|
|
if (Ty->isVectorTy())
|
|
ScalarResType = cast<VectorType>(Ty)->getElementType();
|
|
else
|
|
ScalarResType = Ty;
|
|
assert(ScalarResType->isIntegerTy() || ScalarResType->isFloatingPointTy());
|
|
break;
|
|
}
|
|
if (!ScalarResType)
|
|
llvm_unreachable("Could not determine scalar result type");
|
|
Type *VecType =
|
|
(K > 1 ? FixedVectorType::get(ScalarResType, K) : ScalarResType);
|
|
return GR->getOrCreateSPIRVType(VecType, Builder,
|
|
SPIRV::AccessQualifier::None, false);
|
|
}
|
|
|
|
void SPIRVCombinerHelper::applyMatrixMultiply(MachineInstr &MI) const {
|
|
Register ResReg = MI.getOperand(0).getReg();
|
|
Register AReg = MI.getOperand(2).getReg();
|
|
Register BReg = MI.getOperand(3).getReg();
|
|
uint32_t NumRowsA = MI.getOperand(4).getImm();
|
|
uint32_t NumColsA = MI.getOperand(5).getImm();
|
|
uint32_t NumColsB = MI.getOperand(6).getImm();
|
|
|
|
Builder.setInstrAndDebugLoc(MI);
|
|
|
|
SPIRVGlobalRegistry *GR =
|
|
MI.getMF()->getSubtarget<SPIRVSubtarget>().getSPIRVGlobalRegistry();
|
|
|
|
SPIRVTypeInst SpvVecType = getDotProductVectorType(ResReg, NumColsA, GR);
|
|
SmallVector<Register, 4> ColsB =
|
|
extractColumns(BReg, NumColsB, SpvVecType, GR);
|
|
SmallVector<Register, 4> RowsA =
|
|
extractRows(AReg, NumRowsA, NumColsA, SpvVecType, GR);
|
|
SmallVector<Register, 16> ResultScalars =
|
|
computeDotProducts(RowsA, ColsB, SpvVecType, GR);
|
|
|
|
Builder.buildBuildVector(ResReg, ResultScalars);
|
|
MI.eraseFromParent();
|
|
}
|