Add support for single reductions in ComplexDeinterleavingPass (#112875)

The Complex Deinterleaving pass assumes that all values emitted will
result in complex numbers, this patch aims to remove that assumption and
adds support for emitting just the real or imaginary components, not
both.
This commit is contained in:
Nicholas Guy 2024-12-18 10:34:26 +00:00 committed by GitHub
parent f8d270474c
commit b3eede5e1f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 1434 additions and 25 deletions

View File

@ -35,6 +35,7 @@ public:
enum class ComplexDeinterleavingOperation {
CAdd,
CMulPartial,
CDot,
// The following 'operations' are used to represent internal states. Backends
// are not expected to try and support these in any capacity.
Deinterleave,
@ -43,6 +44,7 @@ enum class ComplexDeinterleavingOperation {
ReductionPHI,
ReductionOperation,
ReductionSelect,
ReductionSingle
};
enum class ComplexDeinterleavingRotation {

View File

@ -108,6 +108,13 @@ static bool isNeg(Value *V);
static Value *getNegOperand(Value *V);
namespace {
template <typename T, typename IterT>
std::optional<T> findCommonBetweenCollections(IterT A, IterT B) {
auto Common = llvm::find_if(A, [B](T I) { return llvm::is_contained(B, I); });
if (Common != A.end())
return std::make_optional(*Common);
return std::nullopt;
}
class ComplexDeinterleavingLegacyPass : public FunctionPass {
public:
@ -144,6 +151,7 @@ private:
friend class ComplexDeinterleavingGraph;
using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
using RawNodePtr = ComplexDeinterleavingCompositeNode *;
bool OperandsValid = true;
public:
ComplexDeinterleavingOperation Operation;
@ -160,7 +168,11 @@ public:
SmallVector<RawNodePtr> Operands;
Value *ReplacementNode = nullptr;
void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }
void addOperand(NodePtr Node) {
if (!Node || !Node.get())
OperandsValid = false;
Operands.push_back(Node.get());
}
void dump() { dump(dbgs()); }
void dump(raw_ostream &OS) {
@ -194,6 +206,8 @@ public:
PrintNodeRef(Op);
}
}
bool areOperandsValid() { return OperandsValid; }
};
class ComplexDeinterleavingGraph {
@ -293,7 +307,7 @@ private:
NodePtr submitCompositeNode(NodePtr Node) {
CompositeNodes.push_back(Node);
if (Node->Real && Node->Imag)
if (Node->Real)
CachedResult[{Node->Real, Node->Imag}] = Node;
return Node;
}
@ -327,6 +341,8 @@ private:
/// i: ai - br
NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag);
NodePtr identifyPartialReduction(Value *R, Value *I);
NodePtr identifyDotProduct(Value *Inst);
NodePtr identifyNode(Value *R, Value *I);
@ -396,6 +412,7 @@ private:
/// * Deinterleave the final value outside of the loop and repurpose original
/// reduction users
void processReductionOperation(Value *OperationReplacement, RawNodePtr Node);
void processReductionSingle(Value *OperationReplacement, RawNodePtr Node);
public:
void dump() { dump(dbgs()); }
@ -891,17 +908,163 @@ ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
}
ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n");
assert(R->getType() == I->getType() &&
"Real and imaginary parts should not have different types");
ComplexDeinterleavingGraph::identifyDotProduct(Value *V) {
if (!TL->isComplexDeinterleavingOperationSupported(
ComplexDeinterleavingOperation::CDot, V->getType())) {
LLVM_DEBUG(dbgs() << "Target doesn't support complex deinterleaving "
"operation CDot with the type "
<< *V->getType() << "\n");
return nullptr;
}
auto *Inst = cast<Instruction>(V);
auto *RealUser = cast<Instruction>(*Inst->user_begin());
NodePtr CN =
prepareCompositeNode(ComplexDeinterleavingOperation::CDot, Inst, nullptr);
NodePtr ANode;
const Intrinsic::ID PartialReduceInt =
Intrinsic::experimental_vector_partial_reduce_add;
Value *AReal = nullptr;
Value *AImag = nullptr;
Value *BReal = nullptr;
Value *BImag = nullptr;
Value *Phi = nullptr;
auto UnwrapCast = [](Value *V) -> Value * {
if (auto *CI = dyn_cast<CastInst>(V))
return CI->getOperand(0);
return V;
};
auto PatternRot0 = m_Intrinsic<PartialReduceInt>(
m_Intrinsic<PartialReduceInt>(m_Value(Phi),
m_Mul(m_Value(BReal), m_Value(AReal))),
m_Neg(m_Mul(m_Value(BImag), m_Value(AImag))));
auto PatternRot270 = m_Intrinsic<PartialReduceInt>(
m_Intrinsic<PartialReduceInt>(
m_Value(Phi), m_Neg(m_Mul(m_Value(BReal), m_Value(AImag)))),
m_Mul(m_Value(BImag), m_Value(AReal)));
if (match(Inst, PatternRot0)) {
CN->Rotation = ComplexDeinterleavingRotation::Rotation_0;
} else if (match(Inst, PatternRot270)) {
CN->Rotation = ComplexDeinterleavingRotation::Rotation_270;
} else {
Value *A0, *A1;
// The rotations 90 and 180 share the same operation pattern, so inspect the
// order of the operands, identifying where the real and imaginary
// components of A go, to discern between the aforementioned rotations.
auto PatternRot90Rot180 = m_Intrinsic<PartialReduceInt>(
m_Intrinsic<PartialReduceInt>(m_Value(Phi),
m_Mul(m_Value(BReal), m_Value(A0))),
m_Mul(m_Value(BImag), m_Value(A1)));
if (!match(Inst, PatternRot90Rot180))
return nullptr;
A0 = UnwrapCast(A0);
A1 = UnwrapCast(A1);
// Test if A0 is real/A1 is imag
ANode = identifyNode(A0, A1);
if (!ANode) {
// Test if A0 is imag/A1 is real
ANode = identifyNode(A1, A0);
// Unable to identify operand components, thus unable to identify rotation
if (!ANode)
return nullptr;
CN->Rotation = ComplexDeinterleavingRotation::Rotation_90;
AReal = A1;
AImag = A0;
} else {
AReal = A0;
AImag = A1;
CN->Rotation = ComplexDeinterleavingRotation::Rotation_180;
}
}
AReal = UnwrapCast(AReal);
AImag = UnwrapCast(AImag);
BReal = UnwrapCast(BReal);
BImag = UnwrapCast(BImag);
VectorType *VTy = cast<VectorType>(V->getType());
Type *ExpectedOperandTy = VectorType::getSubdividedVectorType(VTy, 2);
if (AReal->getType() != ExpectedOperandTy)
return nullptr;
if (AImag->getType() != ExpectedOperandTy)
return nullptr;
if (BReal->getType() != ExpectedOperandTy)
return nullptr;
if (BImag->getType() != ExpectedOperandTy)
return nullptr;
if (Phi->getType() != VTy && RealUser->getType() != VTy)
return nullptr;
NodePtr Node = identifyNode(AReal, AImag);
// In the case that a node was identified to figure out the rotation, ensure
// that trying to identify a node with AReal and AImag post-unwrap results in
// the same node
if (ANode && Node != ANode) {
LLVM_DEBUG(
dbgs()
<< "Identified node is different from previously identified node. "
"Unable to confidently generate a complex operation node\n");
return nullptr;
}
CN->addOperand(Node);
CN->addOperand(identifyNode(BReal, BImag));
CN->addOperand(identifyNode(Phi, RealUser));
return submitCompositeNode(CN);
}
ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifyPartialReduction(Value *R, Value *I) {
// Partial reductions don't support non-vector types, so check these first
if (!isa<VectorType>(R->getType()) || !isa<VectorType>(I->getType()))
return nullptr;
auto CommonUser =
findCommonBetweenCollections<Value *>(R->users(), I->users());
if (!CommonUser)
return nullptr;
auto *IInst = dyn_cast<IntrinsicInst>(*CommonUser);
if (!IInst || IInst->getIntrinsicID() !=
Intrinsic::experimental_vector_partial_reduce_add)
return nullptr;
if (NodePtr CN = identifyDotProduct(IInst))
return CN;
return nullptr;
}
ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
auto It = CachedResult.find({R, I});
if (It != CachedResult.end()) {
LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
return It->second;
}
if (NodePtr CN = identifyPartialReduction(R, I))
return CN;
bool IsReduction = RealPHI == R && (!ImagPHI || ImagPHI == I);
if (!IsReduction && R->getType() != I->getType())
return nullptr;
if (NodePtr CN = identifySplat(R, I))
return CN;
@ -1427,12 +1590,20 @@ bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
if (It != RootToNode.end()) {
auto RootNode = It->second;
assert(RootNode->Operation ==
ComplexDeinterleavingOperation::ReductionOperation);
ComplexDeinterleavingOperation::ReductionOperation ||
RootNode->Operation ==
ComplexDeinterleavingOperation::ReductionSingle);
// Find out which part, Real or Imag, comes later, and only if we come to
// the latest part, add it to OrderedRoots.
auto *R = cast<Instruction>(RootNode->Real);
auto *I = cast<Instruction>(RootNode->Imag);
auto *ReplacementAnchor = R->comesBefore(I) ? I : R;
auto *I = RootNode->Imag ? cast<Instruction>(RootNode->Imag) : nullptr;
Instruction *ReplacementAnchor;
if (I)
ReplacementAnchor = R->comesBefore(I) ? I : R;
else
ReplacementAnchor = R;
if (ReplacementAnchor != RootI)
return false;
OrderedRoots.push_back(RootI);
@ -1523,7 +1694,6 @@ void ComplexDeinterleavingGraph::identifyReductionNodes() {
for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
if (Processed[j])
continue;
auto *Real = OperationInstruction[i];
auto *Imag = OperationInstruction[j];
if (Real->getType() != Imag->getType())
@ -1556,6 +1726,28 @@ void ComplexDeinterleavingGraph::identifyReductionNodes() {
break;
}
}
auto *Real = OperationInstruction[i];
// We want to check that we have 2 operands, but the function attributes
// being counted as operands bloats this value.
if (Real->getNumOperands() < 2)
continue;
RealPHI = ReductionInfo[Real].first;
ImagPHI = nullptr;
PHIsFound = false;
auto Node = identifyNode(Real->getOperand(0), Real->getOperand(1));
if (Node && PHIsFound) {
LLVM_DEBUG(
dbgs() << "Identified single reduction starting from instruction: "
<< *Real << "/" << *ReductionInfo[Real].second << "\n");
Processed[i] = true;
auto RootNode = prepareCompositeNode(
ComplexDeinterleavingOperation::ReductionSingle, Real, nullptr);
RootNode->addOperand(Node);
RootToNode[Real] = RootNode;
submitCompositeNode(RootNode);
}
}
RealPHI = nullptr;
@ -1563,6 +1755,12 @@ void ComplexDeinterleavingGraph::identifyReductionNodes() {
}
bool ComplexDeinterleavingGraph::checkNodes() {
for (NodePtr N : CompositeNodes) {
if (!N->areOperandsValid())
return false;
}
// Collect all instructions from roots to leaves
SmallPtrSet<Instruction *, 16> AllInstructions;
SmallVector<Instruction *, 8> Worklist;
@ -1831,7 +2029,7 @@ ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) {
ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
Instruction *Imag) {
if (Real != RealPHI || Imag != ImagPHI)
if (Real != RealPHI || (ImagPHI && Imag != ImagPHI))
return nullptr;
PHIsFound = true;
@ -1926,6 +2124,16 @@ Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
Value *ReplacementNode;
switch (Node->Operation) {
case ComplexDeinterleavingOperation::CDot: {
Value *Input0 = ReplaceOperandIfExist(Node, 0);
Value *Input1 = ReplaceOperandIfExist(Node, 1);
Value *Accumulator = ReplaceOperandIfExist(Node, 2);
assert(!Input1 || (Input0->getType() == Input1->getType() &&
"Node inputs need to be of the same type"));
ReplacementNode = TL->createComplexDeinterleavingIR(
Builder, Node->Operation, Node->Rotation, Input0, Input1, Accumulator);
break;
}
case ComplexDeinterleavingOperation::CAdd:
case ComplexDeinterleavingOperation::CMulPartial:
case ComplexDeinterleavingOperation::Symmetric: {
@ -1969,13 +2177,18 @@ Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
case ComplexDeinterleavingOperation::ReductionPHI: {
// If Operation is ReductionPHI, a new empty PHINode is created.
// It is filled later when the ReductionOperation is processed.
auto *OldPHI = cast<PHINode>(Node->Real);
auto *VTy = cast<VectorType>(Node->Real->getType());
auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHIIt());
OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI;
OldToNewPHI[OldPHI] = NewPHI;
ReplacementNode = NewPHI;
break;
}
case ComplexDeinterleavingOperation::ReductionSingle:
ReplacementNode = replaceNode(Builder, Node->Operands[0]);
processReductionSingle(ReplacementNode, Node);
break;
case ComplexDeinterleavingOperation::ReductionOperation:
ReplacementNode = replaceNode(Builder, Node->Operands[0]);
processReductionOperation(ReplacementNode, Node);
@ -2000,6 +2213,38 @@ Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
return ReplacementNode;
}
void ComplexDeinterleavingGraph::processReductionSingle(
Value *OperationReplacement, RawNodePtr Node) {
auto *Real = cast<Instruction>(Node->Real);
auto *OldPHI = ReductionInfo[Real].first;
auto *NewPHI = OldToNewPHI[OldPHI];
auto *VTy = cast<VectorType>(Real->getType());
auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
Value *Init = OldPHI->getIncomingValueForBlock(Incoming);
IRBuilder<> Builder(Incoming->getTerminator());
Value *NewInit = nullptr;
if (auto *C = dyn_cast<Constant>(Init)) {
if (C->isZeroValue())
NewInit = Constant::getNullValue(NewVTy);
}
if (!NewInit)
NewInit = Builder.CreateIntrinsic(Intrinsic::vector_interleave2, NewVTy,
{Init, Constant::getNullValue(VTy)});
NewPHI->addIncoming(NewInit, Incoming);
NewPHI->addIncoming(OperationReplacement, BackEdge);
auto *FinalReduction = ReductionInfo[Real].second;
Builder.SetInsertPoint(&*FinalReduction->getParent()->getFirstInsertionPt());
auto *AddReduce = Builder.CreateAddReduce(OperationReplacement);
FinalReduction->replaceAllUsesWith(AddReduce);
}
void ComplexDeinterleavingGraph::processReductionOperation(
Value *OperationReplacement, RawNodePtr Node) {
auto *Real = cast<Instruction>(Node->Real);
@ -2059,8 +2304,13 @@ void ComplexDeinterleavingGraph::replaceNodes() {
auto *RootImag = cast<Instruction>(RootNode->Imag);
ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
DeadInstrRoots.push_back(cast<Instruction>(RootReal));
DeadInstrRoots.push_back(cast<Instruction>(RootImag));
DeadInstrRoots.push_back(RootReal);
DeadInstrRoots.push_back(RootImag);
} else if (RootNode->Operation ==
ComplexDeinterleavingOperation::ReductionSingle) {
auto *RootInst = cast<Instruction>(RootNode->Real);
ReductionInfo[RootInst].first->removeIncomingValue(BackEdge);
DeadInstrRoots.push_back(ReductionInfo[RootInst].second);
} else {
assert(R && "Unable to find replacement for RootInstruction");
DeadInstrRoots.push_back(RootInstruction);

View File

@ -29542,9 +29542,16 @@ bool AArch64TargetLowering::isComplexDeinterleavingOperationSupported(
if (ScalarTy->isIntegerTy() && Subtarget->hasSVE2() && VTy->isScalableTy()) {
unsigned ScalarWidth = ScalarTy->getScalarSizeInBits();
if (Operation == ComplexDeinterleavingOperation::CDot)
return ScalarWidth == 32 || ScalarWidth == 64;
return 8 <= ScalarWidth && ScalarWidth <= 64;
}
// CDot is not supported outside of scalable/sve scopes
if (Operation == ComplexDeinterleavingOperation::CDot)
return false;
return (ScalarTy->isHalfTy() && Subtarget->hasFullFP16()) ||
ScalarTy->isFloatTy() || ScalarTy->isDoubleTy();
}
@ -29554,6 +29561,8 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR(
ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB,
Value *Accumulator) const {
VectorType *Ty = cast<VectorType>(InputA->getType());
if (Accumulator == nullptr)
Accumulator = Constant::getNullValue(Ty);
bool IsScalable = Ty->isScalableTy();
bool IsInt = Ty->getElementType()->isIntegerTy();
@ -29565,6 +29574,10 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR(
if (TyWidth > 128) {
int Stride = Ty->getElementCount().getKnownMinValue() / 2;
int AccStride = cast<VectorType>(Accumulator->getType())
->getElementCount()
.getKnownMinValue() /
2;
auto *HalfTy = VectorType::getHalfElementsVectorType(Ty);
auto *LowerSplitA = B.CreateExtractVector(HalfTy, InputA, B.getInt64(0));
auto *LowerSplitB = B.CreateExtractVector(HalfTy, InputB, B.getInt64(0));
@ -29574,25 +29587,26 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR(
B.CreateExtractVector(HalfTy, InputB, B.getInt64(Stride));
Value *LowerSplitAcc = nullptr;
Value *UpperSplitAcc = nullptr;
if (Accumulator) {
LowerSplitAcc = B.CreateExtractVector(HalfTy, Accumulator, B.getInt64(0));
UpperSplitAcc =
B.CreateExtractVector(HalfTy, Accumulator, B.getInt64(Stride));
}
Type *FullTy = Ty;
FullTy = Accumulator->getType();
auto *HalfAccTy = VectorType::getHalfElementsVectorType(
cast<VectorType>(Accumulator->getType()));
LowerSplitAcc =
B.CreateExtractVector(HalfAccTy, Accumulator, B.getInt64(0));
UpperSplitAcc =
B.CreateExtractVector(HalfAccTy, Accumulator, B.getInt64(AccStride));
auto *LowerSplitInt = createComplexDeinterleavingIR(
B, OperationType, Rotation, LowerSplitA, LowerSplitB, LowerSplitAcc);
auto *UpperSplitInt = createComplexDeinterleavingIR(
B, OperationType, Rotation, UpperSplitA, UpperSplitB, UpperSplitAcc);
auto *Result = B.CreateInsertVector(Ty, PoisonValue::get(Ty), LowerSplitInt,
B.getInt64(0));
return B.CreateInsertVector(Ty, Result, UpperSplitInt, B.getInt64(Stride));
auto *Result = B.CreateInsertVector(FullTy, PoisonValue::get(FullTy),
LowerSplitInt, B.getInt64(0));
return B.CreateInsertVector(FullTy, Result, UpperSplitInt,
B.getInt64(AccStride));
}
if (OperationType == ComplexDeinterleavingOperation::CMulPartial) {
if (Accumulator == nullptr)
Accumulator = Constant::getNullValue(Ty);
if (IsScalable) {
if (IsInt)
return B.CreateIntrinsic(
@ -29644,6 +29658,13 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR(
return B.CreateIntrinsic(IntId, Ty, {InputA, InputB});
}
if (OperationType == ComplexDeinterleavingOperation::CDot && IsInt &&
IsScalable) {
return B.CreateIntrinsic(
Intrinsic::aarch64_sve_cdot, Accumulator->getType(),
{Accumulator, InputA, InputB, B.getInt32((int)Rotation * 90)});
}
return nullptr;
}

File diff suppressed because it is too large Load Diff