[RISCV] Add intrinsics for strided segment loads with fixed vectors (#151611)

These intrinsics are the strided version of `llvm.riscv.segN.load`
intrinsics.
This commit is contained in:
Min-Yih Hsu 2025-08-01 10:13:46 -07:00 committed by GitHub
parent 6c072c06cc
commit 401e72c830
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 189 additions and 50 deletions

View File

@ -1717,6 +1717,16 @@ let TargetPrefix = "riscv" in {
llvm_anyint_ty],
[NoCapture<ArgIndex<0>>, IntrReadMem]>;
// Input: (pointer, offset, mask, vl)
def int_riscv_sseg # nf # _load_mask
: DefaultAttrsIntrinsic<!listconcat([llvm_anyvector_ty],
!listsplat(LLVMMatchType<0>,
!add(nf, -1))),
[llvm_anyptr_ty, llvm_anyint_ty,
LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>,
llvm_anyint_ty],
[NoCapture<ArgIndex<0>>, IntrReadMem]>;
// Input: (<stored values>..., pointer, mask, vl)
def int_riscv_seg # nf # _store_mask
: DefaultAttrsIntrinsic<[],

View File

@ -1819,6 +1819,13 @@ bool RISCVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
case Intrinsic::riscv_seg6_load_mask:
case Intrinsic::riscv_seg7_load_mask:
case Intrinsic::riscv_seg8_load_mask:
case Intrinsic::riscv_sseg2_load_mask:
case Intrinsic::riscv_sseg3_load_mask:
case Intrinsic::riscv_sseg4_load_mask:
case Intrinsic::riscv_sseg5_load_mask:
case Intrinsic::riscv_sseg6_load_mask:
case Intrinsic::riscv_sseg7_load_mask:
case Intrinsic::riscv_sseg8_load_mask:
return SetRVVLoadStoreInfo(/*PtrOp*/ 0, /*IsStore*/ false,
/*IsUnitStrided*/ false, /*UsePtrVal*/ true);
case Intrinsic::riscv_seg2_store_mask:
@ -10938,6 +10945,97 @@ static inline SDValue getVCIXISDNodeVOID(SDValue &Op, SelectionDAG &DAG,
return DAG.getNode(Type, SDLoc(Op), Op.getValueType(), Operands);
}
static SDValue
lowerFixedVectorSegLoadIntrinsics(unsigned IntNo, SDValue Op,
const RISCVSubtarget &Subtarget,
SelectionDAG &DAG) {
bool IsStrided;
switch (IntNo) {
case Intrinsic::riscv_seg2_load_mask:
case Intrinsic::riscv_seg3_load_mask:
case Intrinsic::riscv_seg4_load_mask:
case Intrinsic::riscv_seg5_load_mask:
case Intrinsic::riscv_seg6_load_mask:
case Intrinsic::riscv_seg7_load_mask:
case Intrinsic::riscv_seg8_load_mask:
IsStrided = false;
break;
case Intrinsic::riscv_sseg2_load_mask:
case Intrinsic::riscv_sseg3_load_mask:
case Intrinsic::riscv_sseg4_load_mask:
case Intrinsic::riscv_sseg5_load_mask:
case Intrinsic::riscv_sseg6_load_mask:
case Intrinsic::riscv_sseg7_load_mask:
case Intrinsic::riscv_sseg8_load_mask:
IsStrided = true;
break;
default:
llvm_unreachable("unexpected intrinsic ID");
};
static const Intrinsic::ID VlsegInts[7] = {
Intrinsic::riscv_vlseg2_mask, Intrinsic::riscv_vlseg3_mask,
Intrinsic::riscv_vlseg4_mask, Intrinsic::riscv_vlseg5_mask,
Intrinsic::riscv_vlseg6_mask, Intrinsic::riscv_vlseg7_mask,
Intrinsic::riscv_vlseg8_mask};
static const Intrinsic::ID VlssegInts[7] = {
Intrinsic::riscv_vlsseg2_mask, Intrinsic::riscv_vlsseg3_mask,
Intrinsic::riscv_vlsseg4_mask, Intrinsic::riscv_vlsseg5_mask,
Intrinsic::riscv_vlsseg6_mask, Intrinsic::riscv_vlsseg7_mask,
Intrinsic::riscv_vlsseg8_mask};
SDLoc DL(Op);
unsigned NF = Op->getNumValues() - 1;
assert(NF >= 2 && NF <= 8 && "Unexpected seg number");
MVT XLenVT = Subtarget.getXLenVT();
MVT VT = Op->getSimpleValueType(0);
MVT ContainerVT = ::getContainerForFixedLengthVector(DAG, VT, Subtarget);
unsigned Sz = NF * ContainerVT.getVectorMinNumElements() *
ContainerVT.getScalarSizeInBits();
EVT VecTupTy = MVT::getRISCVVectorTupleVT(Sz, NF);
// Operands: (chain, int_id, pointer, mask, vl) or
// (chain, int_id, pointer, offset, mask, vl)
SDValue VL = Op.getOperand(Op.getNumOperands() - 1);
SDValue Mask = Op.getOperand(Op.getNumOperands() - 2);
MVT MaskVT = Mask.getSimpleValueType();
MVT MaskContainerVT =
::getContainerForFixedLengthVector(DAG, MaskVT, Subtarget);
Mask = convertToScalableVector(MaskContainerVT, Mask, DAG, Subtarget);
SDValue IntID = DAG.getTargetConstant(
IsStrided ? VlssegInts[NF - 2] : VlsegInts[NF - 2], DL, XLenVT);
auto *Load = cast<MemIntrinsicSDNode>(Op);
SDVTList VTs = DAG.getVTList({VecTupTy, MVT::Other});
SmallVector<SDValue, 9> Ops = {
Load->getChain(),
IntID,
DAG.getUNDEF(VecTupTy),
Op.getOperand(2),
Mask,
VL,
DAG.getTargetConstant(
RISCVVType::TAIL_AGNOSTIC | RISCVVType::MASK_AGNOSTIC, DL, XLenVT),
DAG.getTargetConstant(Log2_64(VT.getScalarSizeInBits()), DL, XLenVT)};
// Insert the stride operand.
if (IsStrided)
Ops.insert(std::next(Ops.begin(), 4), Op.getOperand(3));
SDValue Result =
DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops,
Load->getMemoryVT(), Load->getMemOperand());
SmallVector<SDValue, 9> Results;
for (unsigned int RetIdx = 0; RetIdx < NF; RetIdx++) {
SDValue SubVec = DAG.getNode(RISCVISD::TUPLE_EXTRACT, DL, ContainerVT,
Result.getValue(0),
DAG.getTargetConstant(RetIdx, DL, MVT::i32));
Results.push_back(convertFromScalableVector(VT, SubVec, DAG, Subtarget));
}
Results.push_back(Result.getValue(1));
return DAG.getMergeValues(Results, DL);
}
SDValue RISCVTargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op,
SelectionDAG &DAG) const {
unsigned IntNo = Op.getConstantOperandVal(1);
@ -10950,57 +11048,16 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op,
case Intrinsic::riscv_seg5_load_mask:
case Intrinsic::riscv_seg6_load_mask:
case Intrinsic::riscv_seg7_load_mask:
case Intrinsic::riscv_seg8_load_mask: {
SDLoc DL(Op);
static const Intrinsic::ID VlsegInts[7] = {
Intrinsic::riscv_vlseg2_mask, Intrinsic::riscv_vlseg3_mask,
Intrinsic::riscv_vlseg4_mask, Intrinsic::riscv_vlseg5_mask,
Intrinsic::riscv_vlseg6_mask, Intrinsic::riscv_vlseg7_mask,
Intrinsic::riscv_vlseg8_mask};
unsigned NF = Op->getNumValues() - 1;
assert(NF >= 2 && NF <= 8 && "Unexpected seg number");
MVT XLenVT = Subtarget.getXLenVT();
MVT VT = Op->getSimpleValueType(0);
MVT ContainerVT = getContainerForFixedLengthVector(VT);
unsigned Sz = NF * ContainerVT.getVectorMinNumElements() *
ContainerVT.getScalarSizeInBits();
EVT VecTupTy = MVT::getRISCVVectorTupleVT(Sz, NF);
case Intrinsic::riscv_seg8_load_mask:
case Intrinsic::riscv_sseg2_load_mask:
case Intrinsic::riscv_sseg3_load_mask:
case Intrinsic::riscv_sseg4_load_mask:
case Intrinsic::riscv_sseg5_load_mask:
case Intrinsic::riscv_sseg6_load_mask:
case Intrinsic::riscv_sseg7_load_mask:
case Intrinsic::riscv_sseg8_load_mask:
return lowerFixedVectorSegLoadIntrinsics(IntNo, Op, Subtarget, DAG);
// Operands: (chain, int_id, pointer, mask, vl)
SDValue VL = Op.getOperand(Op.getNumOperands() - 1);
SDValue Mask = Op.getOperand(3);
MVT MaskVT = Mask.getSimpleValueType();
MVT MaskContainerVT =
::getContainerForFixedLengthVector(DAG, MaskVT, Subtarget);
Mask = convertToScalableVector(MaskContainerVT, Mask, DAG, Subtarget);
SDValue IntID = DAG.getTargetConstant(VlsegInts[NF - 2], DL, XLenVT);
auto *Load = cast<MemIntrinsicSDNode>(Op);
SDVTList VTs = DAG.getVTList({VecTupTy, MVT::Other});
SDValue Ops[] = {
Load->getChain(),
IntID,
DAG.getUNDEF(VecTupTy),
Op.getOperand(2),
Mask,
VL,
DAG.getTargetConstant(
RISCVVType::TAIL_AGNOSTIC | RISCVVType::MASK_AGNOSTIC, DL, XLenVT),
DAG.getTargetConstant(Log2_64(VT.getScalarSizeInBits()), DL, XLenVT)};
SDValue Result =
DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops,
Load->getMemoryVT(), Load->getMemOperand());
SmallVector<SDValue, 9> Results;
for (unsigned int RetIdx = 0; RetIdx < NF; RetIdx++) {
SDValue SubVec = DAG.getNode(RISCVISD::TUPLE_EXTRACT, DL, ContainerVT,
Result.getValue(0),
DAG.getTargetConstant(RetIdx, DL, MVT::i32));
Results.push_back(convertFromScalableVector(VT, SubVec, DAG, Subtarget));
}
Results.push_back(Result.getValue(1));
return DAG.getMergeValues(Results, DL);
}
case Intrinsic::riscv_sf_vc_v_x_se:
return getVCIXISDNodeWCHAIN(Op, DAG, RISCVISD::SF_VC_V_X_SE);
case Intrinsic::riscv_sf_vc_v_i_se:

