llvm-project/mlir/test/python/python_pass.py
Twice 7d04e37904
[MLIR][Python] Support Python-defined passes in MLIR (#156000)
It closes #155996.

This PR added a method `add(callable, ..)` to
`mlir.passmanager.PassManager` to accept a callable object for defining
passes in the Python side.

This is a simple example of a Python-defined pass.
```python
from mlir.passmanager import PassManager

def demo_pass_1(op):
    # do something with op
    pass

class DemoPass:
    def __init__(self, ...):
        pass
    def __call__(op):
        # do something
        pass

demo_pass_2 = DemoPass(..)

pm = PassManager('any', ctx)
pm.add(demo_pass_1)
pm.add(demo_pass_2)
pm.add("registered-passes")
pm.run(..)
```

---------

Co-authored-by: cnb.bsD2OPwAgEA <QejD2DJ2eEahUVy6Zg0aZI+cnb.bsD2OPwAgEA@noreply.cnb.cool>
Co-authored-by: Maksim Levental <maksim.levental@gmail.com>
2025-09-08 18:01:23 -07:00

89 lines
2.5 KiB
Python

# RUN: %PYTHON %s 2>&1 | FileCheck %s
import gc, sys
from mlir.ir import *
from mlir.passmanager import *
from mlir.dialects.builtin import ModuleOp
from mlir.dialects import pdl
from mlir.rewrite import *
def log(*args):
print(*args, file=sys.stderr)
sys.stderr.flush()
def run(f):
log("\nTEST:", f.__name__)
f()
gc.collect()
assert Context._get_live_count() == 0
def make_pdl_module():
with Location.unknown():
pdl_module = Module.create()
with InsertionPoint(pdl_module.body):
# Change all arith.addi with index types to arith.muli.
@pdl.pattern(benefit=1, sym_name="addi_to_mul")
def pat():
# Match arith.addi with index types.
i64_type = pdl.TypeOp(IntegerType.get_signless(64))
operand0 = pdl.OperandOp(i64_type)
operand1 = pdl.OperandOp(i64_type)
op0 = pdl.OperationOp(
name="arith.addi", args=[operand0, operand1], types=[i64_type]
)
# Replace the matched op with arith.muli.
@pdl.rewrite()
def rew():
newOp = pdl.OperationOp(
name="arith.muli", args=[operand0, operand1], types=[i64_type]
)
pdl.ReplaceOp(op0, with_op=newOp)
return pdl_module
# CHECK-LABEL: TEST: testCustomPass
@run
def testCustomPass():
with Context():
pdl_module = make_pdl_module()
frozen = PDLModule(pdl_module).freeze()
module = ModuleOp.parse(
r"""
module {
func.func @add(%a: i64, %b: i64) -> i64 {
%sum = arith.addi %a, %b : i64
return %sum : i64
}
}
"""
)
def custom_pass_1(op):
print("hello from pass 1!!!", file=sys.stderr)
class CustomPass2:
def __call__(self, m):
apply_patterns_and_fold_greedily(m, frozen)
custom_pass_2 = CustomPass2()
pm = PassManager("any")
pm.enable_ir_printing()
# CHECK: hello from pass 1!!!
# CHECK-LABEL: Dump After custom_pass_1
pm.add(custom_pass_1)
# CHECK-LABEL: Dump After CustomPass2
# CHECK: arith.muli
pm.add(custom_pass_2, "CustomPass2")
# CHECK-LABEL: Dump After ArithToLLVMConversionPass
# CHECK: llvm.mul
pm.add("convert-arith-to-llvm")
pm.run(module)