From 1cade8699719c934a8debb7bef9fdc3ff11e9602 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 3 Jan 2025 18:02:59 +0100 Subject: [PATCH] [mlir][arith] Fold `(a * b) / b -> a` (#121534) If overflow flags allow it. Alive2 check: https://alive2.llvm.org/ce/z/5XWjWE --- mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 24 +++++++++ mlir/test/Dialect/Arith/canonicalize.mlir | 64 +++++++++++++++++++++++ 2 files changed, 88 insertions(+) diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index d8b314a3fa43..e016a6e16e59 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -580,11 +580,31 @@ void arith::MulUIExtendedOp::getCanonicalizationPatterns( // DivUIOp //===----------------------------------------------------------------------===// +/// Fold `(a * b) / b -> a` +static Value foldDivMul(Value lhs, Value rhs, + arith::IntegerOverflowFlags ovfFlags) { + auto mul = lhs.getDefiningOp(); + if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags)) + return {}; + + if (mul.getLhs() == rhs) + return mul.getRhs(); + + if (mul.getRhs() == rhs) + return mul.getLhs(); + + return {}; +} + OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) { // divui (x, 1) -> x. if (matchPattern(adaptor.getRhs(), m_One())) return getLhs(); + // (a * b) / b -> a + if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw)) + return val; + // Don't fold if it would require a division by zero. bool div0 = false; auto result = constFoldBinaryOp(adaptor.getOperands(), @@ -621,6 +641,10 @@ OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) { if (matchPattern(adaptor.getRhs(), m_One())) return getLhs(); + // (a * b) / b -> a + if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw)) + return val; + // Don't fold if it would overflow or if it requires a division by zero. bool overflowOrDiv0 = false; auto result = constFoldBinaryOp( diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index 6a186a0c6cec..522711b08f28 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -2060,6 +2060,70 @@ func.func @test_divf1(%arg0 : f32, %arg1 : f32) -> (f32) { // ----- +func.func @fold_divui_of_muli_0(%arg0 : index, %arg1 : index) -> index { + %0 = arith.muli %arg0, %arg1 overflow : index + %1 = arith.divui %0, %arg0 : index + return %1 : index +} +// CHECK-LABEL: func @fold_divui_of_muli_0( +// CHECK-SAME: %[[ARG0:.+]]: index, +// CHECK-SAME: %[[ARG1:.+]]: index) +// CHECK: return %[[ARG1]] + +func.func @fold_divui_of_muli_1(%arg0 : index, %arg1 : index) -> index { + %0 = arith.muli %arg0, %arg1 overflow : index + %1 = arith.divui %0, %arg1 : index + return %1 : index +} +// CHECK-LABEL: func @fold_divui_of_muli_1( +// CHECK-SAME: %[[ARG0:.+]]: index, +// CHECK-SAME: %[[ARG1:.+]]: index) +// CHECK: return %[[ARG0]] + +func.func @fold_divsi_of_muli_0(%arg0 : index, %arg1 : index) -> index { + %0 = arith.muli %arg0, %arg1 overflow : index + %1 = arith.divsi %0, %arg0 : index + return %1 : index +} +// CHECK-LABEL: func @fold_divsi_of_muli_0( +// CHECK-SAME: %[[ARG0:.+]]: index, +// CHECK-SAME: %[[ARG1:.+]]: index) +// CHECK: return %[[ARG1]] + +func.func @fold_divsi_of_muli_1(%arg0 : index, %arg1 : index) -> index { + %0 = arith.muli %arg0, %arg1 overflow : index + %1 = arith.divsi %0, %arg1 : index + return %1 : index +} +// CHECK-LABEL: func @fold_divsi_of_muli_1( +// CHECK-SAME: %[[ARG0:.+]]: index, +// CHECK-SAME: %[[ARG1:.+]]: index) +// CHECK: return %[[ARG0]] + +// Do not fold divui(mul(a, v), v) -> a with nuw attribute. +func.func @no_fold_divui_of_muli(%arg0 : index, %arg1 : index) -> index { + %0 = arith.muli %arg0, %arg1 : index + %1 = arith.divui %0, %arg0 : index + return %1 : index +} +// CHECK-LABEL: func @no_fold_divui_of_muli +// CHECK: %[[T0:.+]] = arith.muli +// CHECK: %[[T1:.+]] = arith.divui %[[T0]], +// CHECK: return %[[T1]] + +// Do not fold divsi(mul(a, v), v) -> a with nuw attribute. +func.func @no_fold_divsi_of_muli(%arg0 : index, %arg1 : index) -> index { + %0 = arith.muli %arg0, %arg1 : index + %1 = arith.divsi %0, %arg0 : index + return %1 : index +} +// CHECK-LABEL: func @no_fold_divsi_of_muli +// CHECK: %[[T0:.+]] = arith.muli +// CHECK: %[[T1:.+]] = arith.divsi %[[T0]], +// CHECK: return %[[T1]] + +// ----- + // CHECK-LABEL: @test_cmpf( func.func @test_cmpf(%arg0 : f32) -> (i1, i1, i1, i1) { // CHECK-DAG: %[[T:.*]] = arith.constant true