[AMDGPU] Add support for safe bfloat16 fdiv on targets with bf16 trans instructions (#154373)
Recent changes introduced custom lowering for bf16 fdiv on targets that support bf16 trans instructions, but only covered the unsafe version. This PR extends that support to the safe variant. For the safe version, the op is lowered by converting to float, performing the div in float, and converting the result back to bf16. This matches the behavior on targets that don't support bf16 trans instructions. Fixes SWDEV-550381.
This commit is contained in:
parent
b35b6297fd
commit
b170f17861
@ -11540,9 +11540,22 @@ SDValue SITargetLowering::LowerFDIV16(SDValue Op, SelectionDAG &DAG) const {
|
||||
return FastLowered;
|
||||
|
||||
SDLoc SL(Op);
|
||||
EVT VT = Op.getValueType();
|
||||
SDValue LHS = Op.getOperand(0);
|
||||
SDValue RHS = Op.getOperand(1);
|
||||
|
||||
SDValue LHSExt = DAG.getNode(ISD::FP_EXTEND, SL, MVT::f32, LHS);
|
||||
SDValue RHSExt = DAG.getNode(ISD::FP_EXTEND, SL, MVT::f32, RHS);
|
||||
|
||||
if (VT == MVT::bf16) {
|
||||
SDValue ExtDiv =
|
||||
DAG.getNode(ISD::FDIV, SL, MVT::f32, LHSExt, RHSExt, Op->getFlags());
|
||||
return DAG.getNode(ISD::FP_ROUND, SL, MVT::bf16, ExtDiv,
|
||||
DAG.getTargetConstant(0, SL, MVT::i32));
|
||||
}
|
||||
|
||||
assert(VT == MVT::f16);
|
||||
|
||||
// a32.u = opx(V_CVT_F32_F16, a.u); // CVT to F32
|
||||
// b32.u = opx(V_CVT_F32_F16, b.u); // CVT to F32
|
||||
// r32.u = opx(V_RCP_F32, b32.u); // rcp = 1 / d
|
||||
@ -11559,9 +11572,6 @@ SDValue SITargetLowering::LowerFDIV16(SDValue Op, SelectionDAG &DAG) const {
|
||||
// We will use ISD::FMA on targets that don't support ISD::FMAD.
|
||||
unsigned FMADOpCode =
|
||||
isOperationLegal(ISD::FMAD, MVT::f32) ? ISD::FMAD : ISD::FMA;
|
||||
|
||||
SDValue LHSExt = DAG.getNode(ISD::FP_EXTEND, SL, MVT::f32, LHS);
|
||||
SDValue RHSExt = DAG.getNode(ISD::FP_EXTEND, SL, MVT::f32, RHS);
|
||||
SDValue NegRHSExt = DAG.getNode(ISD::FNEG, SL, MVT::f32, RHSExt);
|
||||
SDValue Rcp =
|
||||
DAG.getNode(AMDGPUISD::RCP, SL, MVT::f32, RHSExt, Op->getFlags());
|
||||
|
@ -2,12 +2,68 @@
|
||||
; RUN: llc -mtriple=amdgcn -mcpu=gfx1250 -mattr=+real-true16 -denormal-fp-math-f32=preserve-sign < %s | FileCheck -check-prefixes=GFX1250-TRUE16 %s
|
||||
; RUN: llc -mtriple=amdgcn -mcpu=gfx1250 -mattr=-real-true16 -denormal-fp-math-f32=preserve-sign < %s | FileCheck -check-prefixes=GFX1250-FAKE16 %s
|
||||
|
||||
/* TODO: Support safe bf16 fdiv lowering.
|
||||
define bfloat @v_fdiv_bf16(bfloat %x, bfloat %y) {
|
||||
; GFX1250-TRUE16-LABEL: v_fdiv_bf16:
|
||||
; GFX1250-TRUE16: ; %bb.0:
|
||||
; GFX1250-TRUE16-NEXT: s_wait_loadcnt_dscnt 0x0
|
||||
; GFX1250-TRUE16-NEXT: s_wait_kmcnt 0x0
|
||||
; GFX1250-TRUE16-NEXT: v_mov_b16_e32 v2.l, 0
|
||||
; GFX1250-TRUE16-NEXT: v_mov_b16_e32 v2.h, v1.l
|
||||
; GFX1250-TRUE16-NEXT: v_mov_b16_e32 v1.h, v0.l
|
||||
; GFX1250-TRUE16-NEXT: s_delay_alu instid0(VALU_DEP_3) | instskip(NEXT) | instid1(VALU_DEP_1)
|
||||
; GFX1250-TRUE16-NEXT: v_mov_b16_e32 v1.l, v2.l
|
||||
; GFX1250-TRUE16-NEXT: v_div_scale_f32 v0, null, v2, v2, v1
|
||||
; GFX1250-TRUE16-NEXT: v_div_scale_f32 v4, vcc_lo, v1, v2, v1
|
||||
; GFX1250-TRUE16-NEXT: s_delay_alu instid0(VALU_DEP_2) | instskip(SKIP_2) | instid1(TRANS32_DEP_1)
|
||||
; GFX1250-TRUE16-NEXT: v_rcp_f32_e32 v3, v0
|
||||
; GFX1250-TRUE16-NEXT: s_denorm_mode 15
|
||||
; GFX1250-TRUE16-NEXT: v_nop
|
||||
; GFX1250-TRUE16-NEXT: v_fma_f32 v5, -v0, v3, 1.0
|
||||
; GFX1250-TRUE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
|
||||
; GFX1250-TRUE16-NEXT: v_fmac_f32_e32 v3, v5, v3
|
||||
; GFX1250-TRUE16-NEXT: v_mul_f32_e32 v5, v4, v3
|
||||
; GFX1250-TRUE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
|
||||
; GFX1250-TRUE16-NEXT: v_fma_f32 v6, -v0, v5, v4
|
||||
; GFX1250-TRUE16-NEXT: v_fmac_f32_e32 v5, v6, v3
|
||||
; GFX1250-TRUE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_1) | instid1(VALU_DEP_1)
|
||||
; GFX1250-TRUE16-NEXT: v_fma_f32 v0, -v0, v5, v4
|
||||
; GFX1250-TRUE16-NEXT: s_denorm_mode 12
|
||||
; GFX1250-TRUE16-NEXT: v_div_fmas_f32 v0, v0, v3, v5
|
||||
; GFX1250-TRUE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
|
||||
; GFX1250-TRUE16-NEXT: v_div_fixup_f32 v0, v0, v2, v1
|
||||
; GFX1250-TRUE16-NEXT: v_cvt_pk_bf16_f32 v0, v0, s0
|
||||
; GFX1250-TRUE16-NEXT: s_set_pc_i64 s[30:31]
|
||||
;
|
||||
; GFX1250-FAKE16-LABEL: v_fdiv_bf16:
|
||||
; GFX1250-FAKE16: ; %bb.0:
|
||||
; GFX1250-FAKE16-NEXT: s_wait_loadcnt_dscnt 0x0
|
||||
; GFX1250-FAKE16-NEXT: s_wait_kmcnt 0x0
|
||||
; GFX1250-FAKE16-NEXT: v_dual_lshlrev_b32 v1, 16, v1 :: v_dual_lshlrev_b32 v0, 16, v0
|
||||
; GFX1250-FAKE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_1) | instid1(VALU_DEP_2)
|
||||
; GFX1250-FAKE16-NEXT: v_div_scale_f32 v2, null, v1, v1, v0
|
||||
; GFX1250-FAKE16-NEXT: v_div_scale_f32 v4, vcc_lo, v0, v1, v0
|
||||
; GFX1250-FAKE16-NEXT: v_rcp_f32_e32 v3, v2
|
||||
; GFX1250-FAKE16-NEXT: s_denorm_mode 15
|
||||
; GFX1250-FAKE16-NEXT: v_nop
|
||||
; GFX1250-FAKE16-NEXT: s_delay_alu instid0(TRANS32_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
|
||||
; GFX1250-FAKE16-NEXT: v_fma_f32 v5, -v2, v3, 1.0
|
||||
; GFX1250-FAKE16-NEXT: v_fmac_f32_e32 v3, v5, v3
|
||||
; GFX1250-FAKE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
|
||||
; GFX1250-FAKE16-NEXT: v_mul_f32_e32 v5, v4, v3
|
||||
; GFX1250-FAKE16-NEXT: v_fma_f32 v6, -v2, v5, v4
|
||||
; GFX1250-FAKE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
|
||||
; GFX1250-FAKE16-NEXT: v_fmac_f32_e32 v5, v6, v3
|
||||
; GFX1250-FAKE16-NEXT: v_fma_f32 v2, -v2, v5, v4
|
||||
; GFX1250-FAKE16-NEXT: s_denorm_mode 12
|
||||
; GFX1250-FAKE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
|
||||
; GFX1250-FAKE16-NEXT: v_div_fmas_f32 v2, v2, v3, v5
|
||||
; GFX1250-FAKE16-NEXT: v_div_fixup_f32 v0, v2, v1, v0
|
||||
; GFX1250-FAKE16-NEXT: s_delay_alu instid0(VALU_DEP_1)
|
||||
; GFX1250-FAKE16-NEXT: v_cvt_pk_bf16_f32 v0, v0, s0
|
||||
; GFX1250-FAKE16-NEXT: s_set_pc_i64 s[30:31]
|
||||
%fdiv = fdiv bfloat %x, %y
|
||||
ret bfloat %fdiv
|
||||
}
|
||||
*/
|
||||
|
||||
define bfloat @v_rcp_bf16(bfloat %x) {
|
||||
; GFX1250-TRUE16-LABEL: v_rcp_bf16:
|
||||
|
Loading…
x
Reference in New Issue
Block a user