Historical context: `PyMlirContext::liveOperations` was an optimization meant to cut down on the number of Python object allocations and (partially) a mechanism for updating validity of ops after transformation. E.g. during walking/transforming the AST. See original patch [here](https://reviews.llvm.org/D87958). Inspired by a [renewed](https://github.com/llvm/llvm-project/pull/139721#issuecomment-3217131918) interest in https://github.com/llvm/llvm-project/pull/139721 (which has become a little stale...) <p align="center"> <img width="504" height="375" alt="image" src="https://github.com/user-attachments/assets/0daad562-d3d1-4876-8d01-5dba382ab186" /> </p> In the previous go-around (https://github.com/llvm/llvm-project/pull/92631) there were two issues which have been resolved 1. ops that were "fetched" under a root op which has been transformed are no longer reported as invalid. We simply "[formally forbid](https://github.com/llvm/llvm-project/pull/92631#issuecomment-2119397018)" this; 2. `Module._CAPICreate(module_capsule)` must now be followed by a `module._clear_mlir_module()` to prevent double-freeing of the actual `ModuleOp` object (i.e. calling the dtor on the `OwningOpRef<ModuleOp>`): ```python module = ... module_dup = Module._CAPICreate(module._CAPIPtr) module._clear_mlir_module() ``` - **the alternative choice** here is to remove the `Module._CAPICreate` API altogether and replace it with something like `Module._move(module)` which will do both `Module._CAPICreate` and `module._clear_mlir_module`. Note, the other approach I explored last year was a [weakref system](https://github.com/llvm/llvm-project/pull/97340) for `mlir::Operation` which would effectively hoist this `liveOperations` thing into MLIR core. Possibly doable but I now believe it's a bad idea. The other potentially breaking change is `is`, which checks object equality rather than value equality, will now report `False` because we are always allocating `new` Python objects (ie that's the whole point of this change). Users wanting to check equality for `Operation` and `Module` should use `==`.
169 lines
4.5 KiB
Python
169 lines
4.5 KiB
Python
# RUN: %PYTHON %s | FileCheck %s
|
|
|
|
import gc
|
|
from tempfile import NamedTemporaryFile
|
|
from mlir.ir import *
|
|
|
|
|
|
def run(f):
|
|
print("\nTEST:", f.__name__)
|
|
f()
|
|
gc.collect()
|
|
assert Context._get_live_count() == 0
|
|
return f
|
|
|
|
|
|
# Verify successful parse.
|
|
# CHECK-LABEL: TEST: testParseSuccess
|
|
# CHECK: module @successfulParse
|
|
@run
|
|
def testParseSuccess():
|
|
ctx = Context()
|
|
module = Module.parse(r"""module @successfulParse {}""", ctx)
|
|
assert module.context is ctx
|
|
print("CLEAR CONTEXT")
|
|
ctx = None # Ensure that module captures the context.
|
|
gc.collect()
|
|
module.dump() # Just outputs to stderr. Verifies that it functions.
|
|
print(str(module))
|
|
|
|
|
|
# Verify successful parse from file.
|
|
# CHECK-LABEL: TEST: testParseFromFileSuccess
|
|
# CHECK: module @successfulParse
|
|
@run
|
|
def testParseFromFileSuccess():
|
|
ctx = Context()
|
|
with NamedTemporaryFile(mode="w") as tmp_file:
|
|
tmp_file.write(r"""module @successfulParse {}""")
|
|
tmp_file.flush()
|
|
module = Module.parseFile(tmp_file.name, ctx)
|
|
assert module.context is ctx
|
|
print("CLEAR CONTEXT")
|
|
ctx = None # Ensure that module captures the context.
|
|
gc.collect()
|
|
module.operation.verify()
|
|
print(str(module))
|
|
|
|
|
|
# Verify parse error.
|
|
# CHECK-LABEL: TEST: testParseError
|
|
# CHECK: testParseError: <
|
|
# CHECK: Unable to parse module assembly:
|
|
# CHECK: error: "-":1:1: expected operation name in quotes
|
|
# CHECK: >
|
|
@run
|
|
def testParseError():
|
|
ctx = Context()
|
|
try:
|
|
module = Module.parse(r"""}SYNTAX ERROR{""", ctx)
|
|
except MLIRError as e:
|
|
print(f"testParseError: <{e}>")
|
|
else:
|
|
print("Exception not produced")
|
|
|
|
|
|
# Verify successful parse.
|
|
# CHECK-LABEL: TEST: testCreateEmpty
|
|
# CHECK: module {
|
|
@run
|
|
def testCreateEmpty():
|
|
ctx = Context()
|
|
loc = Location.unknown(ctx)
|
|
module = Module.create(loc)
|
|
print("CLEAR CONTEXT")
|
|
ctx = None # Ensure that module captures the context.
|
|
gc.collect()
|
|
print(str(module))
|
|
|
|
|
|
# Verify round-trip of ASM that contains unicode.
|
|
# Note that this does not test that the print path converts unicode properly
|
|
# because MLIR asm always normalizes it to the hex encoding.
|
|
# CHECK-LABEL: TEST: testRoundtripUnicode
|
|
# CHECK: func private @roundtripUnicode()
|
|
# CHECK: foo = "\F0\9F\98\8A"
|
|
@run
|
|
def testRoundtripUnicode():
|
|
ctx = Context()
|
|
module = Module.parse(
|
|
r"""
|
|
func.func private @roundtripUnicode() attributes { foo = "😊" }
|
|
""",
|
|
ctx,
|
|
)
|
|
print(str(module))
|
|
|
|
|
|
# Verify round-trip of ASM that contains unicode.
|
|
# Note that this does not test that the print path converts unicode properly
|
|
# because MLIR asm always normalizes it to the hex encoding.
|
|
# CHECK-LABEL: TEST: testRoundtripBinary
|
|
# CHECK: func private @roundtripUnicode()
|
|
# CHECK: foo = "\F0\9F\98\8A"
|
|
@run
|
|
def testRoundtripBinary():
|
|
with Context():
|
|
module = Module.parse(
|
|
r"""
|
|
func.func private @roundtripUnicode() attributes { foo = "😊" }
|
|
"""
|
|
)
|
|
binary_asm = module.operation.get_asm(binary=True)
|
|
assert isinstance(binary_asm, bytes)
|
|
module = Module.parse(binary_asm)
|
|
print(module)
|
|
|
|
|
|
# Tests that module.operation works and correctly interns instances.
|
|
# CHECK-LABEL: TEST: testModuleOperation
|
|
@run
|
|
def testModuleOperation():
|
|
ctx = Context()
|
|
module = Module.parse(r"""module @successfulParse {}""", ctx)
|
|
op1 = module.operation
|
|
# CHECK: module @successfulParse
|
|
print(op1)
|
|
|
|
# Ensure that operations are the same on multiple calls.
|
|
op2 = module.operation
|
|
assert op1 is not op2
|
|
assert op1 == op2
|
|
|
|
# Test live operation clearing.
|
|
op1 = module.operation
|
|
op1 = None
|
|
gc.collect()
|
|
op1 = module.operation
|
|
|
|
# Ensure that if module is de-referenced, the operations are still valid.
|
|
module = None
|
|
gc.collect()
|
|
print(op1)
|
|
|
|
# Collect and verify lifetime.
|
|
op1 = None
|
|
op2 = None
|
|
gc.collect()
|
|
|
|
|
|
# CHECK-LABEL: TEST: testModuleCapsule
|
|
@run
|
|
def testModuleCapsule():
|
|
ctx = Context()
|
|
module = Module.parse(r"""module @successfulParse {}""", ctx)
|
|
# CHECK: "mlir.ir.Module._CAPIPtr"
|
|
module_capsule = module._CAPIPtr
|
|
print(module_capsule)
|
|
module_dup = Module._CAPICreate(module_capsule)
|
|
assert module is not module_dup
|
|
assert module == module_dup
|
|
module._clear_mlir_module()
|
|
assert module != module_dup
|
|
assert module_dup.context is ctx
|
|
# Gc and verify destructed.
|
|
module = None
|
|
module_capsule = None
|
|
module_dup = None
|
|
gc.collect()
|