llvm-project/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
Sarah Spall 7f34d3acba
[DirectX] Add support for typedBufferLoad and Store for RWBuffer<double2> and RWBuffer<double> (#139996)
typedBufferLoad of double/double2 is expanded to a typedBufferLoad of a
<2 x i32>/<4 x i32> and asdouble
typedBufferStore of a double/double2 is expanded to a splitdouble and a
typedBufferStore of a <2 x i32>/<4 x i32>
Add tests showing result of intrinsic expansion for typedBufferLoad and
typedBufferStore
Add tests showing dxil op lowering can handle typedBufferLoad and
typedBufferStore where the target type doesn't match the typedBufferLoad
and typedBufferStore type
Closes #104423
2025-05-30 08:16:19 -07:00

838 lines
28 KiB
C++

//===- DXILIntrinsicExpansion.cpp - Prepare LLVM Module for DXIL encoding--===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file This file contains DXIL intrinsic expansions for those that don't have
// opcodes in DirectX Intermediate Language (DXIL).
//===----------------------------------------------------------------------===//
#include "DXILIntrinsicExpansion.h"
#include "DirectX.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/CodeGen/Passes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsDirectX.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Type.h"
#include "llvm/Pass.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/MathExtras.h"
#define DEBUG_TYPE "dxil-intrinsic-expansion"
using namespace llvm;
class DXILIntrinsicExpansionLegacy : public ModulePass {
public:
bool runOnModule(Module &M) override;
DXILIntrinsicExpansionLegacy() : ModulePass(ID) {}
static char ID; // Pass identification.
};
static bool isIntrinsicExpansion(Function &F) {
switch (F.getIntrinsicID()) {
case Intrinsic::abs:
case Intrinsic::atan2:
case Intrinsic::exp:
case Intrinsic::is_fpclass:
case Intrinsic::log:
case Intrinsic::log10:
case Intrinsic::pow:
case Intrinsic::powi:
case Intrinsic::dx_all:
case Intrinsic::dx_any:
case Intrinsic::dx_cross:
case Intrinsic::dx_uclamp:
case Intrinsic::dx_sclamp:
case Intrinsic::dx_nclamp:
case Intrinsic::dx_degrees:
case Intrinsic::dx_lerp:
case Intrinsic::dx_normalize:
case Intrinsic::dx_fdot:
case Intrinsic::dx_sdot:
case Intrinsic::dx_udot:
case Intrinsic::dx_sign:
case Intrinsic::dx_step:
case Intrinsic::dx_radians:
case Intrinsic::usub_sat:
case Intrinsic::vector_reduce_add:
case Intrinsic::vector_reduce_fadd:
return true;
case Intrinsic::dx_resource_load_typedbuffer:
// We need to handle doubles and vector of doubles.
return F.getReturnType()
->getStructElementType(0)
->getScalarType()
->isDoubleTy();
case Intrinsic::dx_resource_store_typedbuffer:
// We need to handle doubles and vector of doubles.
return F.getFunctionType()->getParamType(2)->getScalarType()->isDoubleTy();
}
return false;
}
static Value *expandUsubSat(CallInst *Orig) {
Value *A = Orig->getArgOperand(0);
Value *B = Orig->getArgOperand(1);
Type *Ty = A->getType();
IRBuilder<> Builder(Orig);
Value *Cmp = Builder.CreateICmpULT(A, B, "usub.cmp");
Value *Sub = Builder.CreateSub(A, B, "usub.sub");
Value *Zero = ConstantInt::get(Ty, 0);
return Builder.CreateSelect(Cmp, Zero, Sub, "usub.sat");
}
static Value *expandVecReduceAdd(CallInst *Orig, Intrinsic::ID IntrinsicId) {
assert(IntrinsicId == Intrinsic::vector_reduce_add ||
IntrinsicId == Intrinsic::vector_reduce_fadd);
IRBuilder<> Builder(Orig);
bool IsFAdd = (IntrinsicId == Intrinsic::vector_reduce_fadd);
Value *X = Orig->getOperand(IsFAdd ? 1 : 0);
Type *Ty = X->getType();
auto *XVec = dyn_cast<FixedVectorType>(Ty);
unsigned XVecSize = XVec->getNumElements();
Value *Sum = Builder.CreateExtractElement(X, static_cast<uint64_t>(0));
// Handle the initial start value for floating-point addition.
if (IsFAdd) {
Constant *StartValue = dyn_cast<Constant>(Orig->getOperand(0));
if (StartValue && !StartValue->isZeroValue())
Sum = Builder.CreateFAdd(Sum, StartValue);
}
// Accumulate the remaining vector elements.
for (unsigned I = 1; I < XVecSize; I++) {
Value *Elt = Builder.CreateExtractElement(X, I);
if (IsFAdd)
Sum = Builder.CreateFAdd(Sum, Elt);
else
Sum = Builder.CreateAdd(Sum, Elt);
}
return Sum;
}
static Value *expandAbs(CallInst *Orig) {
Value *X = Orig->getOperand(0);
IRBuilder<> Builder(Orig);
Type *Ty = X->getType();
Type *EltTy = Ty->getScalarType();
Constant *Zero = Ty->isVectorTy()
? ConstantVector::getSplat(
ElementCount::getFixed(
cast<FixedVectorType>(Ty)->getNumElements()),
ConstantInt::get(EltTy, 0))
: ConstantInt::get(EltTy, 0);
auto *V = Builder.CreateSub(Zero, X);
return Builder.CreateIntrinsic(Ty, Intrinsic::smax, {X, V}, nullptr,
"dx.max");
}
static Value *expandCrossIntrinsic(CallInst *Orig) {
VectorType *VT = cast<VectorType>(Orig->getType());
if (cast<FixedVectorType>(VT)->getNumElements() != 3)
reportFatalUsageError("return vector must have exactly 3 elements");
Value *op0 = Orig->getOperand(0);
Value *op1 = Orig->getOperand(1);
IRBuilder<> Builder(Orig);
Value *op0_x = Builder.CreateExtractElement(op0, (uint64_t)0, "x0");
Value *op0_y = Builder.CreateExtractElement(op0, 1, "x1");
Value *op0_z = Builder.CreateExtractElement(op0, 2, "x2");
Value *op1_x = Builder.CreateExtractElement(op1, (uint64_t)0, "y0");
Value *op1_y = Builder.CreateExtractElement(op1, 1, "y1");
Value *op1_z = Builder.CreateExtractElement(op1, 2, "y2");
auto MulSub = [&](Value *x0, Value *y0, Value *x1, Value *y1) -> Value * {
Value *xy = Builder.CreateFMul(x0, y1);
Value *yx = Builder.CreateFMul(y0, x1);
return Builder.CreateFSub(xy, yx, Orig->getName());
};
Value *yz_zy = MulSub(op0_y, op0_z, op1_y, op1_z);
Value *zx_xz = MulSub(op0_z, op0_x, op1_z, op1_x);
Value *xy_yx = MulSub(op0_x, op0_y, op1_x, op1_y);
Value *cross = PoisonValue::get(VT);
cross = Builder.CreateInsertElement(cross, yz_zy, (uint64_t)0);
cross = Builder.CreateInsertElement(cross, zx_xz, 1);
cross = Builder.CreateInsertElement(cross, xy_yx, 2);
return cross;
}
// Create appropriate DXIL float dot intrinsic for the given A and B operands
// The appropriate opcode will be determined by the size of the operands
// The dot product is placed in the position indicated by Orig
static Value *expandFloatDotIntrinsic(CallInst *Orig, Value *A, Value *B) {
Type *ATy = A->getType();
[[maybe_unused]] Type *BTy = B->getType();
assert(ATy->isVectorTy() && BTy->isVectorTy());
IRBuilder<> Builder(Orig);
auto *AVec = dyn_cast<FixedVectorType>(ATy);
assert(ATy->getScalarType()->isFloatingPointTy());
Intrinsic::ID DotIntrinsic = Intrinsic::dx_dot4;
int NumElts = AVec->getNumElements();
switch (NumElts) {
case 2:
DotIntrinsic = Intrinsic::dx_dot2;
break;
case 3:
DotIntrinsic = Intrinsic::dx_dot3;
break;
case 4:
DotIntrinsic = Intrinsic::dx_dot4;
break;
default:
reportFatalUsageError(
"Invalid dot product input vector: length is outside 2-4");
return nullptr;
}
SmallVector<Value *> Args;
for (int I = 0; I < NumElts; ++I)
Args.push_back(Builder.CreateExtractElement(A, Builder.getInt32(I)));
for (int I = 0; I < NumElts; ++I)
Args.push_back(Builder.CreateExtractElement(B, Builder.getInt32(I)));
return Builder.CreateIntrinsic(ATy->getScalarType(), DotIntrinsic, Args,
nullptr, "dot");
}
// Create the appropriate DXIL float dot intrinsic for the operands of Orig
// The appropriate opcode will be determined by the size of the operands
// The dot product is placed in the position indicated by Orig
static Value *expandFloatDotIntrinsic(CallInst *Orig) {
return expandFloatDotIntrinsic(Orig, Orig->getOperand(0),
Orig->getOperand(1));
}
// Expand integer dot product to multiply and add ops
static Value *expandIntegerDotIntrinsic(CallInst *Orig,
Intrinsic::ID DotIntrinsic) {
assert(DotIntrinsic == Intrinsic::dx_sdot ||
DotIntrinsic == Intrinsic::dx_udot);
Value *A = Orig->getOperand(0);
Value *B = Orig->getOperand(1);
Type *ATy = A->getType();
[[maybe_unused]] Type *BTy = B->getType();
assert(ATy->isVectorTy() && BTy->isVectorTy());
IRBuilder<> Builder(Orig);
auto *AVec = dyn_cast<FixedVectorType>(ATy);
assert(ATy->getScalarType()->isIntegerTy());
Value *Result;
Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot
? Intrinsic::dx_imad
: Intrinsic::dx_umad;
Value *Elt0 = Builder.CreateExtractElement(A, (uint64_t)0);
Value *Elt1 = Builder.CreateExtractElement(B, (uint64_t)0);
Result = Builder.CreateMul(Elt0, Elt1);
for (unsigned I = 1; I < AVec->getNumElements(); I++) {
Elt0 = Builder.CreateExtractElement(A, I);
Elt1 = Builder.CreateExtractElement(B, I);
Result = Builder.CreateIntrinsic(Result->getType(), MadIntrinsic,
ArrayRef<Value *>{Elt0, Elt1, Result},
nullptr, "dx.mad");
}
return Result;
}
static Value *expandExpIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
IRBuilder<> Builder(Orig);
Type *Ty = X->getType();
Type *EltTy = Ty->getScalarType();
Constant *Log2eConst =
Ty->isVectorTy() ? ConstantVector::getSplat(
ElementCount::getFixed(
cast<FixedVectorType>(Ty)->getNumElements()),
ConstantFP::get(EltTy, numbers::log2ef))
: ConstantFP::get(EltTy, numbers::log2ef);
Value *NewX = Builder.CreateFMul(Log2eConst, X);
auto *Exp2Call =
Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {NewX}, nullptr, "dx.exp2");
Exp2Call->setTailCall(Orig->isTailCall());
Exp2Call->setAttributes(Orig->getAttributes());
return Exp2Call;
}
static Value *expandIsFPClass(CallInst *Orig) {
Value *T = Orig->getArgOperand(1);
auto *TCI = dyn_cast<ConstantInt>(T);
// These FPClassTest cases have DXIL opcodes, so they will be handled in
// DXIL Op Lowering instead.
switch (TCI->getZExtValue()) {
case FPClassTest::fcInf:
case FPClassTest::fcNan:
case FPClassTest::fcNormal:
case FPClassTest::fcFinite:
return nullptr;
}
IRBuilder<> Builder(Orig);
Value *F = Orig->getArgOperand(0);
Type *FTy = F->getType();
unsigned FNumElem = 0; // 0 => F is not a vector
unsigned BitWidth; // Bit width of F or the ElemTy of F
Type *BitCastTy; // An IntNTy of the same bitwidth as F or ElemTy of F
if (auto *FVecTy = dyn_cast<FixedVectorType>(FTy)) {
Type *ElemTy = FVecTy->getElementType();
FNumElem = FVecTy->getNumElements();
BitWidth = ElemTy->getPrimitiveSizeInBits();
BitCastTy = FixedVectorType::get(Builder.getIntNTy(BitWidth), FNumElem);
} else {
BitWidth = FTy->getPrimitiveSizeInBits();
BitCastTy = Builder.getIntNTy(BitWidth);
}
Value *FBitCast = Builder.CreateBitCast(F, BitCastTy);
switch (TCI->getZExtValue()) {
case FPClassTest::fcNegZero: {
Value *NegZero =
ConstantInt::get(Builder.getIntNTy(BitWidth), 1 << (BitWidth - 1));
Value *RetVal;
if (FNumElem) {
Value *NegZeroSplat = Builder.CreateVectorSplat(FNumElem, NegZero);
RetVal =
Builder.CreateICmpEQ(FBitCast, NegZeroSplat, "is.fpclass.negzero");
} else
RetVal = Builder.CreateICmpEQ(FBitCast, NegZero, "is.fpclass.negzero");
return RetVal;
}
default:
reportFatalUsageError("Unsupported FPClassTest");
}
}
static Value *expandAnyOrAllIntrinsic(CallInst *Orig,
Intrinsic::ID IntrinsicId) {
Value *X = Orig->getOperand(0);
IRBuilder<> Builder(Orig);
Type *Ty = X->getType();
Type *EltTy = Ty->getScalarType();
auto ApplyOp = [&Builder](Intrinsic::ID IntrinsicId, Value *Result,
Value *Elt) {
if (IntrinsicId == Intrinsic::dx_any)
return Builder.CreateOr(Result, Elt);
assert(IntrinsicId == Intrinsic::dx_all);
return Builder.CreateAnd(Result, Elt);
};
Value *Result = nullptr;
if (!Ty->isVectorTy()) {
Result = EltTy->isFloatingPointTy()
? Builder.CreateFCmpUNE(X, ConstantFP::get(EltTy, 0))
: Builder.CreateICmpNE(X, ConstantInt::get(EltTy, 0));
} else {
auto *XVec = dyn_cast<FixedVectorType>(Ty);
Value *Cond =
EltTy->isFloatingPointTy()
? Builder.CreateFCmpUNE(
X, ConstantVector::getSplat(
ElementCount::getFixed(XVec->getNumElements()),
ConstantFP::get(EltTy, 0)))
: Builder.CreateICmpNE(
X, ConstantVector::getSplat(
ElementCount::getFixed(XVec->getNumElements()),
ConstantInt::get(EltTy, 0)));
Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
for (unsigned I = 1; I < XVec->getNumElements(); I++) {
Value *Elt = Builder.CreateExtractElement(Cond, I);
Result = ApplyOp(IntrinsicId, Result, Elt);
}
}
return Result;
}
static Value *expandLerpIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
Value *Y = Orig->getOperand(1);
Value *S = Orig->getOperand(2);
IRBuilder<> Builder(Orig);
auto *V = Builder.CreateFSub(Y, X);
V = Builder.CreateFMul(S, V);
return Builder.CreateFAdd(X, V, "dx.lerp");
}
static Value *expandLogIntrinsic(CallInst *Orig,
float LogConstVal = numbers::ln2f) {
Value *X = Orig->getOperand(0);
IRBuilder<> Builder(Orig);
Type *Ty = X->getType();
Type *EltTy = Ty->getScalarType();
Constant *Ln2Const =
Ty->isVectorTy() ? ConstantVector::getSplat(
ElementCount::getFixed(
cast<FixedVectorType>(Ty)->getNumElements()),
ConstantFP::get(EltTy, LogConstVal))
: ConstantFP::get(EltTy, LogConstVal);
auto *Log2Call =
Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");
Log2Call->setTailCall(Orig->isTailCall());
Log2Call->setAttributes(Orig->getAttributes());
return Builder.CreateFMul(Ln2Const, Log2Call);
}
static Value *expandLog10Intrinsic(CallInst *Orig) {
return expandLogIntrinsic(Orig, numbers::ln2f / numbers::ln10f);
}
// Use dot product of vector operand with itself to calculate the length.
// Divide the vector by that length to normalize it.
static Value *expandNormalizeIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
Type *Ty = Orig->getType();
Type *EltTy = Ty->getScalarType();
IRBuilder<> Builder(Orig);
auto *XVec = dyn_cast<FixedVectorType>(Ty);
if (!XVec) {
if (auto *constantFP = dyn_cast<ConstantFP>(X)) {
const APFloat &fpVal = constantFP->getValueAPF();
if (fpVal.isZero())
reportFatalUsageError("Invalid input scalar: length is zero");
}
return Builder.CreateFDiv(X, X);
}
Value *DotProduct = expandFloatDotIntrinsic(Orig, X, X);
// verify that the length is non-zero
// (if the dot product is non-zero, then the length is non-zero)
if (auto *constantFP = dyn_cast<ConstantFP>(DotProduct)) {
const APFloat &fpVal = constantFP->getValueAPF();
if (fpVal.isZero())
reportFatalUsageError("Invalid input vector: length is zero");
}
Value *Multiplicand = Builder.CreateIntrinsic(EltTy, Intrinsic::dx_rsqrt,
ArrayRef<Value *>{DotProduct},
nullptr, "dx.rsqrt");
Value *MultiplicandVec =
Builder.CreateVectorSplat(XVec->getNumElements(), Multiplicand);
return Builder.CreateFMul(X, MultiplicandVec);
}
static Value *expandAtan2Intrinsic(CallInst *Orig) {
Value *Y = Orig->getOperand(0);
Value *X = Orig->getOperand(1);
Type *Ty = X->getType();
IRBuilder<> Builder(Orig);
Builder.setFastMathFlags(Orig->getFastMathFlags());
Value *Tan = Builder.CreateFDiv(Y, X);
CallInst *Atan =
Builder.CreateIntrinsic(Ty, Intrinsic::atan, {Tan}, nullptr, "Elt.Atan");
Atan->setTailCall(Orig->isTailCall());
Atan->setAttributes(Orig->getAttributes());
// Modify atan result based on https://en.wikipedia.org/wiki/Atan2.
Constant *Pi = ConstantFP::get(Ty, llvm::numbers::pi);
Constant *HalfPi = ConstantFP::get(Ty, llvm::numbers::pi / 2);
Constant *NegHalfPi = ConstantFP::get(Ty, -llvm::numbers::pi / 2);
Constant *Zero = ConstantFP::get(Ty, 0);
Value *AtanAddPi = Builder.CreateFAdd(Atan, Pi);
Value *AtanSubPi = Builder.CreateFSub(Atan, Pi);
// x > 0 -> atan.
Value *Result = Atan;
Value *XLt0 = Builder.CreateFCmpOLT(X, Zero);
Value *XEq0 = Builder.CreateFCmpOEQ(X, Zero);
Value *YGe0 = Builder.CreateFCmpOGE(Y, Zero);
Value *YLt0 = Builder.CreateFCmpOLT(Y, Zero);
// x < 0, y >= 0 -> atan + pi.
Value *XLt0AndYGe0 = Builder.CreateAnd(XLt0, YGe0);
Result = Builder.CreateSelect(XLt0AndYGe0, AtanAddPi, Result);
// x < 0, y < 0 -> atan - pi.
Value *XLt0AndYLt0 = Builder.CreateAnd(XLt0, YLt0);
Result = Builder.CreateSelect(XLt0AndYLt0, AtanSubPi, Result);
// x == 0, y < 0 -> -pi/2
Value *XEq0AndYLt0 = Builder.CreateAnd(XEq0, YLt0);
Result = Builder.CreateSelect(XEq0AndYLt0, NegHalfPi, Result);
// x == 0, y > 0 -> pi/2
Value *XEq0AndYGe0 = Builder.CreateAnd(XEq0, YGe0);
Result = Builder.CreateSelect(XEq0AndYGe0, HalfPi, Result);
return Result;
}
static Value *expandPowIntrinsic(CallInst *Orig, Intrinsic::ID IntrinsicId) {
Value *X = Orig->getOperand(0);
Value *Y = Orig->getOperand(1);
Type *Ty = X->getType();
IRBuilder<> Builder(Orig);
if (IntrinsicId == Intrinsic::powi)
Y = Builder.CreateSIToFP(Y, Ty);
auto *Log2Call =
Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");
auto *Mul = Builder.CreateFMul(Log2Call, Y);
auto *Exp2Call =
Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {Mul}, nullptr, "elt.exp2");
Exp2Call->setTailCall(Orig->isTailCall());
Exp2Call->setAttributes(Orig->getAttributes());
return Exp2Call;
}
static Value *expandStepIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
Value *Y = Orig->getOperand(1);
Type *Ty = X->getType();
IRBuilder<> Builder(Orig);
Constant *One = ConstantFP::get(Ty->getScalarType(), 1.0);
Constant *Zero = ConstantFP::get(Ty->getScalarType(), 0.0);
Value *Cond = Builder.CreateFCmpOLT(Y, X);
if (Ty != Ty->getScalarType()) {
auto *XVec = dyn_cast<FixedVectorType>(Ty);
One = ConstantVector::getSplat(
ElementCount::getFixed(XVec->getNumElements()), One);
Zero = ConstantVector::getSplat(
ElementCount::getFixed(XVec->getNumElements()), Zero);
}
return Builder.CreateSelect(Cond, Zero, One);
}
static Value *expandRadiansIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
Type *Ty = X->getType();
IRBuilder<> Builder(Orig);
Value *PiOver180 = ConstantFP::get(Ty, llvm::numbers::pi / 180.0);
return Builder.CreateFMul(X, PiOver180);
}
static bool expandTypedBufferLoadIntrinsic(CallInst *Orig) {
IRBuilder<> Builder(Orig);
Type *BufferTy = Orig->getType()->getStructElementType(0);
assert(BufferTy->getScalarType()->isDoubleTy() &&
"Only expand double or double2");
unsigned ExtractNum = 2;
if (auto *VT = dyn_cast<FixedVectorType>(BufferTy)) {
assert(VT->getNumElements() == 2 &&
"TypedBufferLoad double vector has wrong size");
ExtractNum = 4;
}
Type *Ty = VectorType::get(Builder.getInt32Ty(), ExtractNum, false);
Type *LoadType = StructType::get(Ty, Builder.getInt1Ty());
CallInst *Load =
Builder.CreateIntrinsic(LoadType, Intrinsic::dx_resource_load_typedbuffer,
{Orig->getOperand(0), Orig->getOperand(1)});
// extract the buffer load's result
Value *Extract = Builder.CreateExtractValue(Load, {0});
SmallVector<Value *> ExtractElements;
for (unsigned I = 0; I < ExtractNum; ++I)
ExtractElements.push_back(
Builder.CreateExtractElement(Extract, Builder.getInt32(I)));
// combine into double(s)
Value *Result = PoisonValue::get(BufferTy);
for (unsigned I = 0; I < ExtractNum; I += 2) {
Value *Dbl =
Builder.CreateIntrinsic(Builder.getDoubleTy(), Intrinsic::dx_asdouble,
{ExtractElements[I], ExtractElements[I + 1]});
if (ExtractNum == 4)
Result =
Builder.CreateInsertElement(Result, Dbl, Builder.getInt32(I / 2));
else
Result = Dbl;
}
Value *CheckBit = nullptr;
for (User *U : make_early_inc_range(Orig->users())) {
auto *EVI = cast<ExtractValueInst>(U);
ArrayRef<unsigned> Indices = EVI->getIndices();
assert(Indices.size() == 1);
if (Indices[0] == 0) {
// Use of the value(s)
EVI->replaceAllUsesWith(Result);
} else {
// Use of the check bit
assert(Indices[0] == 1 && "Unexpected type for typedbufferload");
if (!CheckBit)
CheckBit = Builder.CreateExtractValue(Load, {1});
EVI->replaceAllUsesWith(CheckBit);
}
EVI->eraseFromParent();
}
Orig->eraseFromParent();
return true;
}
static bool expandTypedBufferStoreIntrinsic(CallInst *Orig) {
IRBuilder<> Builder(Orig);
Type *BufferTy = Orig->getFunctionType()->getParamType(2);
assert(BufferTy->getScalarType()->isDoubleTy() &&
"Only expand double or double2");
unsigned ExtractNum = 2;
if (auto *VT = dyn_cast<FixedVectorType>(BufferTy)) {
assert(VT->getNumElements() == 2 &&
"TypedBufferStore double vector has wrong size");
ExtractNum = 4;
}
Type *SplitElementTy = Builder.getInt32Ty();
if (ExtractNum == 4)
SplitElementTy = VectorType::get(SplitElementTy, 2, false);
// split our double(s)
auto *SplitTy = llvm::StructType::get(SplitElementTy, SplitElementTy);
Value *Split = Builder.CreateIntrinsic(SplitTy, Intrinsic::dx_splitdouble,
Orig->getOperand(2));
// create our vector
Value *LowBits = Builder.CreateExtractValue(Split, 0);
Value *HighBits = Builder.CreateExtractValue(Split, 1);
Value *Val;
if (ExtractNum == 2) {
Val = PoisonValue::get(VectorType::get(SplitElementTy, 2, false));
Val = Builder.CreateInsertElement(Val, LowBits, Builder.getInt32(0));
Val = Builder.CreateInsertElement(Val, HighBits, Builder.getInt32(1));
} else
Val = Builder.CreateShuffleVector(LowBits, HighBits, {0, 2, 1, 3});
Builder.CreateIntrinsic(Builder.getVoidTy(),
Intrinsic::dx_resource_store_typedbuffer,
{Orig->getOperand(0), Orig->getOperand(1), Val});
Orig->eraseFromParent();
return true;
}
static Intrinsic::ID getMaxForClamp(Intrinsic::ID ClampIntrinsic) {
if (ClampIntrinsic == Intrinsic::dx_uclamp)
return Intrinsic::umax;
if (ClampIntrinsic == Intrinsic::dx_sclamp)
return Intrinsic::smax;
assert(ClampIntrinsic == Intrinsic::dx_nclamp);
return Intrinsic::maxnum;
}
static Intrinsic::ID getMinForClamp(Intrinsic::ID ClampIntrinsic) {
if (ClampIntrinsic == Intrinsic::dx_uclamp)
return Intrinsic::umin;
if (ClampIntrinsic == Intrinsic::dx_sclamp)
return Intrinsic::smin;
assert(ClampIntrinsic == Intrinsic::dx_nclamp);
return Intrinsic::minnum;
}
static Value *expandClampIntrinsic(CallInst *Orig,
Intrinsic::ID ClampIntrinsic) {
Value *X = Orig->getOperand(0);
Value *Min = Orig->getOperand(1);
Value *Max = Orig->getOperand(2);
Type *Ty = X->getType();
IRBuilder<> Builder(Orig);
auto *MaxCall = Builder.CreateIntrinsic(Ty, getMaxForClamp(ClampIntrinsic),
{X, Min}, nullptr, "dx.max");
return Builder.CreateIntrinsic(Ty, getMinForClamp(ClampIntrinsic),
{MaxCall, Max}, nullptr, "dx.min");
}
static Value *expandDegreesIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
Type *Ty = X->getType();
IRBuilder<> Builder(Orig);
Value *DegreesRatio = ConstantFP::get(Ty, 180.0 * llvm::numbers::inv_pi);
return Builder.CreateFMul(X, DegreesRatio);
}
static Value *expandSignIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
Type *Ty = X->getType();
Type *ScalarTy = Ty->getScalarType();
Type *RetTy = Orig->getType();
Constant *Zero = Constant::getNullValue(Ty);
IRBuilder<> Builder(Orig);
Value *GT;
Value *LT;
if (ScalarTy->isFloatingPointTy()) {
GT = Builder.CreateFCmpOLT(Zero, X);
LT = Builder.CreateFCmpOLT(X, Zero);
} else {
assert(ScalarTy->isIntegerTy());
GT = Builder.CreateICmpSLT(Zero, X);
LT = Builder.CreateICmpSLT(X, Zero);
}
Value *ZextGT = Builder.CreateZExt(GT, RetTy);
Value *ZextLT = Builder.CreateZExt(LT, RetTy);
return Builder.CreateSub(ZextGT, ZextLT);
}
static bool expandIntrinsic(Function &F, CallInst *Orig) {
Value *Result = nullptr;
Intrinsic::ID IntrinsicId = F.getIntrinsicID();
switch (IntrinsicId) {
case Intrinsic::abs:
Result = expandAbs(Orig);
break;
case Intrinsic::atan2:
Result = expandAtan2Intrinsic(Orig);
break;
case Intrinsic::exp:
Result = expandExpIntrinsic(Orig);
break;
case Intrinsic::is_fpclass:
Result = expandIsFPClass(Orig);
break;
case Intrinsic::log:
Result = expandLogIntrinsic(Orig);
break;
case Intrinsic::log10:
Result = expandLog10Intrinsic(Orig);
break;
case Intrinsic::pow:
case Intrinsic::powi:
Result = expandPowIntrinsic(Orig, IntrinsicId);
break;
case Intrinsic::dx_all:
case Intrinsic::dx_any:
Result = expandAnyOrAllIntrinsic(Orig, IntrinsicId);
break;
case Intrinsic::dx_cross:
Result = expandCrossIntrinsic(Orig);
break;
case Intrinsic::dx_uclamp:
case Intrinsic::dx_sclamp:
case Intrinsic::dx_nclamp:
Result = expandClampIntrinsic(Orig, IntrinsicId);
break;
case Intrinsic::dx_degrees:
Result = expandDegreesIntrinsic(Orig);
break;
case Intrinsic::dx_lerp:
Result = expandLerpIntrinsic(Orig);
break;
case Intrinsic::dx_normalize:
Result = expandNormalizeIntrinsic(Orig);
break;
case Intrinsic::dx_fdot:
Result = expandFloatDotIntrinsic(Orig);
break;
case Intrinsic::dx_sdot:
case Intrinsic::dx_udot:
Result = expandIntegerDotIntrinsic(Orig, IntrinsicId);
break;
case Intrinsic::dx_sign:
Result = expandSignIntrinsic(Orig);
break;
case Intrinsic::dx_step:
Result = expandStepIntrinsic(Orig);
break;
case Intrinsic::dx_radians:
Result = expandRadiansIntrinsic(Orig);
break;
case Intrinsic::dx_resource_load_typedbuffer:
if (expandTypedBufferLoadIntrinsic(Orig))
return true;
break;
case Intrinsic::dx_resource_store_typedbuffer:
if (expandTypedBufferStoreIntrinsic(Orig))
return true;
break;
case Intrinsic::usub_sat:
Result = expandUsubSat(Orig);
break;
case Intrinsic::vector_reduce_add:
case Intrinsic::vector_reduce_fadd:
Result = expandVecReduceAdd(Orig, IntrinsicId);
break;
}
if (Result) {
Orig->replaceAllUsesWith(Result);
Orig->eraseFromParent();
return true;
}
return false;
}
static bool expansionIntrinsics(Module &M) {
for (auto &F : make_early_inc_range(M.functions())) {
if (!isIntrinsicExpansion(F))
continue;
bool IntrinsicExpanded = false;
for (User *U : make_early_inc_range(F.users())) {
auto *IntrinsicCall = dyn_cast<CallInst>(U);
if (!IntrinsicCall)
continue;
IntrinsicExpanded = expandIntrinsic(F, IntrinsicCall);
}
if (F.user_empty() && IntrinsicExpanded)
F.eraseFromParent();
}
return true;
}
PreservedAnalyses DXILIntrinsicExpansion::run(Module &M,
ModuleAnalysisManager &) {
if (expansionIntrinsics(M))
return PreservedAnalyses::none();
return PreservedAnalyses::all();
}
bool DXILIntrinsicExpansionLegacy::runOnModule(Module &M) {
return expansionIntrinsics(M);
}
char DXILIntrinsicExpansionLegacy::ID = 0;
INITIALIZE_PASS_BEGIN(DXILIntrinsicExpansionLegacy, DEBUG_TYPE,
"DXIL Intrinsic Expansion", false, false)
INITIALIZE_PASS_END(DXILIntrinsicExpansionLegacy, DEBUG_TYPE,
"DXIL Intrinsic Expansion", false, false)
ModulePass *llvm::createDXILIntrinsicExpansionLegacyPass() {
return new DXILIntrinsicExpansionLegacy();
}