Twice 8542514e5c
[MLIR][Python] Allow passing dialect as a class keyword argument (#182465)
Previously, we constructed new ops using the pattern `class
MyOp(MyInt.Operation)`.

Now we’ve added a new pattern: `class MyOp(Operation, dialect=MyInt)`,
which allows more flexible composition. For example:
```python
class BinOpBase(Operation): # it can be used in any dialect!
  res: Result[Any]
  lhs: Operand[Any]
  rhs: Operand[Any]
  
class MyInt(Dialect, name="myint"):
  pass

class AddOp(BinOpBase, dialect=MyInt, name="add"):
  ...
```
2026-02-22 18:52:57 +08:00

522 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# RUN: %PYTHON %s 2>&1 | FileCheck %s
from mlir.ir import *
from mlir.dialects import arith
from mlir.dialects.ext import *
from typing import Any, Optional, Sequence, TypeVar, Union
import sys
def run(f):
print("\nTEST:", f.__name__)
f()
# CHECK: TEST: testMyInt
@run
def testMyInt():
class MyInt(Dialect, name="myint"):
pass
i32 = IntegerType[32]
class ConstantOp(MyInt.Operation, name="constant"):
value: IntegerAttr
cst: Result[i32]
class AddOp(Operation, dialect=MyInt, name="add"):
lhs: Operand[i32]
rhs: Operand[i32]
res: Result[i32]
# CHECK: irdl.dialect @myint {
# CHECK: irdl.operation @constant {
# CHECK: %0 = irdl.base "#builtin.integer"
# CHECK: irdl.attributes {"value" = %0}
# CHECK: %1 = irdl.is i32
# CHECK: irdl.results(cst: %1)
# CHECK: }
# CHECK: irdl.operation @add {
# CHECK: %0 = irdl.is i32
# CHECK: irdl.operands(lhs: %0, rhs: %0)
# CHECK: irdl.results(res: %0)
# CHECK: }
# CHECK: }
with Context(), Location.unknown():
MyInt.load()
print(MyInt._mlir_module)
# CHECK: ['constant', 'add']
print([i._op_name for i in MyInt.operations])
i32 = IntegerType.get_signless(32)
module = Module.create()
with InsertionPoint(module.body):
two = ConstantOp(IntegerAttr.get(i32, 2))
three = ConstantOp(IntegerAttr.get(i32, 3))
add1 = AddOp(two, three)
add2 = AddOp(add1, two)
add3 = AddOp(add2, three)
# CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32
# CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32
# CHECK: %2 = "myint.add"(%0, %1) : (i32, i32) -> i32
# CHECK: %3 = "myint.add"(%2, %0) : (i32, i32) -> i32
# CHECK: %4 = "myint.add"(%3, %1) : (i32, i32) -> i32
print(module)
assert module.operation.verify()
# CHECK: AddOp
print(type(add1).__name__)
# CHECK: ConstantOp
print(type(two).__name__)
# CHECK: myint.add
print(add1.OPERATION_NAME)
# CHECK: None
print(add1._ODS_OPERAND_SEGMENTS)
# CHECK: None
print(add1._ODS_RESULT_SEGMENTS)
# CHECK: (0, True)
print(add1._ODS_REGIONS)
# CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32
print(add1.lhs.owner)
# CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32
print(add1.rhs.owner)
# CHECK: 2 : i32
print(two.value)
# CHECK: OpResult(%0
print(two.cst)
# CHECK: (self, /, lhs, rhs, *, loc=None, ip=None)
print(AddOp.__init__.__signature__)
# CHECK: (self, /, value, *, loc=None, ip=None)
print(ConstantOp.__init__.__signature__)
# CHECK: TEST: testExtDialect
@run
def testExtDialect():
class Test(Dialect, name="ext_test"):
pass
i32 = IntegerType[32]
class ConstraintOp(Test.Operation, name="constraint"):
a: Operand[i32 | IntegerType[64]]
b: Operand[Any]
# Here we use `F32Type[()]` instead of just `F32Type`
# because of an existing issue in IRDL implementation
# where `irdl.base` cannot exist in `irdl.any_of`.
c: Operand[F32Type[()] | i32]
d: Operand[Any]
x: IntegerAttr
y: FloatAttr
class OptionalOp(Test.Operation, name="optional"):
a: Operand[i32]
b: Optional[Operand[i32]]
out1: Result[i32]
out2: Result[i32] | None
out3: Result[i32]
class Optional2Op(Test.Operation, name="optional2"):
a: Optional[Operand[i32]]
b: Optional[Result[i32]]
class VariadicOp(Test.Operation, name="variadic"):
a: Operand[i32]
b: Optional[Operand[i32]]
c: Sequence[Operand[i32]]
out1: Sequence[Result[i32]]
out2: Sequence[Result[i32]]
out3: Optional[Result[i32]]
out4: Result[i32]
class Variadic2Op(Test.Operation, name="variadic2"):
a: Sequence[Operand[i32]]
b: Sequence[Result[i32]]
class MixedOpBase(Test.Operation):
out: Result[i32]
in1: Operand[i32]
class MixedOp(MixedOpBase, name="mixed"):
in2: IntegerAttr
in3: Optional[Operand[i32]]
in4: IntegerAttr
in5: Operand[i32]
T = TypeVar("T")
U = TypeVar("U", bound=IntegerType[32] | IntegerType[64])
V = TypeVar("V", bound=Union[IntegerType[8], IntegerType[16]])
class TypeVarOp(Test.Operation, name="type_var"):
in1: Operand[T]
in2: Operand[T]
in3: Operand[U]
in4: Operand[U | V]
in5: Operand[V]
# CHECK: irdl.dialect @ext_test {
# CHECK: irdl.operation @constraint {
# CHECK: %0 = irdl.is i32
# CHECK: %1 = irdl.is i64
# CHECK: %2 = irdl.any_of(%0, %1)
# CHECK: %3 = irdl.any
# CHECK: %4 = irdl.is f32
# CHECK: %5 = irdl.any_of(%4, %0)
# CHECK: %6 = irdl.any
# CHECK: irdl.operands(a: %2, b: %3, c: %5, d: %6)
# CHECK: %7 = irdl.base "#builtin.integer"
# CHECK: %8 = irdl.base "#builtin.float"
# CHECK: irdl.attributes {"x" = %7, "y" = %8}
# CHECK: }
# CHECK: irdl.operation @optional {
# CHECK: %0 = irdl.is i32
# CHECK: irdl.operands(a: %0, b: optional %0)
# CHECK: irdl.results(out1: %0, out2: optional %0, out3: %0)
# CHECK: }
# CHECK: irdl.operation @optional2 {
# CHECK: %0 = irdl.is i32
# CHECK: irdl.operands(a: optional %0)
# CHECK: irdl.results(b: optional %0)
# CHECK: }
# CHECK: irdl.operation @variadic {
# CHECK: %0 = irdl.is i32
# CHECK: irdl.operands(a: %0, b: optional %0, c: variadic %0)
# CHECK: irdl.results(out1: variadic %0, out2: variadic %0, out3: optional %0, out4: %0)
# CHECK: }
# CHECK: irdl.operation @variadic2 {
# CHECK: %0 = irdl.is i32
# CHECK: irdl.operands(a: variadic %0)
# CHECK: irdl.results(b: variadic %0)
# CHECK: }
# CHECK: irdl.operation @mixed {
# CHECK: %0 = irdl.is i32
# CHECK: irdl.operands(in1: %0, in3: optional %0, in5: %0)
# CHECK: %1 = irdl.base "#builtin.integer"
# CHECK: %2 = irdl.base "#builtin.integer"
# CHECK: irdl.attributes {"in2" = %1, "in4" = %2}
# CHECK: irdl.results(out: %0)
# CHECK: }
# CHECK: irdl.operation @type_var {
# CHECK: %0 = irdl.any
# CHECK: %1 = irdl.is i32
# CHECK: %2 = irdl.is i64
# CHECK: %3 = irdl.any_of(%1, %2)
# CHECK: %4 = irdl.is i8
# CHECK: %5 = irdl.is i16
# CHECK: %6 = irdl.any_of(%4, %5)
# CHECK: %7 = irdl.any_of(%3, %6)
# CHECK: irdl.operands(in1: %0, in2: %0, in3: %3, in4: %7, in5: %6)
# CHECK: }
# CHECK: }
with Context(), Location.unknown():
Test.load()
print(Test._mlir_module)
# CHECK: (self, /, a, b, c, d, x, y, *, loc=None, ip=None)
print(ConstraintOp.__init__.__signature__)
# CHECK: (self, /, out1, out3, a, *, out2=None, b=None, loc=None, ip=None)
print(OptionalOp.__init__.__signature__)
# CHECK: (self, /, *, b=None, a=None, loc=None, ip=None)
print(Optional2Op.__init__.__signature__)
# CHECK: (self, /, out1, out2, out4, a, c, *, out3=None, b=None, loc=None, ip=None)
print(VariadicOp.__init__.__signature__)
# CHECK: (self, /, b, a, *, loc=None, ip=None)
print(Variadic2Op.__init__.__signature__)
# CHECK: (self, /, in1, in2, in4, in5, *, in3=None, loc=None, ip=None)
print(MixedOp.__init__.__signature__)
# CHECK: None None
print(ConstraintOp._ODS_OPERAND_SEGMENTS, ConstraintOp._ODS_RESULT_SEGMENTS)
# CHECK: [1, 0] [1, 0, 1]
print(OptionalOp._ODS_OPERAND_SEGMENTS, OptionalOp._ODS_RESULT_SEGMENTS)
# CHECK: [0] [0]
print(Optional2Op._ODS_OPERAND_SEGMENTS, Optional2Op._ODS_RESULT_SEGMENTS)
# CHECK: [1, 0, -1] [-1, -1, 0, 1]
print(VariadicOp._ODS_OPERAND_SEGMENTS, VariadicOp._ODS_RESULT_SEGMENTS)
# CHECK: [-1] [-1]
print(Variadic2Op._ODS_OPERAND_SEGMENTS, Variadic2Op._ODS_RESULT_SEGMENTS)
i32 = IntegerType.get_signless(32)
i64 = IntegerType.get_signless(64)
f32 = F32Type.get()
iattr = IntegerAttr.get(i32, 2)
fattr = FloatAttr.get_f32(2.3)
module = Module.create()
with InsertionPoint(module.body):
ione = arith.constant(i32, 1)
fone = arith.constant(f32, 1.2)
# CHECK: "ext_test.constraint"(%c1_i32, %c1_i32, %cst, %c1_i32) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, i32, f32, i32) -> ()
c1 = ConstraintOp(ione, ione, fone, ione, iattr, fattr)
# CHECK: "ext_test.constraint"(%c1_i32, %cst, %cst, %cst) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, f32, f32, f32) -> ()
ConstraintOp(ione, fone, fone, fone, iattr, fattr)
# CHECK: ext_test.constraint"(%c1_i32, %cst, %c1_i32, %cst) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, f32, i32, f32) -> ()
ConstraintOp(ione, fone, ione, fone, iattr, fattr)
# CHECK: %0:2 = "ext_test.optional"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0>, resultSegmentSizes = array<i32: 1, 0, 1>} : (i32) -> (i32, i32)
o1 = OptionalOp(i32, i32, ione)
# CHECK: %1:3 = "ext_test.optional"(%c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 1>, resultSegmentSizes = array<i32: 1, 1, 1>} : (i32, i32) -> (i32, i32, i32)
o2 = OptionalOp(i32, i32, ione, out2=i32, b=ione)
# CHECK: ext_test.optional2"() {operandSegmentSizes = array<i32: 0>, resultSegmentSizes = array<i32: 0>} : () -> ()
o3 = Optional2Op()
# CHECK: %2 = "ext_test.optional2"() {operandSegmentSizes = array<i32: 0>, resultSegmentSizes = array<i32: 1>} : () -> i32
o4 = Optional2Op(b=i32)
# CHECK: "ext_test.optional2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 0>} : (i32) -> ()
o5 = Optional2Op(a=ione)
# CHECK: %3 = "ext_test.optional2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 1>} : (i32) -> i32
o6 = Optional2Op(b=i32, a=ione)
# CHECK: %4:4 = "ext_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 0, 2>, resultSegmentSizes = array<i32: 1, 2, 0, 1>} : (i32, i32, i32) -> (i32, i32, i32, i32)
v1 = VariadicOp([i32], [i32, i32], i32, ione, [ione, ione])
# CHECK: %5:5 = "ext_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 1, 1>, resultSegmentSizes = array<i32: 1, 2, 1, 1>} : (i32, i32, i32) -> (i32, i32, i32, i32, i32)
v2 = VariadicOp([i32], [i32, i32], i32, ione, [ione], out3=i32, b=ione)
# CHECK: %6:4 = "ext_test.variadic"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0, 0>, resultSegmentSizes = array<i32: 2, 1, 0, 1>} : (i32) -> (i32, i32, i32, i32)
v3 = VariadicOp([i32, i32], [i32], i32, ione, [])
# CHECK: "ext_test.variadic2"() {operandSegmentSizes = array<i32: 0>, resultSegmentSizes = array<i32: 0>} : () -> ()
v4 = Variadic2Op([], [])
# CHECK: "ext_test.variadic2"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 3>, resultSegmentSizes = array<i32: 0>} : (i32, i32, i32) -> ()
v5 = Variadic2Op([], [ione, ione, ione])
# CHECK: %7:2 = "ext_test.variadic2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 2>} : (i32) -> (i32, i32)
v6 = Variadic2Op([i32, i32], [ione])
# CHECK: %8 = "ext_test.mixed"(%c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 0, 1>} : (i32, i32) -> i32
m1 = MixedOp(ione, iattr, iattr, ione)
# CHECK: %9 = "ext_test.mixed"(%c1_i32, %c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 1, 1>} : (i32, i32, i32) -> i32
m2 = MixedOp(ione, iattr, iattr, ione, in3=ione)
print(module)
assert module.operation.verify()
# CHECK: OpResult(%c1_i32
print(c1.a)
# CHECK: 2 : i32
print(c1.x)
# CHECK: OpResult(%c1_i32
print(o1.a)
# CHECK: None
print(o1.b)
# CHECK: OpResult(%c1_i32
print(o2.b)
# CHECK: 0
print(o1.out1.result_number)
# CHECK: None
print(o1.out2)
# CHECK: 0
print(o2.out1.result_number)
# CHECK: 1
print(o2.out2.result_number)
# CHECK: None
print(o3.a)
# CHECK: OpResult(%c1_i32
print(o5.a)
# CHECK: ['OpResult(%c1_i32 = arith.constant 1 : i32)', 'OpResult(%c1_i32 = arith.constant 1 : i32)']
print([str(i) for i in v1.c])
# CHECK: ['OpResult(%c1_i32 = arith.constant 1 : i32)']
print([str(i) for i in v2.c])
# CHECK: []
print([str(i) for i in v3.c])
# CHECK: 0 0
print(len(v4.a), len(v4.b))
# CHECK: 3 0
print(len(v5.a), len(v5.b))
# CHECK: 1 2
print(len(v6.a), len(v6.b))
# cases to violate constraits
module = Module.create()
with InsertionPoint(module.body):
try:
c1 = ConstraintOp(ione, ione, fone, ione, iattr)
except TypeError as e:
# CHECK: missing a required argument: 'y'
print(e)
try:
c2 = ConstraintOp(ione, ione, fone, ione, iattr, fattr, ione)
except TypeError as e:
# CHECKtoo many positional arguments
print(e)
# CHECK: TEST: testExtDialectWithRegion
@run
def testExtDialectWithRegion():
class ParentIsIfTrait(DynamicOpTrait):
@staticmethod
def verify_invariants(op) -> bool:
if not isinstance(op.parent.opview, IfOp):
op.location.emit_error(
f"{op.name} should be put inside {IfOp.OPERATION_NAME}"
)
return False
return True
class TestRegion(Dialect, name="ext_region"):
pass
class IfOp(TestRegion.Operation, name="if"):
cond: Operand[IntegerType[1]]
result: Result[Any]
then: Region
else_: Region
class YieldOp(
TestRegion.Operation, name="yield", traits=[IsTerminatorTrait, ParentIsIfTrait]
):
value: Operand[Any]
def verify_invariants(self) -> bool:
if self.parent.results[0].type != self.value.type:
self.location.emit_error(
"result type mismatch between YieldOp and its parent IfOp"
)
return False
return True
class NoTermOp(TestRegion.Operation, name="no_term", traits=[NoTerminatorTrait]):
body: Region
with Context(), Location.unknown():
TestRegion.load()
# CHECK: irdl.dialect @ext_region {
# CHECK: irdl.operation @if {
# CHECK: %0 = irdl.is i1
# CHECK: irdl.operands(cond: %0)
# CHECK: %1 = irdl.any
# CHECK: irdl.results(result: %1)
# CHECK: %2 = irdl.region
# CHECK: %3 = irdl.region
# CHECK: irdl.regions(then: %2, else_: %3)
# CHECK: }
# CHECK: irdl.operation @yield {
# CHECK: %0 = irdl.any
# CHECK: irdl.operands(value: %0)
# CHECK: }
# CHECK: irdl.operation @no_term {
# CHECK: %0 = irdl.region
# CHECK: irdl.regions(body: %0)
# CHECK: }
# CHECK: }
print(TestRegion._mlir_module)
# CHECK: (self, /, result, cond, *, loc=None, ip=None)
print(IfOp.__init__.__signature__)
# CHECK: None None
print(IfOp._ODS_OPERAND_SEGMENTS, IfOp._ODS_RESULT_SEGMENTS)
# CHECK: (2, True)
print(IfOp._ODS_REGIONS)
module = Module.create()
with InsertionPoint(module.body):
i1 = IntegerType.get_signless(1)
i32 = IntegerType.get_signless(32)
cond = arith.constant(i1, 1)
if_ = IfOp(i32, cond)
if_.then.blocks.append()
if_.else_.blocks.append()
with InsertionPoint(if_.then.blocks[0]):
v = arith.constant(i32, 2)
YieldOp(v)
with InsertionPoint(if_.else_.blocks[0]):
v = arith.constant(i32, 3)
YieldOp(v)
nt = NoTermOp()
nt.body.blocks.append()
with InsertionPoint(nt.body.blocks[0]):
arith.constant(i32, 4)
# No terminator here
assert module.operation.verify()
# CHECK: module {
# CHECK: %true = arith.constant true
# CHECK: %0 = "ext_region.if"(%true) ({
# CHECK: %c2_i32 = arith.constant 2 : i32
# CHECK: "ext_region.yield"(%c2_i32) : (i32) -> ()
# CHECK: }, {
# CHECK: %c3_i32 = arith.constant 3 : i32
# CHECK: "ext_region.yield"(%c3_i32) : (i32) -> ()
# CHECK: }) : (i1) -> i32
# CHECK: "ext_region.no_term"() ({
# CHECK: %c4_i32 = arith.constant 4 : i32
# CHECK: }) : () -> ()
# CHECK: }
print(module)
# CHECK: %c2_i32 = arith.constant 2 : i32
print(if_.then.blocks[0])
# CHECK: %c3_i32 = arith.constant 3 : i32
print(if_.else_.blocks[0])
# CHECK-LABEL: Testing violation cases
print("Testing violation cases:")
module = Module.create()
with InsertionPoint(module.body):
i1 = IntegerType.get_signless(1)
i32 = IntegerType.get_signless(32)
cond = arith.constant(i1, 1)
if_ = IfOp(i32, cond)
if_.then.blocks.append()
if_.else_.blocks.append()
with InsertionPoint(if_.then.blocks[0]):
v = arith.constant(i32, 2)
with InsertionPoint(if_.else_.blocks[0]):
v = arith.constant(i32, 3)
try:
module.operation.verify()
except Exception as e:
# CHECK: Verification failed:
# CHECK: block with no terminator
print(e)
module = Module.create()
with InsertionPoint(module.body):
v = arith.constant(i32, 2)
YieldOp(v)
try:
module.operation.verify()
except Exception as e:
# CHECK: Verification failed:
# CHECK: ext_region.yield should be put inside ext_region.if
print(e)
module = Module.create()
with InsertionPoint(module.body):
i1 = IntegerType.get_signless(1)
i32 = IntegerType.get_signless(32)
cond = arith.constant(i1, 1)
if_ = IfOp(i1, cond)
if_.then.blocks.append()
if_.else_.blocks.append()
with InsertionPoint(if_.then.blocks[0]):
v = arith.constant(i32, 2)
YieldOp(v)
with InsertionPoint(if_.else_.blocks[0]):
v = arith.constant(i32, 3)
YieldOp(v)
try:
module.operation.verify()
except Exception as e:
# CHECK: Verification failed:
# CHECK: result type mismatch
print(e)