
X86 at least is able to use movmsk or kmov to move the mask to the scalar domain. Then we can just use test instructions to test individual bits. This is more efficient than extracting each mask element individually. I special cased v1i1 to use the previous behavior. This avoids poor type legalization of bitcast of v1i1 to i1. I've skipped expandload/compressstore as I think we need to handle constant masks for those better first. Many tests end up with duplicate test instructions due to tail duplication in the branch folding pass. But the same thing happens when constructing similar code in C. So its not unique to the scalarization. Not sure if this lowering code will also be good for other targets, but we're only testing X86 today. Differential Revision: https://reviews.llvm.org/D65319 llvm-svn: 367489
827 lines
28 KiB
C++
827 lines
28 KiB
C++
//===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
|
|
// instrinsics
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This pass replaces masked memory intrinsics - when unsupported by the target
|
|
// - with a chain of basic blocks, that deal with the elements one-by-one if the
|
|
// appropriate mask bit is set.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "llvm/ADT/Twine.h"
|
|
#include "llvm/Analysis/TargetTransformInfo.h"
|
|
#include "llvm/CodeGen/TargetSubtargetInfo.h"
|
|
#include "llvm/IR/BasicBlock.h"
|
|
#include "llvm/IR/Constant.h"
|
|
#include "llvm/IR/Constants.h"
|
|
#include "llvm/IR/DerivedTypes.h"
|
|
#include "llvm/IR/Function.h"
|
|
#include "llvm/IR/IRBuilder.h"
|
|
#include "llvm/IR/InstrTypes.h"
|
|
#include "llvm/IR/Instruction.h"
|
|
#include "llvm/IR/Instructions.h"
|
|
#include "llvm/IR/IntrinsicInst.h"
|
|
#include "llvm/IR/Intrinsics.h"
|
|
#include "llvm/IR/Type.h"
|
|
#include "llvm/IR/Value.h"
|
|
#include "llvm/Pass.h"
|
|
#include "llvm/Support/Casting.h"
|
|
#include <algorithm>
|
|
#include <cassert>
|
|
|
|
using namespace llvm;
|
|
|
|
#define DEBUG_TYPE "scalarize-masked-mem-intrin"
|
|
|
|
namespace {
|
|
|
|
class ScalarizeMaskedMemIntrin : public FunctionPass {
|
|
const TargetTransformInfo *TTI = nullptr;
|
|
|
|
public:
|
|
static char ID; // Pass identification, replacement for typeid
|
|
|
|
explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID) {
|
|
initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry());
|
|
}
|
|
|
|
bool runOnFunction(Function &F) override;
|
|
|
|
StringRef getPassName() const override {
|
|
return "Scalarize Masked Memory Intrinsics";
|
|
}
|
|
|
|
void getAnalysisUsage(AnalysisUsage &AU) const override {
|
|
AU.addRequired<TargetTransformInfoWrapperPass>();
|
|
}
|
|
|
|
private:
|
|
bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT);
|
|
bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
|
|
};
|
|
|
|
} // end anonymous namespace
|
|
|
|
char ScalarizeMaskedMemIntrin::ID = 0;
|
|
|
|
INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE,
|
|
"Scalarize unsupported masked memory intrinsics", false, false)
|
|
|
|
FunctionPass *llvm::createScalarizeMaskedMemIntrinPass() {
|
|
return new ScalarizeMaskedMemIntrin();
|
|
}
|
|
|
|
static bool isConstantIntVector(Value *Mask) {
|
|
Constant *C = dyn_cast<Constant>(Mask);
|
|
if (!C)
|
|
return false;
|
|
|
|
unsigned NumElts = Mask->getType()->getVectorNumElements();
|
|
for (unsigned i = 0; i != NumElts; ++i) {
|
|
Constant *CElt = C->getAggregateElement(i);
|
|
if (!CElt || !isa<ConstantInt>(CElt))
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
// Translate a masked load intrinsic like
|
|
// <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
|
|
// <16 x i1> %mask, <16 x i32> %passthru)
|
|
// to a chain of basic blocks, with loading element one-by-one if
|
|
// the appropriate mask bit is set
|
|
//
|
|
// %1 = bitcast i8* %addr to i32*
|
|
// %2 = extractelement <16 x i1> %mask, i32 0
|
|
// br i1 %2, label %cond.load, label %else
|
|
//
|
|
// cond.load: ; preds = %0
|
|
// %3 = getelementptr i32* %1, i32 0
|
|
// %4 = load i32* %3
|
|
// %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
|
|
// br label %else
|
|
//
|
|
// else: ; preds = %0, %cond.load
|
|
// %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
|
|
// %6 = extractelement <16 x i1> %mask, i32 1
|
|
// br i1 %6, label %cond.load1, label %else2
|
|
//
|
|
// cond.load1: ; preds = %else
|
|
// %7 = getelementptr i32* %1, i32 1
|
|
// %8 = load i32* %7
|
|
// %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
|
|
// br label %else2
|
|
//
|
|
// else2: ; preds = %else, %cond.load1
|
|
// %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
|
|
// %10 = extractelement <16 x i1> %mask, i32 2
|
|
// br i1 %10, label %cond.load4, label %else5
|
|
//
|
|
static void scalarizeMaskedLoad(CallInst *CI, bool &ModifiedDT) {
|
|
Value *Ptr = CI->getArgOperand(0);
|
|
Value *Alignment = CI->getArgOperand(1);
|
|
Value *Mask = CI->getArgOperand(2);
|
|
Value *Src0 = CI->getArgOperand(3);
|
|
|
|
unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
|
|
VectorType *VecType = cast<VectorType>(CI->getType());
|
|
|
|
Type *EltTy = VecType->getElementType();
|
|
|
|
IRBuilder<> Builder(CI->getContext());
|
|
Instruction *InsertPt = CI;
|
|
BasicBlock *IfBlock = CI->getParent();
|
|
|
|
Builder.SetInsertPoint(InsertPt);
|
|
Builder.SetCurrentDebugLocation(CI->getDebugLoc());
|
|
|
|
// Short-cut if the mask is all-true.
|
|
if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
|
|
Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
|
|
CI->replaceAllUsesWith(NewI);
|
|
CI->eraseFromParent();
|
|
return;
|
|
}
|
|
|
|
// Adjust alignment for the scalar instruction.
|
|
AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
|
|
// Bitcast %addr from i8* to EltTy*
|
|
Type *NewPtrType =
|
|
EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
|
|
Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
|
|
unsigned VectorWidth = VecType->getNumElements();
|
|
|
|
// The result vector
|
|
Value *VResult = Src0;
|
|
|
|
if (isConstantIntVector(Mask)) {
|
|
for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
|
|
if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
|
|
continue;
|
|
Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
|
|
LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal);
|
|
VResult = Builder.CreateInsertElement(VResult, Load, Idx);
|
|
}
|
|
CI->replaceAllUsesWith(VResult);
|
|
CI->eraseFromParent();
|
|
return;
|
|
}
|
|
|
|
// If the mask is not v1i1, use scalar bit test operations. This generates
|
|
// better results on X86 at least.
|
|
Value *SclrMask;
|
|
if (VectorWidth != 1) {
|
|
Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
|
|
SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
|
|
}
|
|
|
|
for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
|
|
// Fill the "else" block, created in the previous iteration
|
|
//
|
|
// %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
|
|
// %mask_1 = and i16 %scalar_mask, i32 1 << Idx
|
|
// %cond = icmp ne i16 %mask_1, 0
|
|
// br i1 %mask_1, label %cond.load, label %else
|
|
//
|
|
Value *Predicate;
|
|
if (VectorWidth != 1) {
|
|
Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
|
|
Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
|
|
Builder.getIntN(VectorWidth, 0));
|
|
} else {
|
|
Predicate = Builder.CreateExtractElement(Mask, Idx);
|
|
}
|
|
|
|
// Create "cond" block
|
|
//
|
|
// %EltAddr = getelementptr i32* %1, i32 0
|
|
// %Elt = load i32* %EltAddr
|
|
// VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
|
|
//
|
|
BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
|
|
"cond.load");
|
|
Builder.SetInsertPoint(InsertPt);
|
|
|
|
Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
|
|
LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal);
|
|
Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
|
|
|
|
// Create "else" block, fill it in the next iteration
|
|
BasicBlock *NewIfBlock =
|
|
CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
|
|
Builder.SetInsertPoint(InsertPt);
|
|
Instruction *OldBr = IfBlock->getTerminator();
|
|
BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
|
|
OldBr->eraseFromParent();
|
|
BasicBlock *PrevIfBlock = IfBlock;
|
|
IfBlock = NewIfBlock;
|
|
|
|
// Create the phi to join the new and previous value.
|
|
PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
|
|
Phi->addIncoming(NewVResult, CondBlock);
|
|
Phi->addIncoming(VResult, PrevIfBlock);
|
|
VResult = Phi;
|
|
}
|
|
|
|
CI->replaceAllUsesWith(VResult);
|
|
CI->eraseFromParent();
|
|
|
|
ModifiedDT = true;
|
|
}
|
|
|
|
// Translate a masked store intrinsic, like
|
|
// void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
|
|
// <16 x i1> %mask)
|
|
// to a chain of basic blocks, that stores element one-by-one if
|
|
// the appropriate mask bit is set
|
|
//
|
|
// %1 = bitcast i8* %addr to i32*
|
|
// %2 = extractelement <16 x i1> %mask, i32 0
|
|
// br i1 %2, label %cond.store, label %else
|
|
//
|
|
// cond.store: ; preds = %0
|
|
// %3 = extractelement <16 x i32> %val, i32 0
|
|
// %4 = getelementptr i32* %1, i32 0
|
|
// store i32 %3, i32* %4
|
|
// br label %else
|
|
//
|
|
// else: ; preds = %0, %cond.store
|
|
// %5 = extractelement <16 x i1> %mask, i32 1
|
|
// br i1 %5, label %cond.store1, label %else2
|
|
//
|
|
// cond.store1: ; preds = %else
|
|
// %6 = extractelement <16 x i32> %val, i32 1
|
|
// %7 = getelementptr i32* %1, i32 1
|
|
// store i32 %6, i32* %7
|
|
// br label %else2
|
|
// . . .
|
|
static void scalarizeMaskedStore(CallInst *CI, bool &ModifiedDT) {
|
|
Value *Src = CI->getArgOperand(0);
|
|
Value *Ptr = CI->getArgOperand(1);
|
|
Value *Alignment = CI->getArgOperand(2);
|
|
Value *Mask = CI->getArgOperand(3);
|
|
|
|
unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
|
|
VectorType *VecType = cast<VectorType>(Src->getType());
|
|
|
|
Type *EltTy = VecType->getElementType();
|
|
|
|
IRBuilder<> Builder(CI->getContext());
|
|
Instruction *InsertPt = CI;
|
|
BasicBlock *IfBlock = CI->getParent();
|
|
Builder.SetInsertPoint(InsertPt);
|
|
Builder.SetCurrentDebugLocation(CI->getDebugLoc());
|
|
|
|
// Short-cut if the mask is all-true.
|
|
if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
|
|
Builder.CreateAlignedStore(Src, Ptr, AlignVal);
|
|
CI->eraseFromParent();
|
|
return;
|
|
}
|
|
|
|
// Adjust alignment for the scalar instruction.
|
|
AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
|
|
// Bitcast %addr from i8* to EltTy*
|
|
Type *NewPtrType =
|
|
EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
|
|
Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
|
|
unsigned VectorWidth = VecType->getNumElements();
|
|
|
|
if (isConstantIntVector(Mask)) {
|
|
for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
|
|
if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
|
|
continue;
|
|
Value *OneElt = Builder.CreateExtractElement(Src, Idx);
|
|
Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
|
|
Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
|
|
}
|
|
CI->eraseFromParent();
|
|
return;
|
|
}
|
|
|
|
// If the mask is not v1i1, use scalar bit test operations. This generates
|
|
// better results on X86 at least.
|
|
Value *SclrMask;
|
|
if (VectorWidth != 1) {
|
|
Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
|
|
SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
|
|
}
|
|
|
|
for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
|
|
// Fill the "else" block, created in the previous iteration
|
|
//
|
|
// %mask_1 = and i16 %scalar_mask, i32 1 << Idx
|
|
// %cond = icmp ne i16 %mask_1, 0
|
|
// br i1 %mask_1, label %cond.store, label %else
|
|
//
|
|
Value *Predicate;
|
|
if (VectorWidth != 1) {
|
|
Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
|
|
Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
|
|
Builder.getIntN(VectorWidth, 0));
|
|
} else {
|
|
Predicate = Builder.CreateExtractElement(Mask, Idx);
|
|
}
|
|
|
|
// Create "cond" block
|
|
//
|
|
// %OneElt = extractelement <16 x i32> %Src, i32 Idx
|
|
// %EltAddr = getelementptr i32* %1, i32 0
|
|
// %store i32 %OneElt, i32* %EltAddr
|
|
//
|
|
BasicBlock *CondBlock =
|
|
IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
|
|
Builder.SetInsertPoint(InsertPt);
|
|
|
|
Value *OneElt = Builder.CreateExtractElement(Src, Idx);
|
|
Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
|
|
Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
|
|
|
|
// Create "else" block, fill it in the next iteration
|
|
BasicBlock *NewIfBlock =
|
|
CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
|
|
Builder.SetInsertPoint(InsertPt);
|
|
Instruction *OldBr = IfBlock->getTerminator();
|
|
BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
|
|
OldBr->eraseFromParent();
|
|
IfBlock = NewIfBlock;
|
|
}
|
|
CI->eraseFromParent();
|
|
|
|
ModifiedDT = true;
|
|
}
|
|
|
|
// Translate a masked gather intrinsic like
|
|
// <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
|
|
// <16 x i1> %Mask, <16 x i32> %Src)
|
|
// to a chain of basic blocks, with loading element one-by-one if
|
|
// the appropriate mask bit is set
|
|
//
|
|
// %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
|
|
// %Mask0 = extractelement <16 x i1> %Mask, i32 0
|
|
// br i1 %Mask0, label %cond.load, label %else
|
|
//
|
|
// cond.load:
|
|
// %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
|
|
// %Load0 = load i32, i32* %Ptr0, align 4
|
|
// %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
|
|
// br label %else
|
|
//
|
|
// else:
|
|
// %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
|
|
// %Mask1 = extractelement <16 x i1> %Mask, i32 1
|
|
// br i1 %Mask1, label %cond.load1, label %else2
|
|
//
|
|
// cond.load1:
|
|
// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
|
|
// %Load1 = load i32, i32* %Ptr1, align 4
|
|
// %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
|
|
// br label %else2
|
|
// . . .
|
|
// %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
|
|
// ret <16 x i32> %Result
|
|
static void scalarizeMaskedGather(CallInst *CI, bool &ModifiedDT) {
|
|
Value *Ptrs = CI->getArgOperand(0);
|
|
Value *Alignment = CI->getArgOperand(1);
|
|
Value *Mask = CI->getArgOperand(2);
|
|
Value *Src0 = CI->getArgOperand(3);
|
|
|
|
VectorType *VecType = cast<VectorType>(CI->getType());
|
|
Type *EltTy = VecType->getElementType();
|
|
|
|
IRBuilder<> Builder(CI->getContext());
|
|
Instruction *InsertPt = CI;
|
|
BasicBlock *IfBlock = CI->getParent();
|
|
Builder.SetInsertPoint(InsertPt);
|
|
unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
|
|
|
|
Builder.SetCurrentDebugLocation(CI->getDebugLoc());
|
|
|
|
// The result vector
|
|
Value *VResult = Src0;
|
|
unsigned VectorWidth = VecType->getNumElements();
|
|
|
|
// Shorten the way if the mask is a vector of constants.
|
|
if (isConstantIntVector(Mask)) {
|
|
for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
|
|
if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
|
|
continue;
|
|
Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
|
|
LoadInst *Load =
|
|
Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
|
|
VResult =
|
|
Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
|
|
}
|
|
CI->replaceAllUsesWith(VResult);
|
|
CI->eraseFromParent();
|
|
return;
|
|
}
|
|
|
|
// If the mask is not v1i1, use scalar bit test operations. This generates
|
|
// better results on X86 at least.
|
|
Value *SclrMask;
|
|
if (VectorWidth != 1) {
|
|
Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
|
|
SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
|
|
}
|
|
|
|
for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
|
|
// Fill the "else" block, created in the previous iteration
|
|
//
|
|
// %Mask1 = and i16 %scalar_mask, i32 1 << Idx
|
|
// %cond = icmp ne i16 %mask_1, 0
|
|
// br i1 %Mask1, label %cond.load, label %else
|
|
//
|
|
|
|
Value *Predicate;
|
|
if (VectorWidth != 1) {
|
|
Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
|
|
Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
|
|
Builder.getIntN(VectorWidth, 0));
|
|
} else {
|
|
Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
|
|
}
|
|
|
|
// Create "cond" block
|
|
//
|
|
// %EltAddr = getelementptr i32* %1, i32 0
|
|
// %Elt = load i32* %EltAddr
|
|
// VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
|
|
//
|
|
BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
|
|
Builder.SetInsertPoint(InsertPt);
|
|
|
|
Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
|
|
LoadInst *Load =
|
|
Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
|
|
Value *NewVResult =
|
|
Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
|
|
|
|
// Create "else" block, fill it in the next iteration
|
|
BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
|
|
Builder.SetInsertPoint(InsertPt);
|
|
Instruction *OldBr = IfBlock->getTerminator();
|
|
BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
|
|
OldBr->eraseFromParent();
|
|
BasicBlock *PrevIfBlock = IfBlock;
|
|
IfBlock = NewIfBlock;
|
|
|
|
PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
|
|
Phi->addIncoming(NewVResult, CondBlock);
|
|
Phi->addIncoming(VResult, PrevIfBlock);
|
|
VResult = Phi;
|
|
}
|
|
|
|
CI->replaceAllUsesWith(VResult);
|
|
CI->eraseFromParent();
|
|
|
|
ModifiedDT = true;
|
|
}
|
|
|
|
// Translate a masked scatter intrinsic, like
|
|
// void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
|
|
// <16 x i1> %Mask)
|
|
// to a chain of basic blocks, that stores element one-by-one if
|
|
// the appropriate mask bit is set.
|
|
//
|
|
// %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
|
|
// %Mask0 = extractelement <16 x i1> %Mask, i32 0
|
|
// br i1 %Mask0, label %cond.store, label %else
|
|
//
|
|
// cond.store:
|
|
// %Elt0 = extractelement <16 x i32> %Src, i32 0
|
|
// %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
|
|
// store i32 %Elt0, i32* %Ptr0, align 4
|
|
// br label %else
|
|
//
|
|
// else:
|
|
// %Mask1 = extractelement <16 x i1> %Mask, i32 1
|
|
// br i1 %Mask1, label %cond.store1, label %else2
|
|
//
|
|
// cond.store1:
|
|
// %Elt1 = extractelement <16 x i32> %Src, i32 1
|
|
// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
|
|
// store i32 %Elt1, i32* %Ptr1, align 4
|
|
// br label %else2
|
|
// . . .
|
|
static void scalarizeMaskedScatter(CallInst *CI, bool &ModifiedDT) {
|
|
Value *Src = CI->getArgOperand(0);
|
|
Value *Ptrs = CI->getArgOperand(1);
|
|
Value *Alignment = CI->getArgOperand(2);
|
|
Value *Mask = CI->getArgOperand(3);
|
|
|
|
assert(isa<VectorType>(Src->getType()) &&
|
|
"Unexpected data type in masked scatter intrinsic");
|
|
assert(isa<VectorType>(Ptrs->getType()) &&
|
|
isa<PointerType>(Ptrs->getType()->getVectorElementType()) &&
|
|
"Vector of pointers is expected in masked scatter intrinsic");
|
|
|
|
IRBuilder<> Builder(CI->getContext());
|
|
Instruction *InsertPt = CI;
|
|
BasicBlock *IfBlock = CI->getParent();
|
|
Builder.SetInsertPoint(InsertPt);
|
|
Builder.SetCurrentDebugLocation(CI->getDebugLoc());
|
|
|
|
unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
|
|
unsigned VectorWidth = Src->getType()->getVectorNumElements();
|
|
|
|
// Shorten the way if the mask is a vector of constants.
|
|
if (isConstantIntVector(Mask)) {
|
|
for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
|
|
if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
|
|
continue;
|
|
Value *OneElt =
|
|
Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
|
|
Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
|
|
Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
|
|
}
|
|
CI->eraseFromParent();
|
|
return;
|
|
}
|
|
|
|
// If the mask is not v1i1, use scalar bit test operations. This generates
|
|
// better results on X86 at least.
|
|
Value *SclrMask;
|
|
if (VectorWidth != 1) {
|
|
Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
|
|
SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
|
|
}
|
|
|
|
for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
|
|
// Fill the "else" block, created in the previous iteration
|
|
//
|
|
// %Mask1 = and i16 %scalar_mask, i32 1 << Idx
|
|
// %cond = icmp ne i16 %mask_1, 0
|
|
// br i1 %Mask1, label %cond.store, label %else
|
|
//
|
|
Value *Predicate;
|
|
if (VectorWidth != 1) {
|
|
Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
|
|
Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
|
|
Builder.getIntN(VectorWidth, 0));
|
|
} else {
|
|
Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
|
|
}
|
|
|
|
// Create "cond" block
|
|
//
|
|
// %Elt1 = extractelement <16 x i32> %Src, i32 1
|
|
// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
|
|
// %store i32 %Elt1, i32* %Ptr1
|
|
//
|
|
BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
|
|
Builder.SetInsertPoint(InsertPt);
|
|
|
|
Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
|
|
Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
|
|
Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
|
|
|
|
// Create "else" block, fill it in the next iteration
|
|
BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
|
|
Builder.SetInsertPoint(InsertPt);
|
|
Instruction *OldBr = IfBlock->getTerminator();
|
|
BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
|
|
OldBr->eraseFromParent();
|
|
IfBlock = NewIfBlock;
|
|
}
|
|
CI->eraseFromParent();
|
|
|
|
ModifiedDT = true;
|
|
}
|
|
|
|
static void scalarizeMaskedExpandLoad(CallInst *CI, bool &ModifiedDT) {
|
|
Value *Ptr = CI->getArgOperand(0);
|
|
Value *Mask = CI->getArgOperand(1);
|
|
Value *PassThru = CI->getArgOperand(2);
|
|
|
|
VectorType *VecType = cast<VectorType>(CI->getType());
|
|
|
|
Type *EltTy = VecType->getElementType();
|
|
|
|
IRBuilder<> Builder(CI->getContext());
|
|
Instruction *InsertPt = CI;
|
|
BasicBlock *IfBlock = CI->getParent();
|
|
|
|
Builder.SetInsertPoint(InsertPt);
|
|
Builder.SetCurrentDebugLocation(CI->getDebugLoc());
|
|
|
|
unsigned VectorWidth = VecType->getNumElements();
|
|
|
|
// The result vector
|
|
Value *VResult = PassThru;
|
|
|
|
for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
|
|
// Fill the "else" block, created in the previous iteration
|
|
//
|
|
// %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
|
|
// %mask_1 = extractelement <16 x i1> %mask, i32 Idx
|
|
// br i1 %mask_1, label %cond.load, label %else
|
|
//
|
|
|
|
Value *Predicate =
|
|
Builder.CreateExtractElement(Mask, Idx);
|
|
|
|
// Create "cond" block
|
|
//
|
|
// %EltAddr = getelementptr i32* %1, i32 0
|
|
// %Elt = load i32* %EltAddr
|
|
// VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
|
|
//
|
|
BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
|
|
"cond.load");
|
|
Builder.SetInsertPoint(InsertPt);
|
|
|
|
LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, 1);
|
|
Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
|
|
|
|
// Move the pointer if there are more blocks to come.
|
|
Value *NewPtr;
|
|
if ((Idx + 1) != VectorWidth)
|
|
NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
|
|
|
|
// Create "else" block, fill it in the next iteration
|
|
BasicBlock *NewIfBlock =
|
|
CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
|
|
Builder.SetInsertPoint(InsertPt);
|
|
Instruction *OldBr = IfBlock->getTerminator();
|
|
BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
|
|
OldBr->eraseFromParent();
|
|
BasicBlock *PrevIfBlock = IfBlock;
|
|
IfBlock = NewIfBlock;
|
|
|
|
// Create the phi to join the new and previous value.
|
|
PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else");
|
|
ResultPhi->addIncoming(NewVResult, CondBlock);
|
|
ResultPhi->addIncoming(VResult, PrevIfBlock);
|
|
VResult = ResultPhi;
|
|
|
|
// Add a PHI for the pointer if this isn't the last iteration.
|
|
if ((Idx + 1) != VectorWidth) {
|
|
PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
|
|
PtrPhi->addIncoming(NewPtr, CondBlock);
|
|
PtrPhi->addIncoming(Ptr, PrevIfBlock);
|
|
Ptr = PtrPhi;
|
|
}
|
|
}
|
|
|
|
CI->replaceAllUsesWith(VResult);
|
|
CI->eraseFromParent();
|
|
|
|
ModifiedDT = true;
|
|
}
|
|
|
|
static void scalarizeMaskedCompressStore(CallInst *CI, bool &ModifiedDT) {
|
|
Value *Src = CI->getArgOperand(0);
|
|
Value *Ptr = CI->getArgOperand(1);
|
|
Value *Mask = CI->getArgOperand(2);
|
|
|
|
VectorType *VecType = cast<VectorType>(Src->getType());
|
|
|
|
IRBuilder<> Builder(CI->getContext());
|
|
Instruction *InsertPt = CI;
|
|
BasicBlock *IfBlock = CI->getParent();
|
|
|
|
Builder.SetInsertPoint(InsertPt);
|
|
Builder.SetCurrentDebugLocation(CI->getDebugLoc());
|
|
|
|
Type *EltTy = VecType->getVectorElementType();
|
|
|
|
unsigned VectorWidth = VecType->getNumElements();
|
|
|
|
for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
|
|
// Fill the "else" block, created in the previous iteration
|
|
//
|
|
// %mask_1 = extractelement <16 x i1> %mask, i32 Idx
|
|
// br i1 %mask_1, label %cond.store, label %else
|
|
//
|
|
Value *Predicate = Builder.CreateExtractElement(Mask, Idx);
|
|
|
|
// Create "cond" block
|
|
//
|
|
// %OneElt = extractelement <16 x i32> %Src, i32 Idx
|
|
// %EltAddr = getelementptr i32* %1, i32 0
|
|
// %store i32 %OneElt, i32* %EltAddr
|
|
//
|
|
BasicBlock *CondBlock =
|
|
IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
|
|
Builder.SetInsertPoint(InsertPt);
|
|
|
|
Value *OneElt = Builder.CreateExtractElement(Src, Idx);
|
|
Builder.CreateAlignedStore(OneElt, Ptr, 1);
|
|
|
|
// Move the pointer if there are more blocks to come.
|
|
Value *NewPtr;
|
|
if ((Idx + 1) != VectorWidth)
|
|
NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
|
|
|
|
// Create "else" block, fill it in the next iteration
|
|
BasicBlock *NewIfBlock =
|
|
CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
|
|
Builder.SetInsertPoint(InsertPt);
|
|
Instruction *OldBr = IfBlock->getTerminator();
|
|
BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
|
|
OldBr->eraseFromParent();
|
|
BasicBlock *PrevIfBlock = IfBlock;
|
|
IfBlock = NewIfBlock;
|
|
|
|
// Add a PHI for the pointer if this isn't the last iteration.
|
|
if ((Idx + 1) != VectorWidth) {
|
|
PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
|
|
PtrPhi->addIncoming(NewPtr, CondBlock);
|
|
PtrPhi->addIncoming(Ptr, PrevIfBlock);
|
|
Ptr = PtrPhi;
|
|
}
|
|
}
|
|
CI->eraseFromParent();
|
|
|
|
ModifiedDT = true;
|
|
}
|
|
|
|
bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) {
|
|
bool EverMadeChange = false;
|
|
|
|
TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
|
|
|
|
bool MadeChange = true;
|
|
while (MadeChange) {
|
|
MadeChange = false;
|
|
for (Function::iterator I = F.begin(); I != F.end();) {
|
|
BasicBlock *BB = &*I++;
|
|
bool ModifiedDTOnIteration = false;
|
|
MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
|
|
|
|
// Restart BB iteration if the dominator tree of the Function was changed
|
|
if (ModifiedDTOnIteration)
|
|
break;
|
|
}
|
|
|
|
EverMadeChange |= MadeChange;
|
|
}
|
|
|
|
return EverMadeChange;
|
|
}
|
|
|
|
bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) {
|
|
bool MadeChange = false;
|
|
|
|
BasicBlock::iterator CurInstIterator = BB.begin();
|
|
while (CurInstIterator != BB.end()) {
|
|
if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
|
|
MadeChange |= optimizeCallInst(CI, ModifiedDT);
|
|
if (ModifiedDT)
|
|
return true;
|
|
}
|
|
|
|
return MadeChange;
|
|
}
|
|
|
|
bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI,
|
|
bool &ModifiedDT) {
|
|
IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
|
|
if (II) {
|
|
switch (II->getIntrinsicID()) {
|
|
default:
|
|
break;
|
|
case Intrinsic::masked_load:
|
|
// Scalarize unsupported vector masked load
|
|
if (TTI->isLegalMaskedLoad(CI->getType()))
|
|
return false;
|
|
scalarizeMaskedLoad(CI, ModifiedDT);
|
|
return true;
|
|
case Intrinsic::masked_store:
|
|
if (TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType()))
|
|
return false;
|
|
scalarizeMaskedStore(CI, ModifiedDT);
|
|
return true;
|
|
case Intrinsic::masked_gather:
|
|
if (TTI->isLegalMaskedGather(CI->getType()))
|
|
return false;
|
|
scalarizeMaskedGather(CI, ModifiedDT);
|
|
return true;
|
|
case Intrinsic::masked_scatter:
|
|
if (TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType()))
|
|
return false;
|
|
scalarizeMaskedScatter(CI, ModifiedDT);
|
|
return true;
|
|
case Intrinsic::masked_expandload:
|
|
if (TTI->isLegalMaskedExpandLoad(CI->getType()))
|
|
return false;
|
|
scalarizeMaskedExpandLoad(CI, ModifiedDT);
|
|
return true;
|
|
case Intrinsic::masked_compressstore:
|
|
if (TTI->isLegalMaskedCompressStore(CI->getArgOperand(0)->getType()))
|
|
return false;
|
|
scalarizeMaskedCompressStore(CI, ModifiedDT);
|
|
return true;
|
|
}
|
|
}
|
|
|
|
return false;
|
|
}
|