View File

@ -0,0 +1,72 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple riscv64 -mattr=+zve64x,+zvl128b < %s | FileCheck %s
define {<8 x i8>, <8 x i8>} @load_factor2(ptr %ptr, i64 %stride) {
; CHECK-LABEL: load_factor2:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
; CHECK-NEXT: vlsseg2e8.v v8, (a0), a1
; CHECK-NEXT: ret
%1 = call { <8 x i8>, <8 x i8> } @llvm.riscv.sseg2.load.mask.v8i8.i64.i64(ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
ret {<8 x i8>, <8 x i8>} %1
}
define {<8 x i8>, <8 x i8>, <8 x i8>} @load_factor3(ptr %ptr, i64 %stride) {
; CHECK-LABEL: load_factor3:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
; CHECK-NEXT: vlsseg3e8.v v8, (a0), a1
; CHECK-NEXT: ret
%1 = call { <8 x i8>, <8 x i8>, <8 x i8> } @llvm.riscv.sseg3.load.mask.v8i8.i64.i64(ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
ret { <8 x i8>, <8 x i8>, <8 x i8> } %1
}
define {<8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>} @load_factor4(ptr %ptr, i64 %stride) {
; CHECK-LABEL: load_factor4:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
; CHECK-NEXT: vlsseg4e8.v v8, (a0), a1
; CHECK-NEXT: ret
%1 = call { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } @llvm.riscv.sseg4.load.mask.v8i8.i64.i64(ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
ret { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } %1
}
define {<8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>} @load_factor5(ptr %ptr, i64 %stride) {
; CHECK-LABEL: load_factor5:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
; CHECK-NEXT: vlsseg5e8.v v8, (a0), a1
; CHECK-NEXT: ret
%1 = call { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } @llvm.riscv.sseg5.load.mask.v8i8.i64.i64(ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
ret { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } %1
}
define {<8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>} @load_factor6(ptr %ptr, i64 %stride) {
; CHECK-LABEL: load_factor6:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
; CHECK-NEXT: vlsseg6e8.v v8, (a0), a1
; CHECK-NEXT: ret
%1 = call { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } @llvm.riscv.sseg6.load.mask.v8i8.i64.i64(ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
ret { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } %1
}
define {<8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>} @load_factor7(ptr %ptr, i64 %stride) {
; CHECK-LABEL: load_factor7:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
; CHECK-NEXT: vlsseg7e8.v v8, (a0), a1
; CHECK-NEXT: ret
%1 = call { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } @llvm.riscv.sseg7.load.mask.v8i8.i64.i64(ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
ret { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } %1
}
define {<8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>} @load_factor8(ptr %ptr, i64 %stride) {
; CHECK-LABEL: load_factor8:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
; CHECK-NEXT: vlsseg8e8.v v8, (a0), a1
; CHECK-NEXT: ret
%1 = call { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } @llvm.riscv.sseg8.load.mask.v8i8.i64.i64(ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
ret { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } %1
}