AMDGPU: Skip last corrections and scaling for afn llvm.sqrt.f64 (#183697)

Device libs has a fast sqrt macro implemented this way.
This commit is contained in:
Matt Arsenault 2026-03-28 00:59:25 +01:00 committed by GitHub
parent 1264ffc4cc
commit 9be0cc173d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 425 additions and 1288 deletions

View File

@ -5969,18 +5969,21 @@ bool AMDGPULegalizerInfo::legalizeFSQRTF64(MachineInstr &MI,
Register X = MI.getOperand(1).getReg();
unsigned Flags = MI.getFlags();
auto ScaleConstant = B.buildFConstant(F64, 0x1.0p-767);
Register SqrtX = X;
Register Scaling, ZeroInt;
if (!MI.getFlag(MachineInstr::FmAfn)) {
auto ScaleConstant = B.buildFConstant(F64, 0x1.0p-767);
auto ZeroInt = B.buildConstant(S32, 0);
auto Scaling = B.buildFCmp(FCmpInst::FCMP_OLT, S1, X, ScaleConstant);
ZeroInt = B.buildConstant(S32, 0).getReg(0);
Scaling = B.buildFCmp(FCmpInst::FCMP_OLT, S1, X, ScaleConstant).getReg(0);
// Scale up input if it is too small.
auto ScaleUpFactor = B.buildConstant(S32, 256);
auto ScaleUp = B.buildSelect(S32, Scaling, ScaleUpFactor, ZeroInt);
auto SqrtX = B.buildFLdexp(F64, X, ScaleUp, Flags);
// Scale up input if it is too small.
auto ScaleUpFactor = B.buildConstant(S32, 256);
auto ScaleUp = B.buildSelect(S32, Scaling, ScaleUpFactor, ZeroInt);
SqrtX = B.buildFLdexp(F64, X, ScaleUp, Flags).getReg(0);
}
auto SqrtY =
B.buildIntrinsic(Intrinsic::amdgcn_rsq, {F64}).addReg(SqrtX.getReg(0));
auto SqrtY = B.buildIntrinsic(Intrinsic::amdgcn_rsq, {F64}).addReg(SqrtX);
auto Half = B.buildFConstant(F64, 0.5);
auto SqrtH0 = B.buildFMul(F64, SqrtY, Half);
@ -5997,15 +6000,17 @@ bool AMDGPULegalizerInfo::legalizeFSQRTF64(MachineInstr &MI,
auto SqrtS2 = B.buildFMA(F64, SqrtD0, SqrtH1, SqrtS1);
auto NegSqrtS2 = B.buildFNeg(F64, SqrtS2);
auto SqrtD1 = B.buildFMA(F64, NegSqrtS2, SqrtS2, SqrtX);
Register SqrtRet = SqrtS2.getReg(0);
if (!MI.getFlag(MachineInstr::FmAfn)) {
auto NegSqrtS2 = B.buildFNeg(F64, SqrtS2);
auto SqrtD1 = B.buildFMA(F64, NegSqrtS2, SqrtS2, SqrtX);
auto SqrtD2 = B.buildFMA(F64, SqrtD1, SqrtH1, SqrtS2);
auto SqrtRet = B.buildFMA(F64, SqrtD1, SqrtH1, SqrtS2);
// Scale down the result.
auto ScaleDownFactor = B.buildConstant(S32, -128);
auto ScaleDown = B.buildSelect(S32, Scaling, ScaleDownFactor, ZeroInt);
SqrtRet = B.buildFLdexp(F64, SqrtRet, ScaleDown, Flags);
// Scale down the result.
auto ScaleDownFactor = B.buildConstant(S32, -128);
auto ScaleDown = B.buildSelect(S32, Scaling, ScaleDownFactor, ZeroInt);
SqrtRet = B.buildFLdexp(F64, SqrtD2, ScaleDown, Flags).getReg(0);
}
Register IsZeroOrInf;
if (MI.getFlag(MachineInstr::FmNoInfs)) {

View File

@ -13416,17 +13416,20 @@ SDValue SITargetLowering::lowerFSQRTF64(SDValue Op, SelectionDAG &DAG) const {
SDLoc DL(Op);
SDValue X = Op.getOperand(0);
SDValue ScaleConstant = DAG.getConstantFP(0x1.0p-767, DL, MVT::f64);
SDValue Scaling = DAG.getSetCC(DL, MVT::i1, X, ScaleConstant, ISD::SETOLT);
SDValue ZeroInt = DAG.getConstant(0, DL, MVT::i32);
// Scale up input if it is too small.
SDValue ScaleUpFactor = DAG.getConstant(256, DL, MVT::i32);
SDValue ScaleUp =
DAG.getNode(ISD::SELECT, DL, MVT::i32, Scaling, ScaleUpFactor, ZeroInt);
SDValue SqrtX = DAG.getNode(ISD::FLDEXP, DL, MVT::f64, X, ScaleUp, Flags);
SDValue SqrtX = X;
SDValue Scaling;
if (!Flags.hasApproximateFuncs()) {
SDValue ScaleConstant = DAG.getConstantFP(0x1.0p-767, DL, MVT::f64);
Scaling = DAG.getSetCC(DL, MVT::i1, X, ScaleConstant, ISD::SETOLT);
// Scale up input if it is too small.
SDValue ScaleUpFactor = DAG.getConstant(256, DL, MVT::i32);
SDValue ScaleUp =
DAG.getNode(ISD::SELECT, DL, MVT::i32, Scaling, ScaleUpFactor, ZeroInt);
SqrtX = DAG.getNode(ISD::FLDEXP, DL, MVT::f64, X, ScaleUp, Flags);
}
SDValue SqrtY = DAG.getNode(AMDGPUISD::RSQ, DL, MVT::f64, SqrtX);
@ -13448,16 +13451,19 @@ SDValue SITargetLowering::lowerFSQRTF64(SDValue Op, SelectionDAG &DAG) const {
SDValue SqrtS2 = DAG.getNode(ISD::FMA, DL, MVT::f64, SqrtD0, SqrtH1, SqrtS1);
SDValue NegSqrtS2 = DAG.getNode(ISD::FNEG, DL, MVT::f64, SqrtS2);
SDValue SqrtD1 =
DAG.getNode(ISD::FMA, DL, MVT::f64, NegSqrtS2, SqrtS2, SqrtX);
SDValue SqrtRet = SqrtS2;
if (!Flags.hasApproximateFuncs()) {
SDValue NegSqrtS2 = DAG.getNode(ISD::FNEG, DL, MVT::f64, SqrtS2);
SDValue SqrtD1 =
DAG.getNode(ISD::FMA, DL, MVT::f64, NegSqrtS2, SqrtS2, SqrtX);
SDValue SqrtRet = DAG.getNode(ISD::FMA, DL, MVT::f64, SqrtD1, SqrtH1, SqrtS2);
SqrtRet = DAG.getNode(ISD::FMA, DL, MVT::f64, SqrtD1, SqrtH1, SqrtS2);
SDValue ScaleDownFactor = DAG.getSignedConstant(-128, DL, MVT::i32);
SDValue ScaleDown =
DAG.getNode(ISD::SELECT, DL, MVT::i32, Scaling, ScaleDownFactor, ZeroInt);
SqrtRet = DAG.getNode(ISD::FLDEXP, DL, MVT::f64, SqrtRet, ScaleDown, Flags);
SDValue ScaleDownFactor = DAG.getSignedConstant(-128, DL, MVT::i32);
SDValue ScaleDown = DAG.getNode(ISD::SELECT, DL, MVT::i32, Scaling,
ScaleDownFactor, ZeroInt);
SqrtRet = DAG.getNode(ISD::FLDEXP, DL, MVT::f64, SqrtRet, ScaleDown, Flags);
}
// TODO: Check for DAZ and expand to subnormals

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff