diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td index df5787dc4840..1265bfb18aaa 100644 --- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td +++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td @@ -995,6 +995,7 @@ def Math_RsqrtOp : Math_FloatUnaryOp<"rsqrt"> { %a = math.rsqrt %b : f64 ``` }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp index 7e9d4acae682..4c0274ddb18a 100644 --- a/mlir/lib/Dialect/Math/IR/MathOps.cpp +++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp @@ -517,6 +517,28 @@ OpFoldResult math::PowFOp::fold(FoldAdaptor adaptor) { }); } +//===----------------------------------------------------------------------===// +// RsqrtOp folder +//===----------------------------------------------------------------------===// + +OpFoldResult math::RsqrtOp::fold(FoldAdaptor adaptor) { + return constFoldUnaryOpConditional( + adaptor.getOperands(), [](const APFloat &a) -> std::optional { + if (a.isNegative()) + return {}; + + APFloat one(a.getSemantics(), 1); + switch (a.getSizeInBits(a.getSemantics())) { + case 64: + return one / APFloat(sqrt(a.convertToDouble())); + case 32: + return one / APFloat(sqrtf(a.convertToFloat())); + default: + return {}; + } + }); +} + //===----------------------------------------------------------------------===// // SqrtOp folder //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Math/canonicalize.mlir b/mlir/test/Dialect/Math/canonicalize.mlir index c900a24e821b..67235c38e9cd 100644 --- a/mlir/test/Dialect/Math/canonicalize.mlir +++ b/mlir/test/Dialect/Math/canonicalize.mlir @@ -102,6 +102,33 @@ func.func @powf_fold_vec() -> (vector<4xf32>) { return %0 : vector<4xf32> } +// CHECK-LABEL: @rsqrt_fold +// CHECK: %[[cst:.+]] = arith.constant 5.000000e-01 : f32 +// CHECK: return %[[cst]] +func.func @rsqrt_fold() -> f32 { + %c = arith.constant 4.0 : f32 + %r = math.rsqrt %c : f32 + return %r : f32 +} + +// CHECK-LABEL: @rsqrt_fold_vec +// CHECK: %[[cst:.+]] = arith.constant dense<[1.000000e+00, 5.000000e-01]> : vector<2xf32> +// CHECK: return %[[cst]] +func.func @rsqrt_fold_vec() -> (vector<2xf32>) { + %v1 = arith.constant dense<[1.0, 4.0]> : vector<2xf32> + %0 = math.rsqrt %v1 : vector<2xf32> + return %0 : vector<2xf32> +} + +// CHECK-LABEL: @rsqrt_poison +// CHECK: %[[P:.*]] = ub.poison : f32 +// CHECK: return %[[P]] +func.func @rsqrt_poison() -> f32 { + %0 = ub.poison : f32 + %1 = math.rsqrt %0 : f32 + return %1 : f32 +} + // CHECK-LABEL: @sqrt_fold // CHECK: %[[cst:.+]] = arith.constant 2.000000e+00 : f32 // CHECK: return %[[cst]]