Add support for single reductions in ComplexDeinterleavingPass (#112875)
The Complex Deinterleaving pass assumes that all values emitted will result in complex numbers, this patch aims to remove that assumption and adds support for emitting just the real or imaginary components, not both.
This commit is contained in:
parent
f8d270474c
commit
b3eede5e1f
@ -35,6 +35,7 @@ public:
|
||||
enum class ComplexDeinterleavingOperation {
|
||||
CAdd,
|
||||
CMulPartial,
|
||||
CDot,
|
||||
// The following 'operations' are used to represent internal states. Backends
|
||||
// are not expected to try and support these in any capacity.
|
||||
Deinterleave,
|
||||
@ -43,6 +44,7 @@ enum class ComplexDeinterleavingOperation {
|
||||
ReductionPHI,
|
||||
ReductionOperation,
|
||||
ReductionSelect,
|
||||
ReductionSingle
|
||||
};
|
||||
|
||||
enum class ComplexDeinterleavingRotation {
|
||||
|
||||
@ -108,6 +108,13 @@ static bool isNeg(Value *V);
|
||||
static Value *getNegOperand(Value *V);
|
||||
|
||||
namespace {
|
||||
template <typename T, typename IterT>
|
||||
std::optional<T> findCommonBetweenCollections(IterT A, IterT B) {
|
||||
auto Common = llvm::find_if(A, [B](T I) { return llvm::is_contained(B, I); });
|
||||
if (Common != A.end())
|
||||
return std::make_optional(*Common);
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
class ComplexDeinterleavingLegacyPass : public FunctionPass {
|
||||
public:
|
||||
@ -144,6 +151,7 @@ private:
|
||||
friend class ComplexDeinterleavingGraph;
|
||||
using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
|
||||
using RawNodePtr = ComplexDeinterleavingCompositeNode *;
|
||||
bool OperandsValid = true;
|
||||
|
||||
public:
|
||||
ComplexDeinterleavingOperation Operation;
|
||||
@ -160,7 +168,11 @@ public:
|
||||
SmallVector<RawNodePtr> Operands;
|
||||
Value *ReplacementNode = nullptr;
|
||||
|
||||
void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }
|
||||
void addOperand(NodePtr Node) {
|
||||
if (!Node || !Node.get())
|
||||
OperandsValid = false;
|
||||
Operands.push_back(Node.get());
|
||||
}
|
||||
|
||||
void dump() { dump(dbgs()); }
|
||||
void dump(raw_ostream &OS) {
|
||||
@ -194,6 +206,8 @@ public:
|
||||
PrintNodeRef(Op);
|
||||
}
|
||||
}
|
||||
|
||||
bool areOperandsValid() { return OperandsValid; }
|
||||
};
|
||||
|
||||
class ComplexDeinterleavingGraph {
|
||||
@ -293,7 +307,7 @@ private:
|
||||
|
||||
NodePtr submitCompositeNode(NodePtr Node) {
|
||||
CompositeNodes.push_back(Node);
|
||||
if (Node->Real && Node->Imag)
|
||||
if (Node->Real)
|
||||
CachedResult[{Node->Real, Node->Imag}] = Node;
|
||||
return Node;
|
||||
}
|
||||
@ -327,6 +341,8 @@ private:
|
||||
/// i: ai - br
|
||||
NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
|
||||
NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag);
|
||||
NodePtr identifyPartialReduction(Value *R, Value *I);
|
||||
NodePtr identifyDotProduct(Value *Inst);
|
||||
|
||||
NodePtr identifyNode(Value *R, Value *I);
|
||||
|
||||
@ -396,6 +412,7 @@ private:
|
||||
/// * Deinterleave the final value outside of the loop and repurpose original
|
||||
/// reduction users
|
||||
void processReductionOperation(Value *OperationReplacement, RawNodePtr Node);
|
||||
void processReductionSingle(Value *OperationReplacement, RawNodePtr Node);
|
||||
|
||||
public:
|
||||
void dump() { dump(dbgs()); }
|
||||
@ -891,17 +908,163 @@ ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
|
||||
}
|
||||
|
||||
ComplexDeinterleavingGraph::NodePtr
|
||||
ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
|
||||
LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n");
|
||||
assert(R->getType() == I->getType() &&
|
||||
"Real and imaginary parts should not have different types");
|
||||
ComplexDeinterleavingGraph::identifyDotProduct(Value *V) {
|
||||
|
||||
if (!TL->isComplexDeinterleavingOperationSupported(
|
||||
ComplexDeinterleavingOperation::CDot, V->getType())) {
|
||||
LLVM_DEBUG(dbgs() << "Target doesn't support complex deinterleaving "
|
||||
"operation CDot with the type "
|
||||
<< *V->getType() << "\n");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto *Inst = cast<Instruction>(V);
|
||||
auto *RealUser = cast<Instruction>(*Inst->user_begin());
|
||||
|
||||
NodePtr CN =
|
||||
prepareCompositeNode(ComplexDeinterleavingOperation::CDot, Inst, nullptr);
|
||||
|
||||
NodePtr ANode;
|
||||
|
||||
const Intrinsic::ID PartialReduceInt =
|
||||
Intrinsic::experimental_vector_partial_reduce_add;
|
||||
|
||||
Value *AReal = nullptr;
|
||||
Value *AImag = nullptr;
|
||||
Value *BReal = nullptr;
|
||||
Value *BImag = nullptr;
|
||||
Value *Phi = nullptr;
|
||||
|
||||
auto UnwrapCast = [](Value *V) -> Value * {
|
||||
if (auto *CI = dyn_cast<CastInst>(V))
|
||||
return CI->getOperand(0);
|
||||
return V;
|
||||
};
|
||||
|
||||
auto PatternRot0 = m_Intrinsic<PartialReduceInt>(
|
||||
m_Intrinsic<PartialReduceInt>(m_Value(Phi),
|
||||
m_Mul(m_Value(BReal), m_Value(AReal))),
|
||||
m_Neg(m_Mul(m_Value(BImag), m_Value(AImag))));
|
||||
|
||||
auto PatternRot270 = m_Intrinsic<PartialReduceInt>(
|
||||
m_Intrinsic<PartialReduceInt>(
|
||||
m_Value(Phi), m_Neg(m_Mul(m_Value(BReal), m_Value(AImag)))),
|
||||
m_Mul(m_Value(BImag), m_Value(AReal)));
|
||||
|
||||
if (match(Inst, PatternRot0)) {
|
||||
CN->Rotation = ComplexDeinterleavingRotation::Rotation_0;
|
||||
} else if (match(Inst, PatternRot270)) {
|
||||
CN->Rotation = ComplexDeinterleavingRotation::Rotation_270;
|
||||
} else {
|
||||
Value *A0, *A1;
|
||||
// The rotations 90 and 180 share the same operation pattern, so inspect the
|
||||
// order of the operands, identifying where the real and imaginary
|
||||
// components of A go, to discern between the aforementioned rotations.
|
||||
auto PatternRot90Rot180 = m_Intrinsic<PartialReduceInt>(
|
||||
m_Intrinsic<PartialReduceInt>(m_Value(Phi),
|
||||
m_Mul(m_Value(BReal), m_Value(A0))),
|
||||
m_Mul(m_Value(BImag), m_Value(A1)));
|
||||
|
||||
if (!match(Inst, PatternRot90Rot180))
|
||||
return nullptr;
|
||||
|
||||
A0 = UnwrapCast(A0);
|
||||
A1 = UnwrapCast(A1);
|
||||
|
||||
// Test if A0 is real/A1 is imag
|
||||
ANode = identifyNode(A0, A1);
|
||||
if (!ANode) {
|
||||
// Test if A0 is imag/A1 is real
|
||||
ANode = identifyNode(A1, A0);
|
||||
// Unable to identify operand components, thus unable to identify rotation
|
||||
if (!ANode)
|
||||
return nullptr;
|
||||
CN->Rotation = ComplexDeinterleavingRotation::Rotation_90;
|
||||
AReal = A1;
|
||||
AImag = A0;
|
||||
} else {
|
||||
AReal = A0;
|
||||
AImag = A1;
|
||||
CN->Rotation = ComplexDeinterleavingRotation::Rotation_180;
|
||||
}
|
||||
}
|
||||
|
||||
AReal = UnwrapCast(AReal);
|
||||
AImag = UnwrapCast(AImag);
|
||||
BReal = UnwrapCast(BReal);
|
||||
BImag = UnwrapCast(BImag);
|
||||
|
||||
VectorType *VTy = cast<VectorType>(V->getType());
|
||||
Type *ExpectedOperandTy = VectorType::getSubdividedVectorType(VTy, 2);
|
||||
if (AReal->getType() != ExpectedOperandTy)
|
||||
return nullptr;
|
||||
if (AImag->getType() != ExpectedOperandTy)
|
||||
return nullptr;
|
||||
if (BReal->getType() != ExpectedOperandTy)
|
||||
return nullptr;
|
||||
if (BImag->getType() != ExpectedOperandTy)
|
||||
return nullptr;
|
||||
|
||||
if (Phi->getType() != VTy && RealUser->getType() != VTy)
|
||||
return nullptr;
|
||||
|
||||
NodePtr Node = identifyNode(AReal, AImag);
|
||||
|
||||
// In the case that a node was identified to figure out the rotation, ensure
|
||||
// that trying to identify a node with AReal and AImag post-unwrap results in
|
||||
// the same node
|
||||
if (ANode && Node != ANode) {
|
||||
LLVM_DEBUG(
|
||||
dbgs()
|
||||
<< "Identified node is different from previously identified node. "
|
||||
"Unable to confidently generate a complex operation node\n");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
CN->addOperand(Node);
|
||||
CN->addOperand(identifyNode(BReal, BImag));
|
||||
CN->addOperand(identifyNode(Phi, RealUser));
|
||||
|
||||
return submitCompositeNode(CN);
|
||||
}
|
||||
|
||||
ComplexDeinterleavingGraph::NodePtr
|
||||
ComplexDeinterleavingGraph::identifyPartialReduction(Value *R, Value *I) {
|
||||
// Partial reductions don't support non-vector types, so check these first
|
||||
if (!isa<VectorType>(R->getType()) || !isa<VectorType>(I->getType()))
|
||||
return nullptr;
|
||||
|
||||
auto CommonUser =
|
||||
findCommonBetweenCollections<Value *>(R->users(), I->users());
|
||||
if (!CommonUser)
|
||||
return nullptr;
|
||||
|
||||
auto *IInst = dyn_cast<IntrinsicInst>(*CommonUser);
|
||||
if (!IInst || IInst->getIntrinsicID() !=
|
||||
Intrinsic::experimental_vector_partial_reduce_add)
|
||||
return nullptr;
|
||||
|
||||
if (NodePtr CN = identifyDotProduct(IInst))
|
||||
return CN;
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ComplexDeinterleavingGraph::NodePtr
|
||||
ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
|
||||
auto It = CachedResult.find({R, I});
|
||||
if (It != CachedResult.end()) {
|
||||
LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
|
||||
return It->second;
|
||||
}
|
||||
|
||||
if (NodePtr CN = identifyPartialReduction(R, I))
|
||||
return CN;
|
||||
|
||||
bool IsReduction = RealPHI == R && (!ImagPHI || ImagPHI == I);
|
||||
if (!IsReduction && R->getType() != I->getType())
|
||||
return nullptr;
|
||||
|
||||
if (NodePtr CN = identifySplat(R, I))
|
||||
return CN;
|
||||
|
||||
@ -1427,12 +1590,20 @@ bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
|
||||
if (It != RootToNode.end()) {
|
||||
auto RootNode = It->second;
|
||||
assert(RootNode->Operation ==
|
||||
ComplexDeinterleavingOperation::ReductionOperation);
|
||||
ComplexDeinterleavingOperation::ReductionOperation ||
|
||||
RootNode->Operation ==
|
||||
ComplexDeinterleavingOperation::ReductionSingle);
|
||||
// Find out which part, Real or Imag, comes later, and only if we come to
|
||||
// the latest part, add it to OrderedRoots.
|
||||
auto *R = cast<Instruction>(RootNode->Real);
|
||||
auto *I = cast<Instruction>(RootNode->Imag);
|
||||
auto *ReplacementAnchor = R->comesBefore(I) ? I : R;
|
||||
auto *I = RootNode->Imag ? cast<Instruction>(RootNode->Imag) : nullptr;
|
||||
|
||||
Instruction *ReplacementAnchor;
|
||||
if (I)
|
||||
ReplacementAnchor = R->comesBefore(I) ? I : R;
|
||||
else
|
||||
ReplacementAnchor = R;
|
||||
|
||||
if (ReplacementAnchor != RootI)
|
||||
return false;
|
||||
OrderedRoots.push_back(RootI);
|
||||
@ -1523,7 +1694,6 @@ void ComplexDeinterleavingGraph::identifyReductionNodes() {
|
||||
for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
|
||||
if (Processed[j])
|
||||
continue;
|
||||
|
||||
auto *Real = OperationInstruction[i];
|
||||
auto *Imag = OperationInstruction[j];
|
||||
if (Real->getType() != Imag->getType())
|
||||
@ -1556,6 +1726,28 @@ void ComplexDeinterleavingGraph::identifyReductionNodes() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
auto *Real = OperationInstruction[i];
|
||||
// We want to check that we have 2 operands, but the function attributes
|
||||
// being counted as operands bloats this value.
|
||||
if (Real->getNumOperands() < 2)
|
||||
continue;
|
||||
|
||||
RealPHI = ReductionInfo[Real].first;
|
||||
ImagPHI = nullptr;
|
||||
PHIsFound = false;
|
||||
auto Node = identifyNode(Real->getOperand(0), Real->getOperand(1));
|
||||
if (Node && PHIsFound) {
|
||||
LLVM_DEBUG(
|
||||
dbgs() << "Identified single reduction starting from instruction: "
|
||||
<< *Real << "/" << *ReductionInfo[Real].second << "\n");
|
||||
Processed[i] = true;
|
||||
auto RootNode = prepareCompositeNode(
|
||||
ComplexDeinterleavingOperation::ReductionSingle, Real, nullptr);
|
||||
RootNode->addOperand(Node);
|
||||
RootToNode[Real] = RootNode;
|
||||
submitCompositeNode(RootNode);
|
||||
}
|
||||
}
|
||||
|
||||
RealPHI = nullptr;
|
||||
@ -1563,6 +1755,12 @@ void ComplexDeinterleavingGraph::identifyReductionNodes() {
|
||||
}
|
||||
|
||||
bool ComplexDeinterleavingGraph::checkNodes() {
|
||||
|
||||
for (NodePtr N : CompositeNodes) {
|
||||
if (!N->areOperandsValid())
|
||||
return false;
|
||||
}
|
||||
|
||||
// Collect all instructions from roots to leaves
|
||||
SmallPtrSet<Instruction *, 16> AllInstructions;
|
||||
SmallVector<Instruction *, 8> Worklist;
|
||||
@ -1831,7 +2029,7 @@ ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) {
|
||||
ComplexDeinterleavingGraph::NodePtr
|
||||
ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
|
||||
Instruction *Imag) {
|
||||
if (Real != RealPHI || Imag != ImagPHI)
|
||||
if (Real != RealPHI || (ImagPHI && Imag != ImagPHI))
|
||||
return nullptr;
|
||||
|
||||
PHIsFound = true;
|
||||
@ -1926,6 +2124,16 @@ Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
|
||||
|
||||
Value *ReplacementNode;
|
||||
switch (Node->Operation) {
|
||||
case ComplexDeinterleavingOperation::CDot: {
|
||||
Value *Input0 = ReplaceOperandIfExist(Node, 0);
|
||||
Value *Input1 = ReplaceOperandIfExist(Node, 1);
|
||||
Value *Accumulator = ReplaceOperandIfExist(Node, 2);
|
||||
assert(!Input1 || (Input0->getType() == Input1->getType() &&
|
||||
"Node inputs need to be of the same type"));
|
||||
ReplacementNode = TL->createComplexDeinterleavingIR(
|
||||
Builder, Node->Operation, Node->Rotation, Input0, Input1, Accumulator);
|
||||
break;
|
||||
}
|
||||
case ComplexDeinterleavingOperation::CAdd:
|
||||
case ComplexDeinterleavingOperation::CMulPartial:
|
||||
case ComplexDeinterleavingOperation::Symmetric: {
|
||||
@ -1969,13 +2177,18 @@ Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
|
||||
case ComplexDeinterleavingOperation::ReductionPHI: {
|
||||
// If Operation is ReductionPHI, a new empty PHINode is created.
|
||||
// It is filled later when the ReductionOperation is processed.
|
||||
auto *OldPHI = cast<PHINode>(Node->Real);
|
||||
auto *VTy = cast<VectorType>(Node->Real->getType());
|
||||
auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
|
||||
auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHIIt());
|
||||
OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI;
|
||||
OldToNewPHI[OldPHI] = NewPHI;
|
||||
ReplacementNode = NewPHI;
|
||||
break;
|
||||
}
|
||||
case ComplexDeinterleavingOperation::ReductionSingle:
|
||||
ReplacementNode = replaceNode(Builder, Node->Operands[0]);
|
||||
processReductionSingle(ReplacementNode, Node);
|
||||
break;
|
||||
case ComplexDeinterleavingOperation::ReductionOperation:
|
||||
ReplacementNode = replaceNode(Builder, Node->Operands[0]);
|
||||
processReductionOperation(ReplacementNode, Node);
|
||||
@ -2000,6 +2213,38 @@ Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
|
||||
return ReplacementNode;
|
||||
}
|
||||
|
||||
void ComplexDeinterleavingGraph::processReductionSingle(
|
||||
Value *OperationReplacement, RawNodePtr Node) {
|
||||
auto *Real = cast<Instruction>(Node->Real);
|
||||
auto *OldPHI = ReductionInfo[Real].first;
|
||||
auto *NewPHI = OldToNewPHI[OldPHI];
|
||||
auto *VTy = cast<VectorType>(Real->getType());
|
||||
auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
|
||||
|
||||
Value *Init = OldPHI->getIncomingValueForBlock(Incoming);
|
||||
|
||||
IRBuilder<> Builder(Incoming->getTerminator());
|
||||
|
||||
Value *NewInit = nullptr;
|
||||
if (auto *C = dyn_cast<Constant>(Init)) {
|
||||
if (C->isZeroValue())
|
||||
NewInit = Constant::getNullValue(NewVTy);
|
||||
}
|
||||
|
||||
if (!NewInit)
|
||||
NewInit = Builder.CreateIntrinsic(Intrinsic::vector_interleave2, NewVTy,
|
||||
{Init, Constant::getNullValue(VTy)});
|
||||
|
||||
NewPHI->addIncoming(NewInit, Incoming);
|
||||
NewPHI->addIncoming(OperationReplacement, BackEdge);
|
||||
|
||||
auto *FinalReduction = ReductionInfo[Real].second;
|
||||
Builder.SetInsertPoint(&*FinalReduction->getParent()->getFirstInsertionPt());
|
||||
|
||||
auto *AddReduce = Builder.CreateAddReduce(OperationReplacement);
|
||||
FinalReduction->replaceAllUsesWith(AddReduce);
|
||||
}
|
||||
|
||||
void ComplexDeinterleavingGraph::processReductionOperation(
|
||||
Value *OperationReplacement, RawNodePtr Node) {
|
||||
auto *Real = cast<Instruction>(Node->Real);
|
||||
@ -2059,8 +2304,13 @@ void ComplexDeinterleavingGraph::replaceNodes() {
|
||||
auto *RootImag = cast<Instruction>(RootNode->Imag);
|
||||
ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
|
||||
ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
|
||||
DeadInstrRoots.push_back(cast<Instruction>(RootReal));
|
||||
DeadInstrRoots.push_back(cast<Instruction>(RootImag));
|
||||
DeadInstrRoots.push_back(RootReal);
|
||||
DeadInstrRoots.push_back(RootImag);
|
||||
} else if (RootNode->Operation ==
|
||||
ComplexDeinterleavingOperation::ReductionSingle) {
|
||||
auto *RootInst = cast<Instruction>(RootNode->Real);
|
||||
ReductionInfo[RootInst].first->removeIncomingValue(BackEdge);
|
||||
DeadInstrRoots.push_back(ReductionInfo[RootInst].second);
|
||||
} else {
|
||||
assert(R && "Unable to find replacement for RootInstruction");
|
||||
DeadInstrRoots.push_back(RootInstruction);
|
||||
|
||||
@ -29542,9 +29542,16 @@ bool AArch64TargetLowering::isComplexDeinterleavingOperationSupported(
|
||||
|
||||
if (ScalarTy->isIntegerTy() && Subtarget->hasSVE2() && VTy->isScalableTy()) {
|
||||
unsigned ScalarWidth = ScalarTy->getScalarSizeInBits();
|
||||
|
||||
if (Operation == ComplexDeinterleavingOperation::CDot)
|
||||
return ScalarWidth == 32 || ScalarWidth == 64;
|
||||
return 8 <= ScalarWidth && ScalarWidth <= 64;
|
||||
}
|
||||
|
||||
// CDot is not supported outside of scalable/sve scopes
|
||||
if (Operation == ComplexDeinterleavingOperation::CDot)
|
||||
return false;
|
||||
|
||||
return (ScalarTy->isHalfTy() && Subtarget->hasFullFP16()) ||
|
||||
ScalarTy->isFloatTy() || ScalarTy->isDoubleTy();
|
||||
}
|
||||
@ -29554,6 +29561,8 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR(
|
||||
ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB,
|
||||
Value *Accumulator) const {
|
||||
VectorType *Ty = cast<VectorType>(InputA->getType());
|
||||
if (Accumulator == nullptr)
|
||||
Accumulator = Constant::getNullValue(Ty);
|
||||
bool IsScalable = Ty->isScalableTy();
|
||||
bool IsInt = Ty->getElementType()->isIntegerTy();
|
||||
|
||||
@ -29565,6 +29574,10 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR(
|
||||
|
||||
if (TyWidth > 128) {
|
||||
int Stride = Ty->getElementCount().getKnownMinValue() / 2;
|
||||
int AccStride = cast<VectorType>(Accumulator->getType())
|
||||
->getElementCount()
|
||||
.getKnownMinValue() /
|
||||
2;
|
||||
auto *HalfTy = VectorType::getHalfElementsVectorType(Ty);
|
||||
auto *LowerSplitA = B.CreateExtractVector(HalfTy, InputA, B.getInt64(0));
|
||||
auto *LowerSplitB = B.CreateExtractVector(HalfTy, InputB, B.getInt64(0));
|
||||
@ -29574,25 +29587,26 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR(
|
||||
B.CreateExtractVector(HalfTy, InputB, B.getInt64(Stride));
|
||||
Value *LowerSplitAcc = nullptr;
|
||||
Value *UpperSplitAcc = nullptr;
|
||||
if (Accumulator) {
|
||||
LowerSplitAcc = B.CreateExtractVector(HalfTy, Accumulator, B.getInt64(0));
|
||||
UpperSplitAcc =
|
||||
B.CreateExtractVector(HalfTy, Accumulator, B.getInt64(Stride));
|
||||
}
|
||||
Type *FullTy = Ty;
|
||||
FullTy = Accumulator->getType();
|
||||
auto *HalfAccTy = VectorType::getHalfElementsVectorType(
|
||||
cast<VectorType>(Accumulator->getType()));
|
||||
LowerSplitAcc =
|
||||
B.CreateExtractVector(HalfAccTy, Accumulator, B.getInt64(0));
|
||||
UpperSplitAcc =
|
||||
B.CreateExtractVector(HalfAccTy, Accumulator, B.getInt64(AccStride));
|
||||
auto *LowerSplitInt = createComplexDeinterleavingIR(
|
||||
B, OperationType, Rotation, LowerSplitA, LowerSplitB, LowerSplitAcc);
|
||||
auto *UpperSplitInt = createComplexDeinterleavingIR(
|
||||
B, OperationType, Rotation, UpperSplitA, UpperSplitB, UpperSplitAcc);
|
||||
|
||||
auto *Result = B.CreateInsertVector(Ty, PoisonValue::get(Ty), LowerSplitInt,
|
||||
B.getInt64(0));
|
||||
return B.CreateInsertVector(Ty, Result, UpperSplitInt, B.getInt64(Stride));
|
||||
auto *Result = B.CreateInsertVector(FullTy, PoisonValue::get(FullTy),
|
||||
LowerSplitInt, B.getInt64(0));
|
||||
return B.CreateInsertVector(FullTy, Result, UpperSplitInt,
|
||||
B.getInt64(AccStride));
|
||||
}
|
||||
|
||||
if (OperationType == ComplexDeinterleavingOperation::CMulPartial) {
|
||||
if (Accumulator == nullptr)
|
||||
Accumulator = Constant::getNullValue(Ty);
|
||||
|
||||
if (IsScalable) {
|
||||
if (IsInt)
|
||||
return B.CreateIntrinsic(
|
||||
@ -29644,6 +29658,13 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR(
|
||||
return B.CreateIntrinsic(IntId, Ty, {InputA, InputB});
|
||||
}
|
||||
|
||||
if (OperationType == ComplexDeinterleavingOperation::CDot && IsInt &&
|
||||
IsScalable) {
|
||||
return B.CreateIntrinsic(
|
||||
Intrinsic::aarch64_sve_cdot, Accumulator->getType(),
|
||||
{Accumulator, InputA, InputB, B.getInt32((int)Rotation * 90)});
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
||||
1136
llvm/test/CodeGen/AArch64/complex-deinterleaving-cdot.ll
Normal file
1136
llvm/test/CodeGen/AArch64/complex-deinterleaving-cdot.ll
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user