[mlir][arith] Fold (a * b) / b -> a (#121534)

If overflow flags allow it.

Alive2 check: https://alive2.llvm.org/ce/z/5XWjWE
This commit is contained in:
Ivan Butygin 2025-01-03 18:02:59 +01:00 committed by GitHub
parent fa56e8bb64
commit 1cade86997
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 88 additions and 0 deletions

View File

@ -580,11 +580,31 @@ void arith::MulUIExtendedOp::getCanonicalizationPatterns(
// DivUIOp
//===----------------------------------------------------------------------===//
/// Fold `(a * b) / b -> a`
static Value foldDivMul(Value lhs, Value rhs,
arith::IntegerOverflowFlags ovfFlags) {
auto mul = lhs.getDefiningOp<mlir::arith::MulIOp>();
if (!mul || !bitEnumContainsAll(mul.getOverflowFlags(), ovfFlags))
return {};
if (mul.getLhs() == rhs)
return mul.getRhs();
if (mul.getRhs() == rhs)
return mul.getLhs();
return {};
}
OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
// divui (x, 1) -> x.
if (matchPattern(adaptor.getRhs(), m_One()))
return getLhs();
// (a * b) / b -> a
if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nuw))
return val;
// Don't fold if it would require a division by zero.
bool div0 = false;
auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
@ -621,6 +641,10 @@ OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
if (matchPattern(adaptor.getRhs(), m_One()))
return getLhs();
// (a * b) / b -> a
if (Value val = foldDivMul(getLhs(), getRhs(), IntegerOverflowFlags::nsw))
return val;
// Don't fold if it would overflow or if it requires a division by zero.
bool overflowOrDiv0 = false;
auto result = constFoldBinaryOp<IntegerAttr>(

View File

@ -2060,6 +2060,70 @@ func.func @test_divf1(%arg0 : f32, %arg1 : f32) -> (f32) {
// -----
func.func @fold_divui_of_muli_0(%arg0 : index, %arg1 : index) -> index {
%0 = arith.muli %arg0, %arg1 overflow<nuw> : index
%1 = arith.divui %0, %arg0 : index
return %1 : index
}
// CHECK-LABEL: func @fold_divui_of_muli_0(
// CHECK-SAME: %[[ARG0:.+]]: index,
// CHECK-SAME: %[[ARG1:.+]]: index)
// CHECK: return %[[ARG1]]
func.func @fold_divui_of_muli_1(%arg0 : index, %arg1 : index) -> index {
%0 = arith.muli %arg0, %arg1 overflow<nuw> : index
%1 = arith.divui %0, %arg1 : index
return %1 : index
}
// CHECK-LABEL: func @fold_divui_of_muli_1(
// CHECK-SAME: %[[ARG0:.+]]: index,
// CHECK-SAME: %[[ARG1:.+]]: index)
// CHECK: return %[[ARG0]]
func.func @fold_divsi_of_muli_0(%arg0 : index, %arg1 : index) -> index {
%0 = arith.muli %arg0, %arg1 overflow<nsw> : index
%1 = arith.divsi %0, %arg0 : index
return %1 : index
}
// CHECK-LABEL: func @fold_divsi_of_muli_0(
// CHECK-SAME: %[[ARG0:.+]]: index,
// CHECK-SAME: %[[ARG1:.+]]: index)
// CHECK: return %[[ARG1]]
func.func @fold_divsi_of_muli_1(%arg0 : index, %arg1 : index) -> index {
%0 = arith.muli %arg0, %arg1 overflow<nsw> : index
%1 = arith.divsi %0, %arg1 : index
return %1 : index
}
// CHECK-LABEL: func @fold_divsi_of_muli_1(
// CHECK-SAME: %[[ARG0:.+]]: index,
// CHECK-SAME: %[[ARG1:.+]]: index)
// CHECK: return %[[ARG0]]
// Do not fold divui(mul(a, v), v) -> a with nuw attribute.
func.func @no_fold_divui_of_muli(%arg0 : index, %arg1 : index) -> index {
%0 = arith.muli %arg0, %arg1 : index
%1 = arith.divui %0, %arg0 : index
return %1 : index
}
// CHECK-LABEL: func @no_fold_divui_of_muli
// CHECK: %[[T0:.+]] = arith.muli
// CHECK: %[[T1:.+]] = arith.divui %[[T0]],
// CHECK: return %[[T1]]
// Do not fold divsi(mul(a, v), v) -> a with nuw attribute.
func.func @no_fold_divsi_of_muli(%arg0 : index, %arg1 : index) -> index {
%0 = arith.muli %arg0, %arg1 : index
%1 = arith.divsi %0, %arg0 : index
return %1 : index
}
// CHECK-LABEL: func @no_fold_divsi_of_muli
// CHECK: %[[T0:.+]] = arith.muli
// CHECK: %[[T1:.+]] = arith.divsi %[[T0]],
// CHECK: return %[[T1]]
// -----
// CHECK-LABEL: @test_cmpf(
func.func @test_cmpf(%arg0 : f32) -> (i1, i1, i1, i1) {
// CHECK-DAG: %[[T:.*]] = arith.constant true