[mlir][math] Add constant folding for math.rsqrt (#184443)
Add a fold() method to RsqrtOp, matching the pattern used by SqrtOp and other math unary ops. The fold computes `1.0 / sqrt(x)` using APFloat division. --------- Signed-off-by: Ian Wood <ianwood@u.northwestern.edu>
This commit is contained in:
parent
ece4b75932
commit
630b9570d1
@ -995,6 +995,7 @@ def Math_RsqrtOp : Math_FloatUnaryOp<"rsqrt"> {
|
||||
%a = math.rsqrt %b : f64
|
||||
```
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@ -517,6 +517,28 @@ OpFoldResult math::PowFOp::fold(FoldAdaptor adaptor) {
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RsqrtOp folder
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult math::RsqrtOp::fold(FoldAdaptor adaptor) {
|
||||
return constFoldUnaryOpConditional<FloatAttr>(
|
||||
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
|
||||
if (a.isNegative())
|
||||
return {};
|
||||
|
||||
APFloat one(a.getSemantics(), 1);
|
||||
switch (a.getSizeInBits(a.getSemantics())) {
|
||||
case 64:
|
||||
return one / APFloat(sqrt(a.convertToDouble()));
|
||||
case 32:
|
||||
return one / APFloat(sqrtf(a.convertToFloat()));
|
||||
default:
|
||||
return {};
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SqrtOp folder
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@ -102,6 +102,33 @@ func.func @powf_fold_vec() -> (vector<4xf32>) {
|
||||
return %0 : vector<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @rsqrt_fold
|
||||
// CHECK: %[[cst:.+]] = arith.constant 5.000000e-01 : f32
|
||||
// CHECK: return %[[cst]]
|
||||
func.func @rsqrt_fold() -> f32 {
|
||||
%c = arith.constant 4.0 : f32
|
||||
%r = math.rsqrt %c : f32
|
||||
return %r : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @rsqrt_fold_vec
|
||||
// CHECK: %[[cst:.+]] = arith.constant dense<[1.000000e+00, 5.000000e-01]> : vector<2xf32>
|
||||
// CHECK: return %[[cst]]
|
||||
func.func @rsqrt_fold_vec() -> (vector<2xf32>) {
|
||||
%v1 = arith.constant dense<[1.0, 4.0]> : vector<2xf32>
|
||||
%0 = math.rsqrt %v1 : vector<2xf32>
|
||||
return %0 : vector<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @rsqrt_poison
|
||||
// CHECK: %[[P:.*]] = ub.poison : f32
|
||||
// CHECK: return %[[P]]
|
||||
func.func @rsqrt_poison() -> f32 {
|
||||
%0 = ub.poison : f32
|
||||
%1 = math.rsqrt %0 : f32
|
||||
return %1 : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @sqrt_fold
|
||||
// CHECK: %[[cst:.+]] = arith.constant 2.000000e+00 : f32
|
||||
// CHECK: return %[[cst]]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user