Wasm fmuladd relaxed (#163177)

Reland #161355, after fixing up the cross-projects-tests for the wasm
simd intrinsics.

Original commit message:
Lower v4f32 and v2f64 fmuladd calls to relaxed_madd instructions.
If we have FP16, then lower v8f16 fmuladds to FMA.

I've introduced an ISD node for fmuladd to maintain the rounding
ambiguity through legalization / combine / isel.
This commit is contained in:
Sam Parker 2025-10-13 16:50:53 +01:00 committed by GitHub
parent 095cad6add
commit 1820102167
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 1447 additions and 53 deletions

View File

@ -1511,13 +1511,13 @@ v128_t test_f16x8_convert_u16x8(v128_t a) {
}
// CHECK-LABEL: test_f16x8_relaxed_madd:
// CHECK: f16x8.relaxed_madd{{$}}
// CHECK: f16x8.madd{{$}}
v128_t test_f16x8_relaxed_madd(v128_t a, v128_t b, v128_t c) {
return wasm_f16x8_relaxed_madd(a, b, c);
}
// CHECK-LABEL: test_f16x8_relaxed_nmadd:
// CHECK: f16x8.relaxed_nmadd{{$}}
// CHECK: f16x8.nmadd{{$}}
v128_t test_f16x8_relaxed_nmadd(v128_t a, v128_t b, v128_t c) {
return wasm_f16x8_relaxed_nmadd(a, b, c);
}

View File

@ -514,6 +514,12 @@ enum NodeType {
/// separately rounded operations.
FMAD,
/// FMULADD - Performs a * b + c, with, or without, intermediate rounding.
/// It is expected that this will be illegal for most targets, as it usually
/// makes sense to split this or use an FMA. But some targets, such as
/// WebAssembly, can directly support these semantics.
FMULADD,
/// FCOPYSIGN(X, Y) - Return the value of X with the sign of Y. NOTE: This
/// DAG node does not require that X and Y have the same type, just that
/// they are both floating point. X and the result must have the same type.

View File

@ -535,6 +535,7 @@ def fdiv : SDNode<"ISD::FDIV" , SDTFPBinOp>;
def frem : SDNode<"ISD::FREM" , SDTFPBinOp>;
def fma : SDNode<"ISD::FMA" , SDTFPTernaryOp, [SDNPCommutative]>;
def fmad : SDNode<"ISD::FMAD" , SDTFPTernaryOp, [SDNPCommutative]>;
def fmuladd : SDNode<"ISD::FMULADD" , SDTFPTernaryOp, [SDNPCommutative]>;
def fabs : SDNode<"ISD::FABS" , SDTFPUnaryOp>;
def fminnum : SDNode<"ISD::FMINNUM" , SDTFPBinOp,
[SDNPCommutative, SDNPAssociative]>;

View File

@ -509,6 +509,7 @@ namespace {
SDValue visitFMUL(SDNode *N);
template <class MatchContextClass> SDValue visitFMA(SDNode *N);
SDValue visitFMAD(SDNode *N);
SDValue visitFMULADD(SDNode *N);
SDValue visitFDIV(SDNode *N);
SDValue visitFREM(SDNode *N);
SDValue visitFSQRT(SDNode *N);
@ -1991,6 +1992,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::FMUL: return visitFMUL(N);
case ISD::FMA: return visitFMA<EmptyMatchContext>(N);
case ISD::FMAD: return visitFMAD(N);
case ISD::FMULADD: return visitFMULADD(N);
case ISD::FDIV: return visitFDIV(N);
case ISD::FREM: return visitFREM(N);
case ISD::FSQRT: return visitFSQRT(N);
@ -18444,6 +18446,21 @@ SDValue DAGCombiner::visitFMAD(SDNode *N) {
return SDValue();
}
SDValue DAGCombiner::visitFMULADD(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
SDValue N2 = N->getOperand(2);
EVT VT = N->getValueType(0);
SDLoc DL(N);
// Constant fold FMULADD.
if (SDValue C =
DAG.FoldConstantArithmetic(ISD::FMULADD, DL, VT, {N0, N1, N2}))
return C;
return SDValue();
}
// Combine multiple FDIVs with the same divisor into multiple FMULs by the
// reciprocal.
// E.g., (a / D; b / D;) -> (recip = 1.0 / D; a * recip; b * recip)

View File

@ -5786,6 +5786,7 @@ bool SelectionDAG::canCreateUndefOrPoison(SDValue Op, const APInt &DemandedElts,
case ISD::FCOPYSIGN:
case ISD::FMA:
case ISD::FMAD:
case ISD::FMULADD:
case ISD::FP_EXTEND:
case ISD::FP_TO_SINT_SAT:
case ISD::FP_TO_UINT_SAT:
@ -5904,6 +5905,7 @@ bool SelectionDAG::isKnownNeverNaN(SDValue Op, const APInt &DemandedElts,
case ISD::FCOSH:
case ISD::FTANH:
case ISD::FMA:
case ISD::FMULADD:
case ISD::FMAD: {
if (SNaN)
return true;
@ -7231,7 +7233,7 @@ SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, const SDLoc &DL,
}
// Handle fma/fmad special cases.
if (Opcode == ISD::FMA || Opcode == ISD::FMAD) {
if (Opcode == ISD::FMA || Opcode == ISD::FMAD || Opcode == ISD::FMULADD) {
assert(VT.isFloatingPoint() && "This operator only applies to FP types!");
assert(Ops[0].getValueType() == VT && Ops[1].getValueType() == VT &&
Ops[2].getValueType() == VT && "FMA types must match!");
@ -7242,7 +7244,7 @@ SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, const SDLoc &DL,
APFloat V1 = C1->getValueAPF();
const APFloat &V2 = C2->getValueAPF();
const APFloat &V3 = C3->getValueAPF();
if (Opcode == ISD::FMAD) {
if (Opcode == ISD::FMAD || Opcode == ISD::FMULADD) {
V1.multiply(V2, APFloat::rmNearestTiesToEven);
V1.add(V3, APFloat::rmNearestTiesToEven);
} else

View File

@ -6996,6 +6996,13 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
getValue(I.getArgOperand(0)),
getValue(I.getArgOperand(1)),
getValue(I.getArgOperand(2)), Flags));
} else if (TLI.isOperationLegalOrCustom(ISD::FMULADD, VT)) {
// TODO: Support splitting the vector.
setValue(&I, DAG.getNode(ISD::FMULADD, sdl,
getValue(I.getArgOperand(0)).getValueType(),
getValue(I.getArgOperand(0)),
getValue(I.getArgOperand(1)),
getValue(I.getArgOperand(2)), Flags));
} else {
// TODO: Intrinsic calls should have fast-math-flags.
SDValue Mul = DAG.getNode(

View File

@ -310,6 +310,7 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
case ISD::FMA: return "fma";
case ISD::STRICT_FMA: return "strict_fma";
case ISD::FMAD: return "fmad";
case ISD::FMULADD: return "fmuladd";
case ISD::FREM: return "frem";
case ISD::STRICT_FREM: return "strict_frem";
case ISD::FCOPYSIGN: return "fcopysign";

View File

@ -7676,6 +7676,7 @@ SDValue TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG,
break;
}
case ISD::FMA:
case ISD::FMULADD:
case ISD::FMAD: {
if (!Flags.hasNoSignedZeros())
break;

View File

@ -815,7 +815,8 @@ void TargetLoweringBase::initActions() {
ISD::FTAN, ISD::FACOS,
ISD::FASIN, ISD::FATAN,
ISD::FCOSH, ISD::FSINH,
ISD::FTANH, ISD::FATAN2},
ISD::FTANH, ISD::FATAN2,
ISD::FMULADD},
VT, Expand);
// Overflow operations default to expand

View File

@ -317,6 +317,15 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
setOperationAction(ISD::ZERO_EXTEND_VECTOR_INREG, T, Custom);
}
if (Subtarget->hasFP16()) {
setOperationAction(ISD::FMA, MVT::v8f16, Legal);
}
if (Subtarget->hasRelaxedSIMD()) {
setOperationAction(ISD::FMULADD, MVT::v4f32, Legal);
setOperationAction(ISD::FMULADD, MVT::v2f64, Legal);
}
// Partial MLA reductions.
for (auto Op : {ISD::PARTIAL_REDUCE_SMLA, ISD::PARTIAL_REDUCE_UMLA}) {
setPartialReduceMLAAction(Op, MVT::v4i32, MVT::v16i8, Legal);
@ -1120,6 +1129,18 @@ WebAssemblyTargetLowering::getPreferredVectorAction(MVT VT) const {
return TargetLoweringBase::getPreferredVectorAction(VT);
}
bool WebAssemblyTargetLowering::isFMAFasterThanFMulAndFAdd(
const MachineFunction &MF, EVT VT) const {
if (!Subtarget->hasFP16() || !VT.isVector())
return false;
EVT ScalarVT = VT.getScalarType();
if (!ScalarVT.isSimple())
return false;
return ScalarVT.getSimpleVT().SimpleTy == MVT::f16;
}
bool WebAssemblyTargetLowering::shouldSimplifyDemandedVectorElts(
SDValue Op, const TargetLoweringOpt &TLO) const {
// ISel process runs DAGCombiner after legalization; this step is called

View File

@ -81,6 +81,8 @@ private:
TargetLoweringBase::LegalizeTypeAction
getPreferredVectorAction(MVT VT) const override;
bool isFMAFasterThanFMulAndFAdd(const MachineFunction &MF,
EVT VT) const override;
SDValue LowerCall(CallLoweringInfo &CLI,
SmallVectorImpl<SDValue> &InVals) const override;

View File

@ -1626,7 +1626,8 @@ defm "" : RelaxedConvert<I32x4, F64x2, int_wasm_relaxed_trunc_unsigned_zero,
// Relaxed (Negative) Multiply-Add (madd/nmadd)
//===----------------------------------------------------------------------===//
multiclass SIMDMADD<Vec vec, bits<32> simdopA, bits<32> simdopS, list<Predicate> reqs> {
multiclass RELAXED_SIMDMADD<Vec vec, bits<32> simdopA, bits<32> simdopS,
list<Predicate> reqs> {
defm MADD_#vec :
SIMD_I<(outs V128:$dst), (ins V128:$a, V128:$b, V128:$c), (outs), (ins),
[(set (vec.vt V128:$dst), (int_wasm_relaxed_madd
@ -1640,16 +1641,46 @@ multiclass SIMDMADD<Vec vec, bits<32> simdopA, bits<32> simdopS, list<Predicate>
vec.prefix#".relaxed_nmadd\t$dst, $a, $b, $c",
vec.prefix#".relaxed_nmadd", simdopS, reqs>;
def : Pat<(fadd_contract (vec.vt V128:$a), (fmul_contract (vec.vt V128:$b), (vec.vt V128:$c))),
(!cast<Instruction>("MADD_"#vec) V128:$a, V128:$b, V128:$c)>, Requires<[HasRelaxedSIMD]>;
def : Pat<(fadd_contract (fmul_contract (vec.vt V128:$a), (vec.vt V128:$b)), (vec.vt V128:$c)),
(!cast<Instruction>("MADD_"#vec) V128:$a, V128:$b, V128:$c)>, Requires<reqs>;
def : Pat<(fmuladd (vec.vt V128:$a), (vec.vt V128:$b), (vec.vt V128:$c)),
(!cast<Instruction>("MADD_"#vec) V128:$a, V128:$b, V128:$c)>, Requires<reqs>;
def : Pat<(fsub_contract (vec.vt V128:$a), (fmul_contract (vec.vt V128:$b), (vec.vt V128:$c))),
(!cast<Instruction>("NMADD_"#vec) V128:$a, V128:$b, V128:$c)>, Requires<[HasRelaxedSIMD]>;
def : Pat<(fsub_contract (vec.vt V128:$c), (fmul_contract (vec.vt V128:$a), (vec.vt V128:$b))),
(!cast<Instruction>("NMADD_"#vec) V128:$a, V128:$b, V128:$c)>, Requires<reqs>;
def : Pat<(fmuladd (fneg (vec.vt V128:$a)), (vec.vt V128:$b), (vec.vt V128:$c)),
(!cast<Instruction>("NMADD_"#vec) V128:$a, V128:$b, V128:$c)>, Requires<reqs>;
}
defm "" : SIMDMADD<F32x4, 0x105, 0x106, [HasRelaxedSIMD]>;
defm "" : SIMDMADD<F64x2, 0x107, 0x108, [HasRelaxedSIMD]>;
defm "" : SIMDMADD<F16x8, 0x14e, 0x14f, [HasFP16]>;
defm "" : RELAXED_SIMDMADD<F32x4, 0x105, 0x106, [HasRelaxedSIMD]>;
defm "" : RELAXED_SIMDMADD<F64x2, 0x107, 0x108, [HasRelaxedSIMD]>;
//===----------------------------------------------------------------------===//
// FP16 (Negative) Multiply-Add (madd/nmadd)
//===----------------------------------------------------------------------===//
multiclass HALF_PRECISION_SIMDMADD<Vec vec, bits<32> simdopA, bits<32> simdopS,
list<Predicate> reqs> {
defm MADD_#vec :
SIMD_I<(outs V128:$dst), (ins V128:$a, V128:$b, V128:$c), (outs), (ins),
[(set (vec.vt V128:$dst), (fma
(vec.vt V128:$a), (vec.vt V128:$b), (vec.vt V128:$c)))],
vec.prefix#".madd\t$dst, $a, $b, $c",
vec.prefix#".madd", simdopA, reqs>;
defm NMADD_#vec :
SIMD_I<(outs V128:$dst), (ins V128:$a, V128:$b, V128:$c), (outs), (ins),
[(set (vec.vt V128:$dst), (fma
(fneg (vec.vt V128:$a)), (vec.vt V128:$b), (vec.vt V128:$c)))],
vec.prefix#".nmadd\t$dst, $a, $b, $c",
vec.prefix#".nmadd", simdopS, reqs>;
}
defm "" : HALF_PRECISION_SIMDMADD<F16x8, 0x14e, 0x14f, [HasFP16]>;
// TODO: I think separate intrinsics should be introduced for these FP16 operations.
def : Pat<(v8f16 (int_wasm_relaxed_madd (v8f16 V128:$a), (v8f16 V128:$b), (v8f16 V128:$c))),
(MADD_F16x8 V128:$a, V128:$b, V128:$c)>;
def : Pat<(v8f16 (int_wasm_relaxed_nmadd (v8f16 V128:$a), (v8f16 V128:$b), (v8f16 V128:$c))),
(NMADD_F16x8 V128:$a, V128:$b, V128:$c)>;
//===----------------------------------------------------------------------===//
// Laneselect

File diff suppressed because it is too large Load Diff

View File

@ -27,7 +27,7 @@ define <4 x float> @fsub_fmul_contract_4xf32(<4 x float> %a, <4 x float> %b, <4
; RELAXED-LABEL: fsub_fmul_contract_4xf32:
; RELAXED: .functype fsub_fmul_contract_4xf32 (v128, v128, v128) -> (v128)
; RELAXED-NEXT: # %bb.0:
; RELAXED-NEXT: f32x4.relaxed_nmadd $push0=, $2, $1, $0
; RELAXED-NEXT: f32x4.relaxed_nmadd $push0=, $1, $0, $2
; RELAXED-NEXT: return $pop0
;
; STRICT-LABEL: fsub_fmul_contract_4xf32:
@ -46,15 +46,14 @@ define <8 x half> @fsub_fmul_contract_8xf16(<8 x half> %a, <8 x half> %b, <8 x h
; RELAXED-LABEL: fsub_fmul_contract_8xf16:
; RELAXED: .functype fsub_fmul_contract_8xf16 (v128, v128, v128) -> (v128)
; RELAXED-NEXT: # %bb.0:
; RELAXED-NEXT: f16x8.relaxed_nmadd $push0=, $2, $1, $0
; RELAXED-NEXT: f16x8.nmadd $push0=, $1, $0, $2
; RELAXED-NEXT: return $pop0
;
; STRICT-LABEL: fsub_fmul_contract_8xf16:
; STRICT: .functype fsub_fmul_contract_8xf16 (v128, v128, v128) -> (v128)
; STRICT-NEXT: # %bb.0:
; STRICT-NEXT: f16x8.mul $push0=, $1, $0
; STRICT-NEXT: f16x8.sub $push1=, $2, $pop0
; STRICT-NEXT: return $pop1
; STRICT-NEXT: f16x8.nmadd $push0=, $1, $0, $2
; STRICT-NEXT: return $pop0
%mul = fmul contract <8 x half> %b, %a
%sub = fsub contract <8 x half> %c, %mul
ret <8 x half> %sub
@ -84,9 +83,9 @@ define <8 x float> @fsub_fmul_contract_8xf32(<8 x float> %a, <8 x float> %b, <8
; RELAXED-LABEL: fsub_fmul_contract_8xf32:
; RELAXED: .functype fsub_fmul_contract_8xf32 (i32, v128, v128, v128, v128, v128, v128) -> ()
; RELAXED-NEXT: # %bb.0:
; RELAXED-NEXT: f32x4.relaxed_nmadd $push0=, $6, $4, $2
; RELAXED-NEXT: f32x4.relaxed_nmadd $push0=, $4, $2, $6
; RELAXED-NEXT: v128.store 16($0), $pop0
; RELAXED-NEXT: f32x4.relaxed_nmadd $push1=, $5, $3, $1
; RELAXED-NEXT: f32x4.relaxed_nmadd $push1=, $3, $1, $5
; RELAXED-NEXT: v128.store 0($0), $pop1
; RELAXED-NEXT: return
;
@ -110,7 +109,7 @@ define <2 x double> @fsub_fmul_contract_2xf64(<2 x double> %a, <2 x double> %b,
; RELAXED-LABEL: fsub_fmul_contract_2xf64:
; RELAXED: .functype fsub_fmul_contract_2xf64 (v128, v128, v128) -> (v128)
; RELAXED-NEXT: # %bb.0:
; RELAXED-NEXT: f64x2.relaxed_nmadd $push0=, $2, $1, $0
; RELAXED-NEXT: f64x2.relaxed_nmadd $push0=, $1, $0, $2
; RELAXED-NEXT: return $pop0
;
; STRICT-LABEL: fsub_fmul_contract_2xf64:
@ -143,3 +142,55 @@ define float @fsub_fmul_contract_f32(float %a, float %b, float %c) {
ret float %sub
}
define <8 x half> @fmuladd_8xf16(<8 x half> %a, <8 x half> %b, <8 x half> %c) {
; RELAXED-LABEL: fmuladd_8xf16:
; RELAXED: .functype fmuladd_8xf16 (v128, v128, v128) -> (v128)
; RELAXED-NEXT: # %bb.0:
; RELAXED-NEXT: f16x8.nmadd $push0=, $0, $1, $2
; RELAXED-NEXT: return $pop0
;
; STRICT-LABEL: fmuladd_8xf16:
; STRICT: .functype fmuladd_8xf16 (v128, v128, v128) -> (v128)
; STRICT-NEXT: # %bb.0:
; STRICT-NEXT: f16x8.nmadd $push0=, $0, $1, $2
; STRICT-NEXT: return $pop0
%fneg = fneg <8 x half> %a
%fma = call <8 x half> @llvm.fmuladd(<8 x half> %fneg, <8 x half> %b, <8 x half> %c)
ret <8 x half> %fma
}
define <4 x float> @fmuladd_4xf32(<4 x float> %a, <4 x float> %b, <4 x float> %c) {
; RELAXED-LABEL: fmuladd_4xf32:
; RELAXED: .functype fmuladd_4xf32 (v128, v128, v128) -> (v128)
; RELAXED-NEXT: # %bb.0:
; RELAXED-NEXT: f32x4.relaxed_nmadd $push0=, $0, $1, $2
; RELAXED-NEXT: return $pop0
;
; STRICT-LABEL: fmuladd_4xf32:
; STRICT: .functype fmuladd_4xf32 (v128, v128, v128) -> (v128)
; STRICT-NEXT: # %bb.0:
; STRICT-NEXT: f32x4.mul $push0=, $0, $1
; STRICT-NEXT: f32x4.sub $push1=, $2, $pop0
; STRICT-NEXT: return $pop1
%fneg = fneg <4 x float> %a
%fma = call <4 x float> @llvm.fmuladd(<4 x float> %fneg, <4 x float> %b, <4 x float> %c)
ret <4 x float> %fma
}
define <2 x double> @fmuladd_2xf64(<2 x double> %a, <2 x double> %b, <2 x double> %c) {
; RELAXED-LABEL: fmuladd_2xf64:
; RELAXED: .functype fmuladd_2xf64 (v128, v128, v128) -> (v128)
; RELAXED-NEXT: # %bb.0:
; RELAXED-NEXT: f64x2.relaxed_nmadd $push0=, $0, $1, $2
; RELAXED-NEXT: return $pop0
;
; STRICT-LABEL: fmuladd_2xf64:
; STRICT: .functype fmuladd_2xf64 (v128, v128, v128) -> (v128)
; STRICT-NEXT: # %bb.0:
; STRICT-NEXT: f64x2.mul $push0=, $0, $1
; STRICT-NEXT: f64x2.sub $push1=, $2, $pop0
; STRICT-NEXT: return $pop1
%fneg = fneg <2 x double> %a
%fma = call <2 x double> @llvm.fmuladd(<2 x double> %fneg, <2 x double> %b, <2 x double> %c)
ret <2 x double> %fma
}

View File

@ -917,11 +917,11 @@ main:
# CHECK: f16x8.nearest # encoding: [0xfd,0xb6,0x02]
f16x8.nearest
# CHECK: f16x8.relaxed_madd # encoding: [0xfd,0xce,0x02]
f16x8.relaxed_madd
# CHECK: f16x8.madd # encoding: [0xfd,0xce,0x02]
f16x8.madd
# CHECK: f16x8.relaxed_nmadd # encoding: [0xfd,0xcf,0x02]
f16x8.relaxed_nmadd
# CHECK: f16x8.nmadd # encoding: [0xfd,0xcf,0x02]
f16x8.nmadd
# CHECK: i16x8.trunc_sat_f16x8_s # encoding: [0xfd,0xc5,0x02]
i16x8.trunc_sat_f16x8_s