[LoongArch] Broadcast repeated subsequence in build_vector instead of inserting per element

This commit is contained in:
Qi Zhao 2025-08-20 20:39:50 +08:00
parent 8375c79afe
commit 3674bad63b
3 changed files with 76 additions and 1 deletions

View File

@ -2434,6 +2434,7 @@ static SDValue lowerBUILD_VECTORAsBroadCastLoad(BuildVectorSDNode *BVOp,
SDValue LoongArchTargetLowering::lowerBUILD_VECTOR(SDValue Op,
SelectionDAG &DAG) const {
BuildVectorSDNode *Node = cast<BuildVectorSDNode>(Op);
MVT VT = Node->getSimpleValueType(0);
EVT ResTy = Op->getValueType(0);
unsigned NumElts = ResTy.getVectorNumElements();
SDLoc DL(Op);
@ -2517,6 +2518,56 @@ SDValue LoongArchTargetLowering::lowerBUILD_VECTOR(SDValue Op,
}
if (!IsConstant) {
// If the BUILD_VECTOR has a repeated pattern, use INSERT_VECTOR_ELT to fill
// the sub-sequence of the vector and then broadcast the sub-sequence.
SmallVector<SDValue> Sequence;
BitVector UndefElements;
if (Node->getRepeatedSequence(Sequence, &UndefElements)) {
// TODO: If the BUILD_VECTOR contains undef elements, consider falling
// back to use INSERT_VECTOR_ELT to materialize the vector, because it
// generates worse code in some cases. This could be further optimized
// with more consideration.
if (UndefElements.count() == 0) {
unsigned SeqLen = Sequence.size();
SDValue Op0 = Sequence[0];
SDValue Vector = DAG.getUNDEF(ResTy);
if (!Op0.isUndef())
Vector = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, ResTy, Op0);
for (unsigned i = 1; i < SeqLen; ++i) {
SDValue Opi = Sequence[i];
if (Opi.isUndef())
continue;
Vector = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ResTy, Vector, Opi,
DAG.getConstant(i, DL, Subtarget.getGRLenVT()));
}
unsigned SplatLen = NumElts / SeqLen;
MVT SplatEltTy = MVT::getIntegerVT(VT.getScalarSizeInBits() * SeqLen);
MVT SplatTy = MVT::getVectorVT(SplatEltTy, SplatLen);
// If size of the sub-sequence is half of a 256-bits vector, bitcast the
// vector to v4i64 type in order to match the pattern of XVREPLVE0Q.
if (SplatEltTy == MVT::i128)
SplatTy = MVT::v4i64;
SDValue SrcVec = DAG.getBitcast(SplatTy, Vector);
SDValue SplatVec;
if (SplatTy.is256BitVector()) {
SplatVec =
DAG.getNode((SplatEltTy == MVT::i128) ? LoongArchISD::XVREPLVE0Q
: LoongArchISD::XVREPLVE0,
DL, SplatTy, SrcVec);
} else {
SplatVec =
DAG.getNode(LoongArchISD::VREPLVEI, DL, SplatTy, SrcVec,
DAG.getConstant(0, DL, Subtarget.getGRLenVT()));
}
return DAG.getBitcast(ResTy, SplatVec);
}
}
// Use INSERT_VECTOR_ELT operations rather than expand to stores.
// The resulting code is the same length as the expansion, but it doesn't
// use memory operations.
@ -6637,6 +6688,8 @@ const char *LoongArchTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(VREPLVEI)
NODE_NAME_CASE(VREPLGR2VR)
NODE_NAME_CASE(XVPERMI)
NODE_NAME_CASE(XVREPLVE0)
NODE_NAME_CASE(XVREPLVE0Q)
NODE_NAME_CASE(VPICK_SEXT_ELT)
NODE_NAME_CASE(VPICK_ZEXT_ELT)
NODE_NAME_CASE(VREPLVE)

View File

@ -141,6 +141,8 @@ enum NodeType : unsigned {
VREPLVEI,
VREPLGR2VR,
XVPERMI,
XVREPLVE0,
XVREPLVE0Q,
// Extended vector element extraction
VPICK_SEXT_ELT,

View File

@ -10,8 +10,13 @@
//
//===----------------------------------------------------------------------===//
def SDT_LoongArchXVREPLVE0 : SDTypeProfile<1, 1, [SDTCisVec<0>,
SDTCisSameAs<0, 1>]>;
// Target nodes.
def loongarch_xvpermi: SDNode<"LoongArchISD::XVPERMI", SDT_LoongArchV1RUimm>;
def loongarch_xvreplve0: SDNode<"LoongArchISD::XVREPLVE0", SDT_LoongArchXVREPLVE0>;
def loongarch_xvreplve0q: SDNode<"LoongArchISD::XVREPLVE0Q", SDT_LoongArchXVREPLVE0>;
def loongarch_xvmskltz: SDNode<"LoongArchISD::XVMSKLTZ", SDT_LoongArchVMSKCOND>;
def loongarch_xvmskgez: SDNode<"LoongArchISD::XVMSKGEZ", SDT_LoongArchVMSKCOND>;
def loongarch_xvmskeqz: SDNode<"LoongArchISD::XVMSKEQZ", SDT_LoongArchVMSKCOND>;
@ -1852,11 +1857,26 @@ def : Pat<(loongarch_xvpermi v4i64:$xj, immZExt8: $ui8),
def : Pat<(loongarch_xvpermi v4f64:$xj, immZExt8: $ui8),
(XVPERMI_D v4f64:$xj, immZExt8: $ui8)>;
// XVREPLVE0_{W/D}
// XVREPLVE0_{B/H/W/D/Q}
def : Pat<(loongarch_xvreplve0 v32i8:$xj),
(XVREPLVE0_B v32i8:$xj)>;
def : Pat<(loongarch_xvreplve0 v16i16:$xj),
(XVREPLVE0_H v16i16:$xj)>;
def : Pat<(loongarch_xvreplve0 v8i32:$xj),
(XVREPLVE0_W v8i32:$xj)>;
def : Pat<(loongarch_xvreplve0 v4i64:$xj),
(XVREPLVE0_D v4i64:$xj)>;
def : Pat<(loongarch_xvreplve0 v8f32:$xj),
(XVREPLVE0_W v8f32:$xj)>;
def : Pat<(loongarch_xvreplve0 v4f64:$xj),
(XVREPLVE0_D v4f64:$xj)>;
def : Pat<(lasxsplatf32 FPR32:$fj),
(XVREPLVE0_W (SUBREG_TO_REG (i64 0), FPR32:$fj, sub_32))>;
def : Pat<(lasxsplatf64 FPR64:$fj),
(XVREPLVE0_D (SUBREG_TO_REG (i64 0), FPR64:$fj, sub_64))>;
foreach vt = [v32i8, v16i16, v8i32, v4i64, v8f32, v4f64] in
def : Pat<(vt (loongarch_xvreplve0q LASX256:$xj)),
(XVREPLVE0_Q LASX256:$xj)>;
// VSTELM
defm : VstelmPat<truncstorei8, v32i8, XVSTELM_B, simm8, uimm5>;