AMDGPU: Skip last corrections and scaling for afn llvm.sqrt.f64 (#183697)
Device libs has a fast sqrt macro implemented this way.
This commit is contained in:
parent
1264ffc4cc
commit
9be0cc173d
@ -5969,18 +5969,21 @@ bool AMDGPULegalizerInfo::legalizeFSQRTF64(MachineInstr &MI,
|
||||
Register X = MI.getOperand(1).getReg();
|
||||
unsigned Flags = MI.getFlags();
|
||||
|
||||
auto ScaleConstant = B.buildFConstant(F64, 0x1.0p-767);
|
||||
Register SqrtX = X;
|
||||
Register Scaling, ZeroInt;
|
||||
if (!MI.getFlag(MachineInstr::FmAfn)) {
|
||||
auto ScaleConstant = B.buildFConstant(F64, 0x1.0p-767);
|
||||
|
||||
auto ZeroInt = B.buildConstant(S32, 0);
|
||||
auto Scaling = B.buildFCmp(FCmpInst::FCMP_OLT, S1, X, ScaleConstant);
|
||||
ZeroInt = B.buildConstant(S32, 0).getReg(0);
|
||||
Scaling = B.buildFCmp(FCmpInst::FCMP_OLT, S1, X, ScaleConstant).getReg(0);
|
||||
|
||||
// Scale up input if it is too small.
|
||||
auto ScaleUpFactor = B.buildConstant(S32, 256);
|
||||
auto ScaleUp = B.buildSelect(S32, Scaling, ScaleUpFactor, ZeroInt);
|
||||
auto SqrtX = B.buildFLdexp(F64, X, ScaleUp, Flags);
|
||||
// Scale up input if it is too small.
|
||||
auto ScaleUpFactor = B.buildConstant(S32, 256);
|
||||
auto ScaleUp = B.buildSelect(S32, Scaling, ScaleUpFactor, ZeroInt);
|
||||
SqrtX = B.buildFLdexp(F64, X, ScaleUp, Flags).getReg(0);
|
||||
}
|
||||
|
||||
auto SqrtY =
|
||||
B.buildIntrinsic(Intrinsic::amdgcn_rsq, {F64}).addReg(SqrtX.getReg(0));
|
||||
auto SqrtY = B.buildIntrinsic(Intrinsic::amdgcn_rsq, {F64}).addReg(SqrtX);
|
||||
|
||||
auto Half = B.buildFConstant(F64, 0.5);
|
||||
auto SqrtH0 = B.buildFMul(F64, SqrtY, Half);
|
||||
@ -5997,15 +6000,17 @@ bool AMDGPULegalizerInfo::legalizeFSQRTF64(MachineInstr &MI,
|
||||
|
||||
auto SqrtS2 = B.buildFMA(F64, SqrtD0, SqrtH1, SqrtS1);
|
||||
|
||||
auto NegSqrtS2 = B.buildFNeg(F64, SqrtS2);
|
||||
auto SqrtD1 = B.buildFMA(F64, NegSqrtS2, SqrtS2, SqrtX);
|
||||
Register SqrtRet = SqrtS2.getReg(0);
|
||||
if (!MI.getFlag(MachineInstr::FmAfn)) {
|
||||
auto NegSqrtS2 = B.buildFNeg(F64, SqrtS2);
|
||||
auto SqrtD1 = B.buildFMA(F64, NegSqrtS2, SqrtS2, SqrtX);
|
||||
auto SqrtD2 = B.buildFMA(F64, SqrtD1, SqrtH1, SqrtS2);
|
||||
|
||||
auto SqrtRet = B.buildFMA(F64, SqrtD1, SqrtH1, SqrtS2);
|
||||
|
||||
// Scale down the result.
|
||||
auto ScaleDownFactor = B.buildConstant(S32, -128);
|
||||
auto ScaleDown = B.buildSelect(S32, Scaling, ScaleDownFactor, ZeroInt);
|
||||
SqrtRet = B.buildFLdexp(F64, SqrtRet, ScaleDown, Flags);
|
||||
// Scale down the result.
|
||||
auto ScaleDownFactor = B.buildConstant(S32, -128);
|
||||
auto ScaleDown = B.buildSelect(S32, Scaling, ScaleDownFactor, ZeroInt);
|
||||
SqrtRet = B.buildFLdexp(F64, SqrtD2, ScaleDown, Flags).getReg(0);
|
||||
}
|
||||
|
||||
Register IsZeroOrInf;
|
||||
if (MI.getFlag(MachineInstr::FmNoInfs)) {
|
||||
|
||||
@ -13416,17 +13416,20 @@ SDValue SITargetLowering::lowerFSQRTF64(SDValue Op, SelectionDAG &DAG) const {
|
||||
SDLoc DL(Op);
|
||||
|
||||
SDValue X = Op.getOperand(0);
|
||||
SDValue ScaleConstant = DAG.getConstantFP(0x1.0p-767, DL, MVT::f64);
|
||||
|
||||
SDValue Scaling = DAG.getSetCC(DL, MVT::i1, X, ScaleConstant, ISD::SETOLT);
|
||||
|
||||
SDValue ZeroInt = DAG.getConstant(0, DL, MVT::i32);
|
||||
|
||||
// Scale up input if it is too small.
|
||||
SDValue ScaleUpFactor = DAG.getConstant(256, DL, MVT::i32);
|
||||
SDValue ScaleUp =
|
||||
DAG.getNode(ISD::SELECT, DL, MVT::i32, Scaling, ScaleUpFactor, ZeroInt);
|
||||
SDValue SqrtX = DAG.getNode(ISD::FLDEXP, DL, MVT::f64, X, ScaleUp, Flags);
|
||||
SDValue SqrtX = X;
|
||||
SDValue Scaling;
|
||||
if (!Flags.hasApproximateFuncs()) {
|
||||
SDValue ScaleConstant = DAG.getConstantFP(0x1.0p-767, DL, MVT::f64);
|
||||
Scaling = DAG.getSetCC(DL, MVT::i1, X, ScaleConstant, ISD::SETOLT);
|
||||
|
||||
// Scale up input if it is too small.
|
||||
SDValue ScaleUpFactor = DAG.getConstant(256, DL, MVT::i32);
|
||||
SDValue ScaleUp =
|
||||
DAG.getNode(ISD::SELECT, DL, MVT::i32, Scaling, ScaleUpFactor, ZeroInt);
|
||||
SqrtX = DAG.getNode(ISD::FLDEXP, DL, MVT::f64, X, ScaleUp, Flags);
|
||||
}
|
||||
|
||||
SDValue SqrtY = DAG.getNode(AMDGPUISD::RSQ, DL, MVT::f64, SqrtX);
|
||||
|
||||
@ -13448,16 +13451,19 @@ SDValue SITargetLowering::lowerFSQRTF64(SDValue Op, SelectionDAG &DAG) const {
|
||||
|
||||
SDValue SqrtS2 = DAG.getNode(ISD::FMA, DL, MVT::f64, SqrtD0, SqrtH1, SqrtS1);
|
||||
|
||||
SDValue NegSqrtS2 = DAG.getNode(ISD::FNEG, DL, MVT::f64, SqrtS2);
|
||||
SDValue SqrtD1 =
|
||||
DAG.getNode(ISD::FMA, DL, MVT::f64, NegSqrtS2, SqrtS2, SqrtX);
|
||||
SDValue SqrtRet = SqrtS2;
|
||||
if (!Flags.hasApproximateFuncs()) {
|
||||
SDValue NegSqrtS2 = DAG.getNode(ISD::FNEG, DL, MVT::f64, SqrtS2);
|
||||
SDValue SqrtD1 =
|
||||
DAG.getNode(ISD::FMA, DL, MVT::f64, NegSqrtS2, SqrtS2, SqrtX);
|
||||
|
||||
SDValue SqrtRet = DAG.getNode(ISD::FMA, DL, MVT::f64, SqrtD1, SqrtH1, SqrtS2);
|
||||
SqrtRet = DAG.getNode(ISD::FMA, DL, MVT::f64, SqrtD1, SqrtH1, SqrtS2);
|
||||
|
||||
SDValue ScaleDownFactor = DAG.getSignedConstant(-128, DL, MVT::i32);
|
||||
SDValue ScaleDown =
|
||||
DAG.getNode(ISD::SELECT, DL, MVT::i32, Scaling, ScaleDownFactor, ZeroInt);
|
||||
SqrtRet = DAG.getNode(ISD::FLDEXP, DL, MVT::f64, SqrtRet, ScaleDown, Flags);
|
||||
SDValue ScaleDownFactor = DAG.getSignedConstant(-128, DL, MVT::i32);
|
||||
SDValue ScaleDown = DAG.getNode(ISD::SELECT, DL, MVT::i32, Scaling,
|
||||
ScaleDownFactor, ZeroInt);
|
||||
SqrtRet = DAG.getNode(ISD::FLDEXP, DL, MVT::f64, SqrtRet, ScaleDown, Flags);
|
||||
}
|
||||
|
||||
// TODO: Check for DAZ and expand to subnormals
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user