AMDGPU: Make copysign with matching v2f16/v2bf16 inputs legal (#142173)

Fixes #141931
This commit is contained in:
Matt Arsenault 2025-05-31 08:06:49 +02:00 committed by GitHub
parent 64e9a3f8f0
commit 4aa4005e04
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 1124 additions and 1281 deletions

View File

@ -756,6 +756,9 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
// allows matching fneg (fabs x) patterns) // allows matching fneg (fabs x) patterns)
setOperationAction(ISD::FABS, MVT::v2f16, Legal); setOperationAction(ISD::FABS, MVT::v2f16, Legal);
// Can do this in one BFI plus a constant materialize.
setOperationAction(ISD::FCOPYSIGN, {MVT::v2f16, MVT::v2bf16}, Custom);
setOperationAction({ISD::FMAXNUM, ISD::FMINNUM}, MVT::f16, Custom); setOperationAction({ISD::FMAXNUM, ISD::FMINNUM}, MVT::f16, Custom);
setOperationAction({ISD::FMAXNUM_IEEE, ISD::FMINNUM_IEEE}, MVT::f16, Legal); setOperationAction({ISD::FMAXNUM_IEEE, ISD::FMINNUM_IEEE}, MVT::f16, Legal);
@ -6088,6 +6091,8 @@ SDValue SITargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
case ISD::SADDSAT: case ISD::SADDSAT:
case ISD::SSUBSAT: case ISD::SSUBSAT:
return splitBinaryVectorOp(Op, DAG); return splitBinaryVectorOp(Op, DAG);
case ISD::FCOPYSIGN:
return lowerFCOPYSIGN(Op, DAG);
case ISD::MUL: case ISD::MUL:
return lowerMUL(Op, DAG); return lowerMUL(Op, DAG);
case ISD::SMULO: case ISD::SMULO:
@ -7115,6 +7120,32 @@ SDValue SITargetLowering::promoteUniformOpToI32(SDValue Op,
return DAG.getZExtOrTrunc(NewVal, DL, OpTy); return DAG.getZExtOrTrunc(NewVal, DL, OpTy);
} }
SDValue SITargetLowering::lowerFCOPYSIGN(SDValue Op, SelectionDAG &DAG) const {
SDValue Mag = Op.getOperand(0);
SDValue Sign = Op.getOperand(1);
EVT MagVT = Mag.getValueType();
EVT SignVT = Sign.getValueType();
assert(MagVT.isVector());
if (MagVT == SignVT)
return Op;
assert(MagVT.getVectorNumElements() == 2);
// fcopysign v2f16:mag, v2f32:sign ->
// fcopysign v2f16:mag, bitcast (trunc (bitcast sign to v2i32) to v2i16)
SDLoc SL(Op);
SDValue SignAsInt32 = DAG.getNode(ISD::BITCAST, SL, MVT::v2i32, Sign);
SDValue SignAsInt16 = DAG.getNode(ISD::TRUNCATE, SL, MVT::v2i16, SignAsInt32);
SDValue SignAsHalf16 = DAG.getNode(ISD::BITCAST, SL, MagVT, SignAsInt16);
return DAG.getNode(ISD::FCOPYSIGN, SL, MagVT, Mag, SignAsHalf16);
}
// Custom lowering for vector multiplications and s_mul_u64. // Custom lowering for vector multiplications and s_mul_u64.
SDValue SITargetLowering::lowerMUL(SDValue Op, SelectionDAG &DAG) const { SDValue SITargetLowering::lowerMUL(SDValue Op, SelectionDAG &DAG) const {
EVT VT = Op.getValueType(); EVT VT = Op.getValueType();

View File

@ -149,6 +149,7 @@ private:
SDValue lowerFMINIMUM_FMAXIMUM(SDValue Op, SelectionDAG &DAG) const; SDValue lowerFMINIMUM_FMAXIMUM(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerFLDEXP(SDValue Op, SelectionDAG &DAG) const; SDValue lowerFLDEXP(SDValue Op, SelectionDAG &DAG) const;
SDValue promoteUniformOpToI32(SDValue Op, DAGCombinerInfo &DCI) const; SDValue promoteUniformOpToI32(SDValue Op, DAGCombinerInfo &DCI) const;
SDValue lowerFCOPYSIGN(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerMUL(SDValue Op, SelectionDAG &DAG) const; SDValue lowerMUL(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerXMULO(SDValue Op, SelectionDAG &DAG) const; SDValue lowerXMULO(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerXMUL_LOHI(SDValue Op, SelectionDAG &DAG) const; SDValue lowerXMUL_LOHI(SDValue Op, SelectionDAG &DAG) const;

View File

@ -2062,6 +2062,16 @@ def : GCNPat <
>; >;
} // End foreach fp16vt = [f16, bf16] } // End foreach fp16vt = [f16, bf16]
foreach fp16vt = [v2f16, v2bf16] in {
def : GCNPat <
(fcopysign fp16vt:$src0, fp16vt:$src1),
(V_BFI_B32_e64 (S_MOV_B32 (i32 0x7fff7fff)), $src0, $src1)
>;
}
/********** ================== **********/ /********** ================== **********/
/********** Immediate Patterns **********/ /********** Immediate Patterns **********/
/********** ================== **********/ /********** ================== **********/

View File

@ -36,17 +36,12 @@ define <2 x half> @test_pown_reduced_fast_v2f16_known_odd(<2 x half> %x, <2 x i3
; GFX9-NEXT: v_cvt_f32_i32_e32 v2, v2 ; GFX9-NEXT: v_cvt_f32_i32_e32 v2, v2
; GFX9-NEXT: v_cvt_f32_i32_e32 v1, v1 ; GFX9-NEXT: v_cvt_f32_i32_e32 v1, v1
; GFX9-NEXT: v_and_b32_e32 v3, 0x7fff7fff, v0 ; GFX9-NEXT: v_and_b32_e32 v3, 0x7fff7fff, v0
; GFX9-NEXT: s_movk_i32 s4, 0x7fff ; GFX9-NEXT: s_mov_b32 s4, 0x7fff7fff
; GFX9-NEXT: v_cvt_f16_f32_e32 v2, v2 ; GFX9-NEXT: v_cvt_f16_f32_e32 v2, v2
; GFX9-NEXT: v_cvt_f16_f32_e32 v1, v1 ; GFX9-NEXT: v_cvt_f16_f32_e32 v1, v1
; GFX9-NEXT: v_pack_b32_f16 v1, v1, v2 ; GFX9-NEXT: v_pack_b32_f16 v1, v1, v2
; GFX9-NEXT: v_pk_mul_f16 v1, v3, v1 ; GFX9-NEXT: v_pk_mul_f16 v1, v3, v1
; GFX9-NEXT: v_bfi_b32 v2, s4, v1, v0
; GFX9-NEXT: v_lshrrev_b32_e32 v1, 16, v1
; GFX9-NEXT: v_lshrrev_b32_e32 v0, 16, v0
; GFX9-NEXT: v_bfi_b32 v0, s4, v1, v0 ; GFX9-NEXT: v_bfi_b32 v0, s4, v1, v0
; GFX9-NEXT: s_mov_b32 s4, 0x5040100
; GFX9-NEXT: v_perm_b32 v0, v0, v2, s4
; GFX9-NEXT: s_setpc_b64 s[30:31] ; GFX9-NEXT: s_setpc_b64 s[30:31]
%y = or <2 x i32> %y.arg, <i32 1, i32 1> %y = or <2 x i32> %y.arg, <i32 1, i32 1>
%fabs = call <2 x half> @llvm.fabs.v2f16(<2 x half> %x) %fabs = call <2 x half> @llvm.fabs.v2f16(<2 x half> %x)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff