Twice 972aa597de
[MLIR][Python] Make traits declarative in python-defined operations (#180748)
This will support two syntax in python-defined dialects.

First is that traits can now be declared in class parameters, e.g.
```python
class ParentIsIfTrait(DynamicOpTrait): #define a python-side trait
    @staticmethod
    def verify_invariants(op) -> bool:
        if not isinstance(op.parent.opview, IfOp):
            op.location.emit_error(
                f"{op.name} should be put inside {IfOp.OPERATION_NAME}"
            )
            return False
        return True

class YieldOp( # attach two traits: IsTerminatorTrait, ParentIsIfTrait
    TestRegion.Operation, name="yield", traits=[IsTerminatorTrait, ParentIsIfTrait]
):
    ...
```

Second is that users can directly define
`verify_invariants`/`verify_region_invariants` methods in the operation
to add additional custom verification logic. And this is implemented via
traits.
```python
class YieldOp(TestRegion.Operation, name="yield", ...):
    value: Operand[Any]

    def verify_invariants(self) -> bool: # define a method directly
        if self.parent.results[0].type != self.value.type:
            self.location.emit_error(
                "result type mismatch between YieldOp and its parent IfOp"
            )
            return False
        return True
```

Previously we use `verify`/`verify_region` as method names (in
yesterday's PR #179705), but in this PR they are renamed to
`verify_invariants`/`verify_region_invariants` because there are
conflicts between the newly-added `verify` method and `ir.OpView.verify`
method:
- `verify_invariants` is just to attach **additional** verification
logic. but `OpView.verify` is to construct an OperationVerifer and do
full verification for an operation, so the semantics is not same between
these two. We should not shadow the `OpView.verify` method by defining a
new semantically-different `verify` method.
- it will make users confuse between these two `verify` methods, since
they have different meaning.
- if users didn't define the `verify` method in their python-defined
operation, `DynamicOpTraits.attach(opname, MyOpCls)` still do the
attaching (because `hasattr("verify")` returns `True`) and seg fault
(because we cannot attach `OpView.verify`).

---------

Co-authored-by: Rolf Morel <rolfmorel@gmail.com>
2026-02-11 20:39:58 +08:00

522 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# RUN: %PYTHON %s 2>&1 | FileCheck %s
from mlir.ir import *
from mlir.dialects import arith
from mlir.dialects.ext import *
from typing import Any, Optional, Sequence, TypeVar, Union
import sys
def run(f):
print("\nTEST:", f.__name__)
f()
# CHECK: TEST: testMyInt
@run
def testMyInt():
class MyInt(Dialect, name="myint"):
pass
i32 = IntegerType[32]
class ConstantOp(MyInt.Operation, name="constant"):
value: IntegerAttr
cst: Result[i32]
class AddOp(MyInt.Operation, name="add"):
lhs: Operand[i32]
rhs: Operand[i32]
res: Result[i32]
# CHECK: irdl.dialect @myint {
# CHECK: irdl.operation @constant {
# CHECK: %0 = irdl.base "#builtin.integer"
# CHECK: irdl.attributes {"value" = %0}
# CHECK: %1 = irdl.is i32
# CHECK: irdl.results(cst: %1)
# CHECK: }
# CHECK: irdl.operation @add {
# CHECK: %0 = irdl.is i32
# CHECK: irdl.operands(lhs: %0, rhs: %0)
# CHECK: irdl.results(res: %0)
# CHECK: }
# CHECK: }
with Context(), Location.unknown():
MyInt.load()
print(MyInt._mlir_module)
# CHECK: ['constant', 'add']
print([i._op_name for i in MyInt.operations])
i32 = IntegerType.get_signless(32)
module = Module.create()
with InsertionPoint(module.body):
two = ConstantOp(IntegerAttr.get(i32, 2))
three = ConstantOp(IntegerAttr.get(i32, 3))
add1 = AddOp(two, three)
add2 = AddOp(add1, two)
add3 = AddOp(add2, three)
# CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32
# CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32
# CHECK: %2 = "myint.add"(%0, %1) : (i32, i32) -> i32
# CHECK: %3 = "myint.add"(%2, %0) : (i32, i32) -> i32
# CHECK: %4 = "myint.add"(%3, %1) : (i32, i32) -> i32
print(module)
assert module.operation.verify()
# CHECK: AddOp
print(type(add1).__name__)
# CHECK: ConstantOp
print(type(two).__name__)
# CHECK: myint.add
print(add1.OPERATION_NAME)
# CHECK: None
print(add1._ODS_OPERAND_SEGMENTS)
# CHECK: None
print(add1._ODS_RESULT_SEGMENTS)
# CHECK: (0, True)
print(add1._ODS_REGIONS)
# CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32
print(add1.lhs.owner)
# CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32
print(add1.rhs.owner)
# CHECK: 2 : i32
print(two.value)
# CHECK: OpResult(%0
print(two.cst)
# CHECK: (self, /, lhs, rhs, *, loc=None, ip=None)
print(AddOp.__init__.__signature__)
# CHECK: (self, /, value, *, loc=None, ip=None)
print(ConstantOp.__init__.__signature__)
# CHECK: TEST: testExtDialect
@run
def testExtDialect():
class Test(Dialect, name="ext_test"):
pass
i32 = IntegerType[32]
class ConstraintOp(Test.Operation, name="constraint"):
a: Operand[i32 | IntegerType[64]]
b: Operand[Any]
# Here we use `F32Type[()]` instead of just `F32Type`
# because of an existing issue in IRDL implementation
# where `irdl.base` cannot exist in `irdl.any_of`.
c: Operand[F32Type[()] | i32]
d: Operand[Any]
x: IntegerAttr
y: FloatAttr
class OptionalOp(Test.Operation, name="optional"):
a: Operand[i32]
b: Optional[Operand[i32]]
out1: Result[i32]
out2: Result[i32] | None
out3: Result[i32]
class Optional2Op(Test.Operation, name="optional2"):
a: Optional[Operand[i32]]
b: Optional[Result[i32]]
class VariadicOp(Test.Operation, name="variadic"):
a: Operand[i32]
b: Optional[Operand[i32]]
c: Sequence[Operand[i32]]
out1: Sequence[Result[i32]]
out2: Sequence[Result[i32]]
out3: Optional[Result[i32]]
out4: Result[i32]
class Variadic2Op(Test.Operation, name="variadic2"):
a: Sequence[Operand[i32]]
b: Sequence[Result[i32]]
class MixedOpBase(Test.Operation):
out: Result[i32]
in1: Operand[i32]
class MixedOp(MixedOpBase, name="mixed"):
in2: IntegerAttr
in3: Optional[Operand[i32]]
in4: IntegerAttr
in5: Operand[i32]
T = TypeVar("T")
U = TypeVar("U", bound=IntegerType[32] | IntegerType[64])
V = TypeVar("V", bound=Union[IntegerType[8], IntegerType[16]])
class TypeVarOp(Test.Operation, name="type_var"):
in1: Operand[T]
in2: Operand[T]
in3: Operand[U]
in4: Operand[U | V]
in5: Operand[V]
# CHECK: irdl.dialect @ext_test {
# CHECK: irdl.operation @constraint {
# CHECK: %0 = irdl.is i32
# CHECK: %1 = irdl.is i64
# CHECK: %2 = irdl.any_of(%0, %1)
# CHECK: %3 = irdl.any
# CHECK: %4 = irdl.is f32
# CHECK: %5 = irdl.any_of(%4, %0)
# CHECK: %6 = irdl.any
# CHECK: irdl.operands(a: %2, b: %3, c: %5, d: %6)
# CHECK: %7 = irdl.base "#builtin.integer"
# CHECK: %8 = irdl.base "#builtin.float"
# CHECK: irdl.attributes {"x" = %7, "y" = %8}
# CHECK: }
# CHECK: irdl.operation @optional {
# CHECK: %0 = irdl.is i32
# CHECK: irdl.operands(a: %0, b: optional %0)
# CHECK: irdl.results(out1: %0, out2: optional %0, out3: %0)
# CHECK: }
# CHECK: irdl.operation @optional2 {
# CHECK: %0 = irdl.is i32
# CHECK: irdl.operands(a: optional %0)
# CHECK: irdl.results(b: optional %0)
# CHECK: }
# CHECK: irdl.operation @variadic {
# CHECK: %0 = irdl.is i32
# CHECK: irdl.operands(a: %0, b: optional %0, c: variadic %0)
# CHECK: irdl.results(out1: variadic %0, out2: variadic %0, out3: optional %0, out4: %0)
# CHECK: }
# CHECK: irdl.operation @variadic2 {
# CHECK: %0 = irdl.is i32
# CHECK: irdl.operands(a: variadic %0)
# CHECK: irdl.results(b: variadic %0)
# CHECK: }
# CHECK: irdl.operation @mixed {
# CHECK: %0 = irdl.is i32
# CHECK: irdl.operands(in1: %0, in3: optional %0, in5: %0)
# CHECK: %1 = irdl.base "#builtin.integer"
# CHECK: %2 = irdl.base "#builtin.integer"
# CHECK: irdl.attributes {"in2" = %1, "in4" = %2}
# CHECK: irdl.results(out: %0)
# CHECK: }
# CHECK: irdl.operation @type_var {
# CHECK: %0 = irdl.any
# CHECK: %1 = irdl.is i32
# CHECK: %2 = irdl.is i64
# CHECK: %3 = irdl.any_of(%1, %2)
# CHECK: %4 = irdl.is i8
# CHECK: %5 = irdl.is i16
# CHECK: %6 = irdl.any_of(%4, %5)
# CHECK: %7 = irdl.any_of(%3, %6)
# CHECK: irdl.operands(in1: %0, in2: %0, in3: %3, in4: %7, in5: %6)
# CHECK: }
# CHECK: }
with Context(), Location.unknown():
Test.load()
print(Test._mlir_module)
# CHECK: (self, /, a, b, c, d, x, y, *, loc=None, ip=None)
print(ConstraintOp.__init__.__signature__)
# CHECK: (self, /, out1, out3, a, *, out2=None, b=None, loc=None, ip=None)
print(OptionalOp.__init__.__signature__)
# CHECK: (self, /, *, b=None, a=None, loc=None, ip=None)
print(Optional2Op.__init__.__signature__)
# CHECK: (self, /, out1, out2, out4, a, c, *, out3=None, b=None, loc=None, ip=None)
print(VariadicOp.__init__.__signature__)
# CHECK: (self, /, b, a, *, loc=None, ip=None)
print(Variadic2Op.__init__.__signature__)
# CHECK: (self, /, in1, in2, in4, in5, *, in3=None, loc=None, ip=None)
print(MixedOp.__init__.__signature__)
# CHECK: None None
print(ConstraintOp._ODS_OPERAND_SEGMENTS, ConstraintOp._ODS_RESULT_SEGMENTS)
# CHECK: [1, 0] [1, 0, 1]
print(OptionalOp._ODS_OPERAND_SEGMENTS, OptionalOp._ODS_RESULT_SEGMENTS)
# CHECK: [0] [0]
print(Optional2Op._ODS_OPERAND_SEGMENTS, Optional2Op._ODS_RESULT_SEGMENTS)
# CHECK: [1, 0, -1] [-1, -1, 0, 1]
print(VariadicOp._ODS_OPERAND_SEGMENTS, VariadicOp._ODS_RESULT_SEGMENTS)
# CHECK: [-1] [-1]
print(Variadic2Op._ODS_OPERAND_SEGMENTS, Variadic2Op._ODS_RESULT_SEGMENTS)
i32 = IntegerType.get_signless(32)
i64 = IntegerType.get_signless(64)
f32 = F32Type.get()
iattr = IntegerAttr.get(i32, 2)
fattr = FloatAttr.get_f32(2.3)
module = Module.create()
with InsertionPoint(module.body):
ione = arith.constant(i32, 1)
fone = arith.constant(f32, 1.2)
# CHECK: "ext_test.constraint"(%c1_i32, %c1_i32, %cst, %c1_i32) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, i32, f32, i32) -> ()
c1 = ConstraintOp(ione, ione, fone, ione, iattr, fattr)
# CHECK: "ext_test.constraint"(%c1_i32, %cst, %cst, %cst) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, f32, f32, f32) -> ()
ConstraintOp(ione, fone, fone, fone, iattr, fattr)
# CHECK: ext_test.constraint"(%c1_i32, %cst, %c1_i32, %cst) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, f32, i32, f32) -> ()
ConstraintOp(ione, fone, ione, fone, iattr, fattr)
# CHECK: %0:2 = "ext_test.optional"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0>, resultSegmentSizes = array<i32: 1, 0, 1>} : (i32) -> (i32, i32)
o1 = OptionalOp(i32, i32, ione)
# CHECK: %1:3 = "ext_test.optional"(%c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 1>, resultSegmentSizes = array<i32: 1, 1, 1>} : (i32, i32) -> (i32, i32, i32)
o2 = OptionalOp(i32, i32, ione, out2=i32, b=ione)
# CHECK: ext_test.optional2"() {operandSegmentSizes = array<i32: 0>, resultSegmentSizes = array<i32: 0>} : () -> ()
o3 = Optional2Op()
# CHECK: %2 = "ext_test.optional2"() {operandSegmentSizes = array<i32: 0>, resultSegmentSizes = array<i32: 1>} : () -> i32
o4 = Optional2Op(b=i32)
# CHECK: "ext_test.optional2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 0>} : (i32) -> ()
o5 = Optional2Op(a=ione)
# CHECK: %3 = "ext_test.optional2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 1>} : (i32) -> i32
o6 = Optional2Op(b=i32, a=ione)
# CHECK: %4:4 = "ext_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 0, 2>, resultSegmentSizes = array<i32: 1, 2, 0, 1>} : (i32, i32, i32) -> (i32, i32, i32, i32)
v1 = VariadicOp([i32], [i32, i32], i32, ione, [ione, ione])
# CHECK: %5:5 = "ext_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 1, 1>, resultSegmentSizes = array<i32: 1, 2, 1, 1>} : (i32, i32, i32) -> (i32, i32, i32, i32, i32)
v2 = VariadicOp([i32], [i32, i32], i32, ione, [ione], out3=i32, b=ione)
# CHECK: %6:4 = "ext_test.variadic"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0, 0>, resultSegmentSizes = array<i32: 2, 1, 0, 1>} : (i32) -> (i32, i32, i32, i32)
v3 = VariadicOp([i32, i32], [i32], i32, ione, [])
# CHECK: "ext_test.variadic2"() {operandSegmentSizes = array<i32: 0>, resultSegmentSizes = array<i32: 0>} : () -> ()
v4 = Variadic2Op([], [])
# CHECK: "ext_test.variadic2"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 3>, resultSegmentSizes = array<i32: 0>} : (i32, i32, i32) -> ()
v5 = Variadic2Op([], [ione, ione, ione])
# CHECK: %7:2 = "ext_test.variadic2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 2>} : (i32) -> (i32, i32)
v6 = Variadic2Op([i32, i32], [ione])
# CHECK: %8 = "ext_test.mixed"(%c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 0, 1>} : (i32, i32) -> i32
m1 = MixedOp(ione, iattr, iattr, ione)
# CHECK: %9 = "ext_test.mixed"(%c1_i32, %c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 1, 1>} : (i32, i32, i32) -> i32
m2 = MixedOp(ione, iattr, iattr, ione, in3=ione)
print(module)
assert module.operation.verify()
# CHECK: OpResult(%c1_i32
print(c1.a)
# CHECK: 2 : i32
print(c1.x)
# CHECK: OpResult(%c1_i32
print(o1.a)
# CHECK: None
print(o1.b)
# CHECK: OpResult(%c1_i32
print(o2.b)
# CHECK: 0
print(o1.out1.result_number)
# CHECK: None
print(o1.out2)
# CHECK: 0
print(o2.out1.result_number)
# CHECK: 1
print(o2.out2.result_number)
# CHECK: None
print(o3.a)
# CHECK: OpResult(%c1_i32
print(o5.a)
# CHECK: ['OpResult(%c1_i32 = arith.constant 1 : i32)', 'OpResult(%c1_i32 = arith.constant 1 : i32)']
print([str(i) for i in v1.c])
# CHECK: ['OpResult(%c1_i32 = arith.constant 1 : i32)']
print([str(i) for i in v2.c])
# CHECK: []
print([str(i) for i in v3.c])
# CHECK: 0 0
print(len(v4.a), len(v4.b))
# CHECK: 3 0
print(len(v5.a), len(v5.b))
# CHECK: 1 2
print(len(v6.a), len(v6.b))
# cases to violate constraits
module = Module.create()
with InsertionPoint(module.body):
try:
c1 = ConstraintOp(ione, ione, fone, ione, iattr)
except TypeError as e:
# CHECK: missing a required argument: 'y'
print(e)
try:
c2 = ConstraintOp(ione, ione, fone, ione, iattr, fattr, ione)
except TypeError as e:
# CHECKtoo many positional arguments
print(e)
# CHECK: TEST: testExtDialectWithRegion
@run
def testExtDialectWithRegion():
class ParentIsIfTrait(DynamicOpTrait):
@staticmethod
def verify_invariants(op) -> bool:
if not isinstance(op.parent.opview, IfOp):
op.location.emit_error(
f"{op.name} should be put inside {IfOp.OPERATION_NAME}"
)
return False
return True
class TestRegion(Dialect, name="ext_region"):
pass
class IfOp(TestRegion.Operation, name="if"):
cond: Operand[IntegerType[1]]
result: Result[Any]
then: Region
else_: Region
class YieldOp(
TestRegion.Operation, name="yield", traits=[IsTerminatorTrait, ParentIsIfTrait]
):
value: Operand[Any]
def verify_invariants(self) -> bool:
if self.parent.results[0].type != self.value.type:
self.location.emit_error(
"result type mismatch between YieldOp and its parent IfOp"
)
return False
return True
class NoTermOp(TestRegion.Operation, name="no_term", traits=[NoTerminatorTrait]):
body: Region
with Context(), Location.unknown():
TestRegion.load()
# CHECK: irdl.dialect @ext_region {
# CHECK: irdl.operation @if {
# CHECK: %0 = irdl.is i1
# CHECK: irdl.operands(cond: %0)
# CHECK: %1 = irdl.any
# CHECK: irdl.results(result: %1)
# CHECK: %2 = irdl.region
# CHECK: %3 = irdl.region
# CHECK: irdl.regions(then: %2, else_: %3)
# CHECK: }
# CHECK: irdl.operation @yield {
# CHECK: %0 = irdl.any
# CHECK: irdl.operands(value: %0)
# CHECK: }
# CHECK: irdl.operation @no_term {
# CHECK: %0 = irdl.region
# CHECK: irdl.regions(body: %0)
# CHECK: }
# CHECK: }
print(TestRegion._mlir_module)
# CHECK: (self, /, result, cond, *, loc=None, ip=None)
print(IfOp.__init__.__signature__)
# CHECK: None None
print(IfOp._ODS_OPERAND_SEGMENTS, IfOp._ODS_RESULT_SEGMENTS)
# CHECK: (2, True)
print(IfOp._ODS_REGIONS)
module = Module.create()
with InsertionPoint(module.body):
i1 = IntegerType.get_signless(1)
i32 = IntegerType.get_signless(32)
cond = arith.constant(i1, 1)
if_ = IfOp(i32, cond)
if_.then.blocks.append()
if_.else_.blocks.append()
with InsertionPoint(if_.then.blocks[0]):
v = arith.constant(i32, 2)
YieldOp(v)
with InsertionPoint(if_.else_.blocks[0]):
v = arith.constant(i32, 3)
YieldOp(v)
nt = NoTermOp()
nt.body.blocks.append()
with InsertionPoint(nt.body.blocks[0]):
arith.constant(i32, 4)
# No terminator here
assert module.operation.verify()
# CHECK: module {
# CHECK: %true = arith.constant true
# CHECK: %0 = "ext_region.if"(%true) ({
# CHECK: %c2_i32 = arith.constant 2 : i32
# CHECK: "ext_region.yield"(%c2_i32) : (i32) -> ()
# CHECK: }, {
# CHECK: %c3_i32 = arith.constant 3 : i32
# CHECK: "ext_region.yield"(%c3_i32) : (i32) -> ()
# CHECK: }) : (i1) -> i32
# CHECK: "ext_region.no_term"() ({
# CHECK: %c4_i32 = arith.constant 4 : i32
# CHECK: }) : () -> ()
# CHECK: }
print(module)
# CHECK: %c2_i32 = arith.constant 2 : i32
print(if_.then.blocks[0])
# CHECK: %c3_i32 = arith.constant 3 : i32
print(if_.else_.blocks[0])
# CHECK-LABEL: Testing violation cases
print("Testing violation cases:")
module = Module.create()
with InsertionPoint(module.body):
i1 = IntegerType.get_signless(1)
i32 = IntegerType.get_signless(32)
cond = arith.constant(i1, 1)
if_ = IfOp(i32, cond)
if_.then.blocks.append()
if_.else_.blocks.append()
with InsertionPoint(if_.then.blocks[0]):
v = arith.constant(i32, 2)
with InsertionPoint(if_.else_.blocks[0]):
v = arith.constant(i32, 3)
try:
module.operation.verify()
except Exception as e:
# CHECK: Verification failed:
# CHECK: block with no terminator
print(e)
module = Module.create()
with InsertionPoint(module.body):
v = arith.constant(i32, 2)
YieldOp(v)
try:
module.operation.verify()
except Exception as e:
# CHECK: Verification failed:
# CHECK: ext_region.yield should be put inside ext_region.if
print(e)
module = Module.create()
with InsertionPoint(module.body):
i1 = IntegerType.get_signless(1)
i32 = IntegerType.get_signless(32)
cond = arith.constant(i1, 1)
if_ = IfOp(i1, cond)
if_.then.blocks.append()
if_.else_.blocks.append()
with InsertionPoint(if_.then.blocks[0]):
v = arith.constant(i32, 2)
YieldOp(v)
with InsertionPoint(if_.else_.blocks[0]):
v = arith.constant(i32, 3)
YieldOp(v)
try:
module.operation.verify()
except Exception as e:
# CHECK: Verification failed:
# CHECK: result type mismatch
print(e)