From de2bac367ff9da74191bd2de130e4a81db07ae08 Mon Sep 17 00:00:00 2001 From: Matthias Guenther Date: Thu, 7 Aug 2025 08:52:03 -0700 Subject: [PATCH] [MLIR] Allow `constFoldBinaryOp` to fold `(T1, T1) -> T2` (#151410) The `constFoldBinaryOp` helper function had limited support for different input and output types, but the static type of the underlying value (e.g. `APInt`) had to match between the inputs and the output. This worked fine for int comparisons of the form `(intN, intN) -> int1`, as the static type signature was `(APInt, APInt) -> APInt`. However, float comparisons map `(floatN, floatN) -> int1`, with a static type signature of `(APFloat, APFloat) -> APInt`. This use case wasn't supported by `constFoldBinaryOp`. `constFoldBinaryOp` now accepts an optional template argument overriding the return type in case it differs from the input type. If the new template argument isn't provided, the default behavior is unchanged (i.e. the return type will be assumed to match the input type). `constFoldUnaryOp` received similar changes in order to support folding non-cast ops of the form `(T1) -> T2` (e.g. a `sign` op mapping `(floatN) -> sint32`). --- mlir/include/mlir/Dialect/CommonFolders.h | 147 +++++++++++++++----- mlir/test/Dialect/common_folders.mlir | 22 +++ mlir/test/lib/Dialect/Test/TestOps.td | 20 +++ mlir/test/lib/Dialect/Test/TestPatterns.cpp | 81 +++++++++++ 4 files changed, 237 insertions(+), 33 deletions(-) create mode 100644 mlir/test/Dialect/common_folders.mlir diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h index b5a12426aff8..113765157946 100644 --- a/mlir/include/mlir/Dialect/CommonFolders.h +++ b/mlir/include/mlir/Dialect/CommonFolders.h @@ -15,10 +15,16 @@ #ifndef MLIR_DIALECT_COMMONFOLDERS_H #define MLIR_DIALECT_COMMONFOLDERS_H +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/Types.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" + +#include +#include #include namespace mlir { @@ -30,11 +36,13 @@ class PoisonAttr; /// Uses `resultType` for the type of the returned attribute. /// Optional PoisonAttr template argument allows to specify 'poison' attribute /// which will be directly propagated to result. -template (ElementValueT, ElementValueT)>> + std::optional(ElementValueT, ElementValueT)>> Attribute constFoldBinaryOpConditional(ArrayRef operands, Type resultType, CalculationT &&calculate) { @@ -65,7 +73,7 @@ Attribute constFoldBinaryOpConditional(ArrayRef operands, if (!calRes) return {}; - return AttrElementT::get(resultType, *calRes); + return ResultAttrElementT::get(resultType, *calRes); } if (isa(operands[0]) && @@ -99,7 +107,7 @@ Attribute constFoldBinaryOpConditional(ArrayRef operands, return {}; auto lhsIt = *maybeLhsIt; auto rhsIt = *maybeRhsIt; - SmallVector elementResults; + SmallVector elementResults; elementResults.reserve(lhs.getNumElements()); for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt) { auto elementResult = calculate(*lhsIt, *rhsIt); @@ -119,11 +127,13 @@ Attribute constFoldBinaryOpConditional(ArrayRef operands, /// attribute. /// Optional PoisonAttr template argument allows to specify 'poison' attribute /// which will be directly propagated to result. -template (ElementValueT, ElementValueT)>> + std::optional(ElementValueT, ElementValueT)>> Attribute constFoldBinaryOpConditional(ArrayRef operands, CalculationT &&calculate) { assert(operands.size() == 2 && "binary op takes two operands"); @@ -139,64 +149,73 @@ Attribute constFoldBinaryOpConditional(ArrayRef operands, return operands[1]; } - auto getResultType = [](Attribute attr) -> Type { + auto getAttrType = [](Attribute attr) -> Type { if (auto typed = dyn_cast_or_null(attr)) return typed.getType(); return {}; }; - Type lhsType = getResultType(operands[0]); - Type rhsType = getResultType(operands[1]); + Type lhsType = getAttrType(operands[0]); + Type rhsType = getAttrType(operands[1]); if (!lhsType || !rhsType) return {}; if (lhsType != rhsType) return {}; return constFoldBinaryOpConditional( operands, lhsType, std::forward(calculate)); } template > + function_ref> Attribute constFoldBinaryOp(ArrayRef operands, Type resultType, CalculationT &&calculate) { - return constFoldBinaryOpConditional( + return constFoldBinaryOpConditional( operands, resultType, - [&](ElementValueT a, ElementValueT b) -> std::optional { - return calculate(a, b); - }); + [&](ElementValueT a, ElementValueT b) + -> std::optional { return calculate(a, b); }); } -template > + function_ref> Attribute constFoldBinaryOp(ArrayRef operands, CalculationT &&calculate) { - return constFoldBinaryOpConditional( + return constFoldBinaryOpConditional( operands, - [&](ElementValueT a, ElementValueT b) -> std::optional { - return calculate(a, b); - }); + [&](ElementValueT a, ElementValueT b) + -> std::optional { return calculate(a, b); }); } /// Performs constant folding `calculate` with element-wise behavior on the one /// attributes in `operands` and returns the result if possible. +/// Uses `resultType` for the type of the returned attribute. /// Optional PoisonAttr template argument allows to specify 'poison' attribute /// which will be directly propagated to result. -template (ElementValueT)>> + function_ref(ElementValueT)>> Attribute constFoldUnaryOpConditional(ArrayRef operands, + Type resultType, CalculationT &&calculate) { - if (!llvm::getSingleElement(operands)) + if (!resultType || !llvm::getSingleElement(operands)) return {}; static_assert( @@ -214,7 +233,7 @@ Attribute constFoldUnaryOpConditional(ArrayRef operands, auto res = calculate(op.getValue()); if (!res) return {}; - return AttrElementT::get(op.getType(), *res); + return ResultAttrElementT::get(resultType, *res); } if (isa(operands[0])) { // Both operands are splats so we can avoid expanding the values out and @@ -224,7 +243,7 @@ Attribute constFoldUnaryOpConditional(ArrayRef operands, auto elementResult = calculate(op.getSplatValue()); if (!elementResult) return {}; - return DenseElementsAttr::get(op.getType(), *elementResult); + return DenseElementsAttr::get(cast(resultType), *elementResult); } else if (isa(operands[0])) { // Operands are ElementsAttr-derived; perform an element-wise fold by // expanding the values. @@ -234,7 +253,7 @@ Attribute constFoldUnaryOpConditional(ArrayRef operands, if (!maybeOpIt) return {}; auto opIt = *maybeOpIt; - SmallVector elementResults; + SmallVector elementResults; elementResults.reserve(op.getNumElements()); for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) { auto elementResult = calculate(*opIt); @@ -242,19 +261,81 @@ Attribute constFoldUnaryOpConditional(ArrayRef operands, return {}; elementResults.push_back(*elementResult); } - return DenseElementsAttr::get(op.getShapedType(), elementResults); + return DenseElementsAttr::get(cast(resultType), elementResults); } return {}; } -template > + class ResultAttrElementT = AttrElementT, + class ResultElementValueT = typename ResultAttrElementT::ValueType, + class CalculationT = + function_ref(ElementValueT)>> +Attribute constFoldUnaryOpConditional(ArrayRef operands, + CalculationT &&calculate) { + if (!llvm::getSingleElement(operands)) + return {}; + + static_assert( + std::is_void_v || !llvm::is_incomplete_v, + "PoisonAttr is undefined, either add a dependency on UB dialect or pass " + "void as template argument to opt-out from poison semantics."); + if constexpr (!std::is_void_v) { + if (isa(operands[0])) + return operands[0]; + } + + auto getAttrType = [](Attribute attr) -> Type { + if (auto typed = dyn_cast_or_null(attr)) + return typed.getType(); + return {}; + }; + + Type operandType = getAttrType(operands[0]); + if (!operandType) + return {}; + + return constFoldUnaryOpConditional( + operands, operandType, std::forward(calculate)); +} + +template > +Attribute constFoldUnaryOp(ArrayRef operands, Type resultType, + CalculationT &&calculate) { + return constFoldUnaryOpConditional( + operands, resultType, + [&](ElementValueT a) -> std::optional { + return calculate(a); + }); +} + +template > Attribute constFoldUnaryOp(ArrayRef operands, CalculationT &&calculate) { - return constFoldUnaryOpConditional( - operands, [&](ElementValueT a) -> std::optional { + return constFoldUnaryOpConditional( + operands, [&](ElementValueT a) -> std::optional { return calculate(a); }); } diff --git a/mlir/test/Dialect/common_folders.mlir b/mlir/test/Dialect/common_folders.mlir new file mode 100644 index 000000000000..92598b493755 --- /dev/null +++ b/mlir/test/Dialect/common_folders.mlir @@ -0,0 +1,22 @@ +// RUN: mlir-opt %s --test-fold-type-converting-op --split-input-file | FileCheck %s + +// CHECK-LABEL: @test_fold_unary_op_f32_to_si32( +func.func @test_fold_unary_op_f32_to_si32() -> tensor<4x2xsi32> { + // CHECK-NEXT: %[[POSITIVE_ONE:.*]] = arith.constant dense<1> : tensor<4x2xsi32> + // CHECK-NEXT: return %[[POSITIVE_ONE]] : tensor<4x2xsi32> + %operand = arith.constant dense<5.1> : tensor<4x2xf32> + %sign = test.sign %operand : (tensor<4x2xf32>) -> tensor<4x2xsi32> + return %sign : tensor<4x2xsi32> +} + +// ----- + +// CHECK-LABEL: @test_fold_binary_op_f32_to_i1( +func.func @test_fold_binary_op_f32_to_i1() -> tensor<8xi1> { + // CHECK-NEXT: %[[FALSE:.*]] = arith.constant dense : tensor<8xi1> + // CHECK-NEXT: return %[[FALSE]] : tensor<8xi1> + %lhs = arith.constant dense<5.1> : tensor<8xf32> + %rhs = arith.constant dense<4.2> : tensor<8xf32> + %less_than = test.less_than %lhs, %rhs : (tensor<8xf32>, tensor<8xf32>) -> tensor<8xi1> + return %less_than : tensor<8xi1> +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 843bd30a51af..3cdd2f226687 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1169,6 +1169,26 @@ def OpP : TEST_Op<"op_p"> { let results = (outs I32); } +// Test constant-folding a pattern that maps `(F32) -> SI32`. +def SignOp : TEST_Op<"sign", [SameOperandsAndResultShape]> { + let arguments = (ins RankedTensorOf<[F32]>:$operand); + let results = (outs RankedTensorOf<[SI32]>:$result); + + let assemblyFormat = [{ + $operand attr-dict `:` functional-type(operands, results) + }]; +} + +// Test constant-folding a pattern that maps `(F32, F32) -> I1`. +def LessThanOp : TEST_Op<"less_than", [SameOperandsAndResultShape]> { + let arguments = (ins RankedTensorOf<[F32]>:$lhs, RankedTensorOf<[F32]>:$rhs); + let results = (outs RankedTensorOf<[I1]>:$result); + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` functional-type(operands, results) + }]; +} + // Test same operand name enforces equality condition check. def TestEqualArgsPattern : Pat<(OpN $a, $a), (OpO $a)>; diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 7150401bdbdc..f92f0982f85b 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -10,6 +10,7 @@ #include "TestOps.h" #include "TestTypes.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -202,6 +203,66 @@ struct HoistEligibleOps : public OpRewritePattern { } }; +struct FoldSignOpF32ToSI32 : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(test::SignOp op, + PatternRewriter &rewriter) const override { + if (op->getNumOperands() != 1 || op->getNumResults() != 1) + return failure(); + + TypedAttr operandAttr; + matchPattern(op->getOperand(0), m_Constant(&operandAttr)); + if (!operandAttr) + return failure(); + + TypedAttr res = cast_or_null( + constFoldUnaryOp( + operandAttr, op.getType(), [](APFloat operand) -> APSInt { + static const APFloat zero(0.0f); + int operandSign = 0; + if (operand != zero) + operandSign = (operand < zero) ? -1 : +1; + return APSInt(APInt(32, operandSign), false); + })); + if (!res) + return failure(); + + rewriter.replaceOpWithNewOp(op, res); + return success(); + } +}; + +struct FoldLessThanOpF32ToI1 : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(test::LessThanOp op, + PatternRewriter &rewriter) const override { + if (op->getNumOperands() != 2 || op->getNumResults() != 1) + return failure(); + + TypedAttr lhsAttr; + TypedAttr rhsAttr; + matchPattern(op->getOperand(0), m_Constant(&lhsAttr)); + matchPattern(op->getOperand(1), m_Constant(&rhsAttr)); + + if (!lhsAttr || !rhsAttr) + return failure(); + + Attribute operandAttrs[2] = {lhsAttr, rhsAttr}; + TypedAttr res = cast_or_null( + constFoldBinaryOp( + operandAttrs, op.getType(), [](APFloat lhs, APFloat rhs) -> APInt { + return APInt(1, lhs < rhs); + })); + if (!res) + return failure(); + + rewriter.replaceOpWithNewOp(op, res); + return success(); + } +}; + /// This pattern moves "test.move_before_parent_op" before the parent op. struct MoveBeforeParentOp : public RewritePattern { MoveBeforeParentOp(MLIRContext *context) @@ -2226,6 +2287,24 @@ struct TestSelectiveReplacementPatternDriver (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; + +struct TestFoldTypeConvertingOp + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFoldTypeConvertingOp) + + StringRef getArgument() const final { return "test-fold-type-converting-op"; } + StringRef getDescription() const final { + return "Test helper functions for folding ops whose input and output types " + "differ, e.g. float comparisons of the form `(f32, f32) -> i1`."; + } + void runOnOperation() override { + MLIRContext *context = &getContext(); + mlir::RewritePatternSet patterns(context); + patterns.add(context); + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); + } +}; } // namespace //===----------------------------------------------------------------------===// @@ -2256,6 +2335,8 @@ void registerPatternsTestPass() { PassRegistration(); PassRegistration(); + + PassRegistration(); } } // namespace test } // namespace mlir