//===- Rewrite.cpp - Rewrite ----------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "Rewrite.h" #include "mlir-c/IR.h" #include "mlir-c/Rewrite.h" #include "mlir-c/Support.h" #include "mlir/Bindings/Python/IRCore.h" // clang-format off #include "mlir/Bindings/Python/Nanobind.h" #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. // clang-format on #include "mlir/Config/mlir-config.h" #include "nanobind/nanobind.h" namespace nb = nanobind; using namespace mlir; using namespace nb::literals; using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN; namespace mlir { namespace python { namespace MLIR_BINDINGS_PYTHON_DOMAIN { class PyPatternRewriter { public: PyPatternRewriter(MlirPatternRewriter rewriter) : base(mlirPatternRewriterAsBase(rewriter)), ctx(PyMlirContext::forContext(mlirRewriterBaseGetContext(base))) {} PyInsertionPoint getInsertionPoint() const { MlirBlock block = mlirRewriterBaseGetInsertionBlock(base); MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion(base); if (mlirOperationIsNull(op)) { MlirOperation owner = mlirBlockGetParentOperation(block); auto parent = PyOperation::forOperation(ctx, owner); return PyInsertionPoint(PyBlock(parent, block)); } return PyInsertionPoint(PyOperation::forOperation(ctx, op)); } void replaceOp(MlirOperation op, MlirOperation newOp) { mlirRewriterBaseReplaceOpWithOperation(base, op, newOp); } void replaceOp(MlirOperation op, const std::vector &values) { mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data()); } void eraseOp(const PyOperation &op) { mlirRewriterBaseEraseOp(base, op); } private: MlirRewriterBase base; PyMlirContextRef ctx; }; struct PyMlirPDLResultList : MlirPDLResultList {}; #if MLIR_ENABLE_PDL_IN_PATTERNMATCH static nb::object objectFromPDLValue(MlirPDLValue value) { if (MlirValue v = mlirPDLValueAsValue(value); !mlirValueIsNull(v)) return nb::cast(v); if (MlirOperation v = mlirPDLValueAsOperation(value); !mlirOperationIsNull(v)) return nb::cast(v); if (MlirAttribute v = mlirPDLValueAsAttribute(value); !mlirAttributeIsNull(v)) return nb::cast(v); if (MlirType v = mlirPDLValueAsType(value); !mlirTypeIsNull(v)) return nb::cast(v); throw std::runtime_error("unsupported PDL value type"); } static std::vector objectsFromPDLValues(size_t nValues, MlirPDLValue *values) { std::vector args; args.reserve(nValues); for (size_t i = 0; i < nValues; ++i) args.push_back(objectFromPDLValue(values[i])); return args; } // Convert the Python object to a boolean. // If it evaluates to False, treat it as success; // otherwise, treat it as failure. // Note that None is considered success. static MlirLogicalResult logicalResultFromObject(const nb::object &obj) { if (obj.is_none()) return mlirLogicalResultSuccess(); return nb::cast(obj) ? mlirLogicalResultFailure() : mlirLogicalResultSuccess(); } /// Owning Wrapper around a PDLPatternModule. class PyPDLPatternModule { public: PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {} PyPDLPatternModule(PyPDLPatternModule &&other) noexcept : module(other.module) { other.module.ptr = nullptr; } ~PyPDLPatternModule() { if (module.ptr != nullptr) mlirPDLPatternModuleDestroy(module); } MlirPDLPatternModule get() { return module; } void registerRewriteFunction(const std::string &name, const nb::callable &fn) { mlirPDLPatternModuleRegisterRewriteFunction( get(), mlirStringRefCreate(name.data(), name.size()), [](MlirPatternRewriter rewriter, MlirPDLResultList results, size_t nValues, MlirPDLValue *values, void *userData) -> MlirLogicalResult { nb::handle f = nb::handle(static_cast(userData)); return logicalResultFromObject( f(PyPatternRewriter(rewriter), PyMlirPDLResultList{results.ptr}, objectsFromPDLValues(nValues, values))); }, fn.ptr()); } void registerConstraintFunction(const std::string &name, const nb::callable &fn) { mlirPDLPatternModuleRegisterConstraintFunction( get(), mlirStringRefCreate(name.data(), name.size()), [](MlirPatternRewriter rewriter, MlirPDLResultList results, size_t nValues, MlirPDLValue *values, void *userData) -> MlirLogicalResult { nb::handle f = nb::handle(static_cast(userData)); return logicalResultFromObject( f(PyPatternRewriter(rewriter), PyMlirPDLResultList{results.ptr}, objectsFromPDLValues(nValues, values))); }, fn.ptr()); } private: MlirPDLPatternModule module; }; #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH /// Owning Wrapper around a FrozenRewritePatternSet. class PyFrozenRewritePatternSet { public: PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {} PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept : set(other.set) { other.set.ptr = nullptr; } ~PyFrozenRewritePatternSet() { if (set.ptr != nullptr) mlirFrozenRewritePatternSetDestroy(set); } MlirFrozenRewritePatternSet get() { return set; } nb::object getCapsule() { return nb::steal( mlirPythonFrozenRewritePatternSetToCapsule(get())); } static nb::object createFromCapsule(const nb::object &capsule) { MlirFrozenRewritePatternSet rawPm = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr()); if (rawPm.ptr == nullptr) throw nb::python_error(); return nb::cast(PyFrozenRewritePatternSet(rawPm), nb::rv_policy::move); } private: MlirFrozenRewritePatternSet set; }; class PyRewritePatternSet { public: PyRewritePatternSet(MlirContext ctx) : set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {} ~PyRewritePatternSet() { if (set.ptr) mlirRewritePatternSetDestroy(set); } void add(MlirStringRef rootName, unsigned benefit, const nb::callable &matchAndRewrite) { MlirRewritePatternCallbacks callbacks; callbacks.construct = [](void *userData) { nb::handle(static_cast(userData)).inc_ref(); }; callbacks.destruct = [](void *userData) { nb::handle(static_cast(userData)).dec_ref(); }; callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op, MlirPatternRewriter rewriter, void *userData) -> MlirLogicalResult { nb::handle f(static_cast(userData)); PyMlirContextRef ctx = PyMlirContext::forContext(mlirOperationGetContext(op)); nb::object opView = PyOperation::forOperation(ctx, op)->createOpView(); nb::object res = f(opView, PyPatternRewriter(rewriter)); return logicalResultFromObject(res); }; MlirRewritePattern pattern = mlirOpRewritePatternCreate( rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(), /* nGeneratedNames */ 0, /* generatedNames */ nullptr); mlirRewritePatternSetAdd(set, pattern); } PyFrozenRewritePatternSet freeze() { MlirRewritePatternSet s = set; set.ptr = nullptr; return mlirFreezeRewritePattern(s); } private: MlirRewritePatternSet set; MlirContext ctx; }; enum PyGreedyRewriteStrictness : std::underlying_type_t< MlirGreedyRewriteStrictness> { MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP = MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP, MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS = MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS, MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS = MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS, }; enum PyGreedySimplifyRegionLevel : std::underlying_type_t< MlirGreedySimplifyRegionLevel> { MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED = MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED, MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL = MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL, MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE = MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE }; /// Owning Wrapper around a GreedyRewriteDriverConfig. class PyGreedyRewriteDriverConfig { public: PyGreedyRewriteDriverConfig() : config(mlirGreedyRewriteDriverConfigCreate()) {} PyGreedyRewriteDriverConfig(PyGreedyRewriteDriverConfig &&other) noexcept : config(other.config) { other.config.ptr = nullptr; } ~PyGreedyRewriteDriverConfig() { if (config.ptr != nullptr) mlirGreedyRewriteDriverConfigDestroy(config); } MlirGreedyRewriteDriverConfig get() { return config; } void setMaxIterations(int64_t maxIterations) { mlirGreedyRewriteDriverConfigSetMaxIterations(config, maxIterations); } void setMaxNumRewrites(int64_t maxNumRewrites) { mlirGreedyRewriteDriverConfigSetMaxNumRewrites(config, maxNumRewrites); } void setUseTopDownTraversal(bool useTopDownTraversal) { mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(config, useTopDownTraversal); } void enableFolding(bool enable) { mlirGreedyRewriteDriverConfigEnableFolding(config, enable); } void setStrictness(PyGreedyRewriteStrictness strictness) { mlirGreedyRewriteDriverConfigSetStrictness( config, static_cast(strictness)); } void setRegionSimplificationLevel(PyGreedySimplifyRegionLevel level) { mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel( config, static_cast(level)); } void enableConstantCSE(bool enable) { mlirGreedyRewriteDriverConfigEnableConstantCSE(config, enable); } int64_t getMaxIterations() { return mlirGreedyRewriteDriverConfigGetMaxIterations(config); } int64_t getMaxNumRewrites() { return mlirGreedyRewriteDriverConfigGetMaxNumRewrites(config); } bool getUseTopDownTraversal() { return mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(config); } bool isFoldingEnabled() { return mlirGreedyRewriteDriverConfigIsFoldingEnabled(config); } PyGreedyRewriteStrictness getStrictness() { return static_cast( mlirGreedyRewriteDriverConfigGetStrictness(config)); } PyGreedySimplifyRegionLevel getRegionSimplificationLevel() { return static_cast( mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(config)); } bool isConstantCSEEnabled() { return mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(config); } private: MlirGreedyRewriteDriverConfig config; }; /// Create the `mlir.rewrite` here. void populateRewriteSubmodule(nb::module_ &m) { // Enum definitions nb::enum_(m, "GreedyRewriteStrictness") .value("ANY_OP", MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP) .value("EXISTING_AND_NEW_OPS", MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS) .value("EXISTING_OPS", MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS); nb::enum_(m, "GreedySimplifyRegionLevel") .value("DISABLED", MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED) .value("NORMAL", MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL) .value("AGGRESSIVE", MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE); //---------------------------------------------------------------------------- // Mapping of the PatternRewriter //---------------------------------------------------------------------------- nb::class_(m, "PatternRewriter") .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint, "The current insertion point of the PatternRewriter.") .def( "replace_op", [](PyPatternRewriter &self, PyOperationBase &op, PyOperationBase &newOp) { self.replaceOp(op.getOperation(), newOp.getOperation()); }, "Replace an operation with a new operation.", nb::arg("op"), nb::arg("new_op")) .def( "replace_op", [](PyPatternRewriter &self, PyOperationBase &op, const std::vector &values) { std::vector values_(values.size()); std::copy(values.begin(), values.end(), values_.begin()); self.replaceOp(op.getOperation(), values_); }, "Replace an operation with a list of values.", nb::arg("op"), nb::arg("values")) .def("erase_op", &PyPatternRewriter::eraseOp, "Erase an operation.", nb::arg("op")); //---------------------------------------------------------------------------- // Mapping of the RewritePatternSet //---------------------------------------------------------------------------- nb::class_(m, "RewritePatternSet") .def( "__init__", [](PyRewritePatternSet &self, DefaultingPyMlirContext context) { new (&self) PyRewritePatternSet(context.get()->get()); }, "context"_a = nb::none()) .def( "add", [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn, unsigned benefit) { std::string opName; if (root.is_type()) { opName = nb::cast(root.attr("OPERATION_NAME")); } else if (nb::isinstance(root)) { opName = nb::cast(root); } else { throw nb::type_error( "the root argument must be a type or a string"); } self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit, fn); }, "root"_a, "fn"_a, "benefit"_a = 1, // clang-format off nb::sig("def add(self, root: type | str, fn: typing.Callable[[" MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", PatternRewriter], typing.Any], benefit: int = 1) -> None"), // clang-format on R"( Add a new rewrite pattern on the specified root operation, using the provided callable for matching and rewriting, and assign it the given benefit. Args: root: The root operation to which this pattern applies. This may be either an OpView subclass (e.g., ``arith.AddIOp``) or an operation name string (e.g., ``"arith.addi"``). fn: The callable to use for matching and rewriting, which takes an operation and a pattern rewriter as arguments. The match is considered successful iff the callable returns a value where ``bool(value)`` is ``False`` (e.g. ``None``). If possible, the operation is cast to its corresponding OpView subclass before being passed to the callable. benefit: The benefit of the pattern, defaulting to 1.)") .def("freeze", &PyRewritePatternSet::freeze, "Freeze the pattern set into a frozen one."); //---------------------------------------------------------------------------- // Mapping of the PDLResultList and PDLModule //---------------------------------------------------------------------------- #if MLIR_ENABLE_PDL_IN_PATTERNMATCH nb::class_(m, "PDLResultList") .def("append", [](PyMlirPDLResultList results, const PyValue &value) { mlirPDLResultListPushBackValue(results, value); }) .def("append", [](PyMlirPDLResultList results, const PyOperation &op) { mlirPDLResultListPushBackOperation(results, op); }) .def("append", [](PyMlirPDLResultList results, const PyType &type) { mlirPDLResultListPushBackType(results, type); }) .def("append", [](PyMlirPDLResultList results, const PyAttribute &attr) { mlirPDLResultListPushBackAttribute(results, attr); }); nb::class_(m, "PDLModule") .def( "__init__", [](PyPDLPatternModule &self, PyModule &module) { new (&self) PyPDLPatternModule( mlirPDLPatternModuleFromModule(module.get())); }, "module"_a, "Create a PDL module from the given module.") .def( "__init__", [](PyPDLPatternModule &self, PyModule &module) { new (&self) PyPDLPatternModule( mlirPDLPatternModuleFromModule(module.get())); }, "module"_a, "Create a PDL module from the given module.") .def( "freeze", [](PyPDLPatternModule &self) { return PyFrozenRewritePatternSet(mlirFreezeRewritePattern( mlirRewritePatternSetFromPDLPatternModule(self.get()))); }, nb::keep_alive<0, 1>()) .def( "register_rewrite_function", [](PyPDLPatternModule &self, const std::string &name, const nb::callable &fn) { self.registerRewriteFunction(name, fn); }, nb::keep_alive<1, 3>()) .def( "register_constraint_function", [](PyPDLPatternModule &self, const std::string &name, const nb::callable &fn) { self.registerConstraintFunction(name, fn); }, nb::keep_alive<1, 3>()); #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH nb::class_(m, "GreedyRewriteDriverConfig") .def(nb::init<>(), "Create a greedy rewrite driver config with defaults") .def_prop_rw("max_iterations", &PyGreedyRewriteDriverConfig::getMaxIterations, &PyGreedyRewriteDriverConfig::setMaxIterations, "Maximum number of iterations") .def_prop_rw("max_num_rewrites", &PyGreedyRewriteDriverConfig::getMaxNumRewrites, &PyGreedyRewriteDriverConfig::setMaxNumRewrites, "Maximum number of rewrites per iteration") .def_prop_rw("use_top_down_traversal", &PyGreedyRewriteDriverConfig::getUseTopDownTraversal, &PyGreedyRewriteDriverConfig::setUseTopDownTraversal, "Whether to use top-down traversal") .def_prop_rw("enable_folding", &PyGreedyRewriteDriverConfig::isFoldingEnabled, &PyGreedyRewriteDriverConfig::enableFolding, "Enable or disable folding") .def_prop_rw("strictness", &PyGreedyRewriteDriverConfig::getStrictness, &PyGreedyRewriteDriverConfig::setStrictness, "Rewrite strictness level") .def_prop_rw("region_simplification_level", &PyGreedyRewriteDriverConfig::getRegionSimplificationLevel, &PyGreedyRewriteDriverConfig::setRegionSimplificationLevel, "Region simplification level") .def_prop_rw("enable_constant_cse", &PyGreedyRewriteDriverConfig::isConstantCSEEnabled, &PyGreedyRewriteDriverConfig::enableConstantCSE, "Enable or disable constant CSE"); nb::class_(m, "FrozenRewritePatternSet") .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyFrozenRewritePatternSet::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyFrozenRewritePatternSet::createFromCapsule); m.def( "apply_patterns_and_fold_greedily", [](PyModule &module, PyFrozenRewritePatternSet &set) { auto status = mlirApplyPatternsAndFoldGreedily( module.get(), set.get(), mlirGreedyRewriteDriverConfigCreate()); if (mlirLogicalResultIsFailure(status)) throw std::runtime_error("pattern application failed to converge"); }, "module"_a, "set"_a, "Applys the given patterns to the given module greedily while folding " "results.") .def( "apply_patterns_and_fold_greedily", [](PyOperationBase &op, PyFrozenRewritePatternSet &set) { auto status = mlirApplyPatternsAndFoldGreedilyWithOp( op.getOperation(), set.get(), mlirGreedyRewriteDriverConfigCreate()); if (mlirLogicalResultIsFailure(status)) throw std::runtime_error( "pattern application failed to converge"); }, "op"_a, "set"_a, "Applys the given patterns to the given op greedily while folding " "results.") .def( "walk_and_apply_patterns", [](PyOperationBase &op, PyFrozenRewritePatternSet &set) { mlirWalkAndApplyPatterns(op.getOperation(), set.get()); }, "op"_a, "set"_a, "Applies the given patterns to the given op by a fast walk-based " "driver."); } } // namespace MLIR_BINDINGS_PYTHON_DOMAIN } // namespace python } // namespace mlir