[mlir][math] Add constant folding for math.rsqrt (#184443)

Add a fold() method to RsqrtOp, matching the pattern used by SqrtOp and
other math unary ops. The fold computes `1.0 / sqrt(x)` using APFloat
division.

---------

Signed-off-by: Ian Wood <ianwood@u.northwestern.edu>
This commit is contained in:
Ian Wood 2026-03-03 15:50:10 -08:00 committed by GitHub
parent ece4b75932
commit 630b9570d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 50 additions and 0 deletions

View File

@ -995,6 +995,7 @@ def Math_RsqrtOp : Math_FloatUnaryOp<"rsqrt"> {
%a = math.rsqrt %b : f64
```
}];
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//

View File

@ -517,6 +517,28 @@ OpFoldResult math::PowFOp::fold(FoldAdaptor adaptor) {
});
}
//===----------------------------------------------------------------------===//
// RsqrtOp folder
//===----------------------------------------------------------------------===//
OpFoldResult math::RsqrtOp::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
if (a.isNegative())
return {};
APFloat one(a.getSemantics(), 1);
switch (a.getSizeInBits(a.getSemantics())) {
case 64:
return one / APFloat(sqrt(a.convertToDouble()));
case 32:
return one / APFloat(sqrtf(a.convertToFloat()));
default:
return {};
}
});
}
//===----------------------------------------------------------------------===//
// SqrtOp folder
//===----------------------------------------------------------------------===//

View File

@ -102,6 +102,33 @@ func.func @powf_fold_vec() -> (vector<4xf32>) {
return %0 : vector<4xf32>
}
// CHECK-LABEL: @rsqrt_fold
// CHECK: %[[cst:.+]] = arith.constant 5.000000e-01 : f32
// CHECK: return %[[cst]]
func.func @rsqrt_fold() -> f32 {
%c = arith.constant 4.0 : f32
%r = math.rsqrt %c : f32
return %r : f32
}
// CHECK-LABEL: @rsqrt_fold_vec
// CHECK: %[[cst:.+]] = arith.constant dense<[1.000000e+00, 5.000000e-01]> : vector<2xf32>
// CHECK: return %[[cst]]
func.func @rsqrt_fold_vec() -> (vector<2xf32>) {
%v1 = arith.constant dense<[1.0, 4.0]> : vector<2xf32>
%0 = math.rsqrt %v1 : vector<2xf32>
return %0 : vector<2xf32>
}
// CHECK-LABEL: @rsqrt_poison
// CHECK: %[[P:.*]] = ub.poison : f32
// CHECK: return %[[P]]
func.func @rsqrt_poison() -> f32 {
%0 = ub.poison : f32
%1 = math.rsqrt %0 : f32
return %1 : f32
}
// CHECK-LABEL: @sqrt_fold
// CHECK: %[[cst:.+]] = arith.constant 2.000000e+00 : f32
// CHECK: return %[[cst]]