[RISCV] Combine a vsse from a vsseg with one active segment (#151198)

This is a rewrite of the current strided store optimization to be a DAG
combine. This allows it to kick in slightly more broadly, in particular
for the scalable lowering paths.
This commit is contained in:
Philip Reames 2025-07-29 14:05:48 -07:00 committed by GitHub
parent 616cef0883
commit ce23830508
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 91 additions and 28 deletions

View File

@ -20751,6 +20751,53 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
return DAG.getAllOnesConstant(DL, VT);
return DAG.getConstant(0, DL, VT);
}
case Intrinsic::riscv_vsseg2_mask:
case Intrinsic::riscv_vsseg3_mask:
case Intrinsic::riscv_vsseg4_mask:
case Intrinsic::riscv_vsseg5_mask:
case Intrinsic::riscv_vsseg6_mask:
case Intrinsic::riscv_vsseg7_mask:
case Intrinsic::riscv_vsseg8_mask: {
SDValue Tuple = N->getOperand(2);
unsigned NF = Tuple.getValueType().getRISCVVectorTupleNumFields();
if (Subtarget.hasOptimizedSegmentLoadStore(NF) || !Tuple.hasOneUse() ||
Tuple.getOpcode() != RISCVISD::TUPLE_INSERT ||
!Tuple.getOperand(0).isUndef())
return SDValue();
SDValue Val = Tuple.getOperand(1);
unsigned Idx = Tuple.getConstantOperandVal(2);
unsigned SEW = Val.getValueType().getScalarSizeInBits();
assert(Log2_64(SEW) == N->getConstantOperandVal(6) &&
"Type mismatch without bitcast?");
unsigned Stride = SEW / 8 * NF;
unsigned Offset = SEW / 8 * Idx;
SDValue Ops[] = {
/*Chain=*/N->getOperand(0),
/*IntID=*/
DAG.getTargetConstant(Intrinsic::riscv_vsse_mask, DL, XLenVT),
/*StoredVal=*/Val,
/*Ptr=*/
DAG.getNode(ISD::ADD, DL, XLenVT, N->getOperand(3),
DAG.getConstant(Offset, DL, XLenVT)),
/*Stride=*/DAG.getConstant(Stride, DL, XLenVT),
/*Mask=*/N->getOperand(4),
/*VL=*/N->getOperand(5)};
auto *OldMemSD = cast<MemIntrinsicSDNode>(N);
// Match getTgtMemIntrinsic for non-unit stride case
EVT MemVT = OldMemSD->getMemoryVT().getScalarType();
MachineFunction &MF = DAG.getMachineFunction();
MachineMemOperand *MMO = MF.getMachineMemOperand(
OldMemSD->getMemOperand(), Offset, MemoryLocation::UnknownSize);
SDVTList VTs = DAG.getVTList(MVT::Other);
return DAG.getMemIntrinsicNode(ISD::INTRINSIC_VOID, DL, VTs, Ops, MemVT,
MMO);
}
}
}
case ISD::EXPERIMENTAL_VP_REVERSE:
@ -20899,6 +20946,12 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
DAG.ReplaceAllUsesOfValueWith(Tuple.getValue(1), Result.getValue(1));
return Result.getValue(0);
}
case RISCVISD::TUPLE_INSERT: {
// tuple_insert tuple, undef, idx -> tuple
if (N->getOperand(1).isUndef())
return N->getOperand(0);
break;
}
}
return SDValue();

View File

@ -266,33 +266,6 @@ bool RISCVTargetLowering::lowerInterleavedStore(Instruction *Store,
if (!isLegalInterleavedAccessType(VTy, Factor, Alignment, AS, DL))
return false;
unsigned Index;
// If the segment store only has one active lane (i.e. the interleave is
// just a spread shuffle), we can use a strided store instead. This will
// be equally fast, and create less vector register pressure.
if (!Subtarget.hasOptimizedSegmentLoadStore(Factor) &&
isSpreadMask(Mask, Factor, Index)) {
unsigned ScalarSizeInBytes =
DL.getTypeStoreSize(ShuffleVTy->getElementType());
Value *Data = SVI->getOperand(0);
Data = Builder.CreateExtractVector(VTy, Data, uint64_t(0));
Value *Stride = ConstantInt::get(XLenTy, Factor * ScalarSizeInBytes);
Value *Offset = ConstantInt::get(XLenTy, Index * ScalarSizeInBytes);
Value *BasePtr = Builder.CreatePtrAdd(Ptr, Offset);
// For rv64, need to truncate i64 to i32 to match signature. As VL is at
// most the number of active lanes (which is bounded by i32) this is safe.
VL = Builder.CreateTrunc(VL, Builder.getInt32Ty());
CallInst *CI =
Builder.CreateIntrinsic(Intrinsic::experimental_vp_strided_store,
{VTy, BasePtr->getType(), Stride->getType()},
{Data, BasePtr, Stride, LaneMask, VL});
Alignment = commonAlignment(Alignment, Index * ScalarSizeInBytes);
CI->addParamAttr(1,
Attribute::getWithAlignment(CI->getContext(), Alignment));
return true;
}
Function *VssegNFunc = Intrinsic::getOrInsertDeclaration(
Store->getModule(), FixedVssegIntrIds[Factor - 2], {VTy, PtrTy, XLenTy});

View File

@ -1883,7 +1883,8 @@ define void @store_factor4_one_active_slidedown(ptr %ptr, <4 x i32> %v) {
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 4, e32, m1, ta, ma
; CHECK-NEXT: vslidedown.vi v8, v8, 1
; CHECK-NEXT: vsseg4e32.v v8, (a0)
; CHECK-NEXT: li a1, 16
; CHECK-NEXT: vsse32.v v8, (a0), a1
; CHECK-NEXT: ret
%v0 = shufflevector <4 x i32> %v, <4 x i32> poison, <16 x i32> <i32 1, i32 undef, i32 undef, i32 undef, i32 2, i32 undef, i32 undef, i32 undef, i32 3, i32 undef, i32 undef, i32 undef, i32 4, i32 undef, i32 undef, i32 undef>
store <16 x i32> %v0, ptr %ptr

View File

@ -326,3 +326,39 @@ define void @masked_store_factor3_masked(<vscale x 2 x i32> %a, <vscale x 2 x i3
call void @llvm.masked.store(<vscale x 6 x i32> %v, ptr %p, i32 4, <vscale x 6 x i1> %interleaved.mask)
ret void
}
define void @store_factor2_oneactive(<vscale x 2 x i32> %a, ptr %p) {
; CHECK-LABEL: store_factor2_oneactive:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli a1, zero, e32, m1, ta, ma
; CHECK-NEXT: vsseg2e32.v v8, (a0)
; CHECK-NEXT: ret
%v = call <vscale x 4 x i32> @llvm.vector.interleave2(<vscale x 2 x i32> %a, <vscale x 2 x i32> poison)
store <vscale x 4 x i32> %v, ptr %p
ret void
}
define void @store_factor3_oneactive(<vscale x 2 x i32> %a, ptr %p) {
; CHECK-LABEL: store_factor3_oneactive:
; CHECK: # %bb.0:
; CHECK-NEXT: li a1, 12
; CHECK-NEXT: vsetvli a2, zero, e32, m1, ta, ma
; CHECK-NEXT: vsse32.v v8, (a0), a1
; CHECK-NEXT: ret
%v = call <vscale x 6 x i32> @llvm.vector.interleave3(<vscale x 2 x i32> %a, <vscale x 2 x i32> poison, <vscale x 2 x i32> poison)
store <vscale x 6 x i32> %v, ptr %p
ret void
}
define void @store_factor7_oneactive(<vscale x 2 x i32> %a, ptr %p) {
; CHECK-LABEL: store_factor7_oneactive:
; CHECK: # %bb.0:
; CHECK-NEXT: addi a0, a0, 24
; CHECK-NEXT: li a1, 28
; CHECK-NEXT: vsetvli a2, zero, e32, m1, ta, ma
; CHECK-NEXT: vsse32.v v8, (a0), a1
; CHECK-NEXT: ret
%v = call <vscale x 14 x i32> @llvm.vector.interleave7(<vscale x 2 x i32> poison, <vscale x 2 x i32> poison, <vscale x 2 x i32> poison, <vscale x 2 x i32> poison, <vscale x 2 x i32> poison, <vscale x 2 x i32> poison, <vscale x 2 x i32> %a)
store <vscale x 14 x i32> %v, ptr %p
ret void
}