Twice e95dabef96
[MLIR][Python] Support attribute definitions in Python-defined dialects (#183907)
This PR is quite similiar to
https://github.com/llvm/llvm-project/pull/182805.

We added basic support of attribute definitions in Python-defined
dialects, including:

- IRDL codegen for attribute definitions
- Attr builders like `MyAttr.get(..)` and attr parameter accessors (e.g.
`my_attr.param1`)
- Use Python-defined attrs in Python-defined operations

Assisted by GitHub Copilot.
2026-03-02 09:57:25 +08:00

697 lines
25 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# RUN: %PYTHON %s 2>&1 | FileCheck %s
from mlir.ir import *
from mlir.dialects import arith
from mlir.dialects.ext import *
from mlir import ir
from typing import Any, Optional, Sequence, TypeVar, Union
import sys
def run(f):
print("\nTEST:", f.__name__)
f()
# CHECK: TEST: testMyInt
@run
def testMyInt():
class MyInt(Dialect, name="myint"):
pass
i32 = IntegerType[32]
class ConstantOp(MyInt.Operation, name="constant"):
value: IntegerAttr
cst: Result[i32]
class AddOp(Operation, dialect=MyInt, name="add"):
lhs: Operand[i32]
rhs: Operand[i32]
res: Result[i32]
# CHECK: irdl.dialect @myint {
# CHECK: irdl.operation @constant {
# CHECK: %0 = irdl.base "#builtin.integer"
# CHECK: irdl.attributes {"value" = %0}
# CHECK: %1 = irdl.is i32
# CHECK: irdl.results(cst: %1)
# CHECK: }
# CHECK: irdl.operation @add {
# CHECK: %0 = irdl.is i32
# CHECK: irdl.operands(lhs: %0, rhs: %0)
# CHECK: irdl.results(res: %0)
# CHECK: }
# CHECK: }
with Context(), Location.unknown():
MyInt.load()
print(MyInt._mlir_module)
# CHECK: ['constant', 'add']
print([i._op_name for i in MyInt.operations])
i32 = IntegerType.get_signless(32)
module = Module.create()
with InsertionPoint(module.body):
two = ConstantOp(IntegerAttr.get(i32, 2))
three = ConstantOp(IntegerAttr.get(i32, 3))
add1 = AddOp(two, three)
add2 = AddOp(add1, two)
add3 = AddOp(add2, three)
# CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32
# CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32
# CHECK: %2 = "myint.add"(%0, %1) : (i32, i32) -> i32
# CHECK: %3 = "myint.add"(%2, %0) : (i32, i32) -> i32
# CHECK: %4 = "myint.add"(%3, %1) : (i32, i32) -> i32
print(module)
assert module.operation.verify()
# CHECK: AddOp
print(type(add1).__name__)
# CHECK: ConstantOp
print(type(two).__name__)
# CHECK: myint.add
print(add1.OPERATION_NAME)
# CHECK: None
print(add1._ODS_OPERAND_SEGMENTS)
# CHECK: None
print(add1._ODS_RESULT_SEGMENTS)
# CHECK: (0, True)
print(add1._ODS_REGIONS)
# CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32
print(add1.lhs.owner)
# CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32
print(add1.rhs.owner)
# CHECK: 2 : i32
print(two.value)
# CHECK: OpResult(%0
print(two.cst)
# CHECK: (self, /, lhs, rhs, *, loc=None, ip=None)
print(AddOp.__init__.__signature__)
# CHECK: (self, /, value, *, loc=None, ip=None)
print(ConstantOp.__init__.__signature__)
# CHECK: True
print(issubclass(AddOp.Adaptor, OpAdaptor))
adaptor1 = AddOp.Adaptor(list(add1.operands), add1)
# CHECK: myint.add
print(adaptor1.OPERATION_NAME)
# CHECK: OpResult(%0 = "myint.constant"() {value = 2 : i32} : () -> i32)
print(adaptor1.lhs)
# CHECK: OpResult(%1 = "myint.constant"() {value = 3 : i32} : () -> i32)
print(adaptor1.rhs)
# CHECK: TEST: testExtDialect
@run
def testExtDialect():
class Test(Dialect, name="ext_test"):
pass
i32 = IntegerType[32]
class ConstraintOp(Test.Operation, name="constraint"):
a: Operand[i32 | IntegerType[64]]
b: Operand[Any]
# Here we use `F32Type[()]` instead of just `F32Type`
# because of an existing issue in IRDL implementation
# where `irdl.base` cannot exist in `irdl.any_of`.
c: Operand[F32Type[()] | i32]
d: Operand[Any]
x: IntegerAttr
y: FloatAttr
class OptionalOp(Test.Operation, name="optional"):
a: Operand[i32]
b: Optional[Operand[i32]]
out1: Result[i32]
out2: Result[i32] | None
out3: Result[i32]
class Optional2Op(Test.Operation, name="optional2"):
a: Optional[Operand[i32]]
b: Optional[Result[i32]]
class VariadicOp(Test.Operation, name="variadic"):
a: Operand[i32]
b: Optional[Operand[i32]]
c: Sequence[Operand[i32]]
out1: Sequence[Result[i32]]
out2: Sequence[Result[i32]]
out3: Optional[Result[i32]]
out4: Result[i32]
class Variadic2Op(Test.Operation, name="variadic2"):
a: Sequence[Operand[i32]]
b: Sequence[Result[i32]]
class MixedOpBase(Test.Operation):
out: Result[i32]
in1: Operand[i32]
class MixedOp(MixedOpBase, name="mixed"):
in2: IntegerAttr
in3: Optional[Operand[i32]]
in4: IntegerAttr
in5: Operand[i32]
T = TypeVar("T")
U = TypeVar("U", bound=IntegerType[32] | IntegerType[64])
V = TypeVar("V", bound=Union[IntegerType[8], IntegerType[16]])
class TypeVarOp(Test.Operation, name="type_var"):
in1: Operand[T]
in2: Operand[T]
in3: Operand[U]
in4: Operand[U | V]
in5: Operand[V]
# CHECK: irdl.dialect @ext_test {
# CHECK: irdl.operation @constraint {
# CHECK: %0 = irdl.is i32
# CHECK: %1 = irdl.is i64
# CHECK: %2 = irdl.any_of(%0, %1)
# CHECK: %3 = irdl.any
# CHECK: %4 = irdl.is f32
# CHECK: %5 = irdl.any_of(%4, %0)
# CHECK: %6 = irdl.any
# CHECK: irdl.operands(a: %2, b: %3, c: %5, d: %6)
# CHECK: %7 = irdl.base "#builtin.integer"
# CHECK: %8 = irdl.base "#builtin.float"
# CHECK: irdl.attributes {"x" = %7, "y" = %8}
# CHECK: }
# CHECK: irdl.operation @optional {
# CHECK: %0 = irdl.is i32
# CHECK: irdl.operands(a: %0, b: optional %0)
# CHECK: irdl.results(out1: %0, out2: optional %0, out3: %0)
# CHECK: }
# CHECK: irdl.operation @optional2 {
# CHECK: %0 = irdl.is i32
# CHECK: irdl.operands(a: optional %0)
# CHECK: irdl.results(b: optional %0)
# CHECK: }
# CHECK: irdl.operation @variadic {
# CHECK: %0 = irdl.is i32
# CHECK: irdl.operands(a: %0, b: optional %0, c: variadic %0)
# CHECK: irdl.results(out1: variadic %0, out2: variadic %0, out3: optional %0, out4: %0)
# CHECK: }
# CHECK: irdl.operation @variadic2 {
# CHECK: %0 = irdl.is i32
# CHECK: irdl.operands(a: variadic %0)
# CHECK: irdl.results(b: variadic %0)
# CHECK: }
# CHECK: irdl.operation @mixed {
# CHECK: %0 = irdl.is i32
# CHECK: irdl.operands(in1: %0, in3: optional %0, in5: %0)
# CHECK: %1 = irdl.base "#builtin.integer"
# CHECK: %2 = irdl.base "#builtin.integer"
# CHECK: irdl.attributes {"in2" = %1, "in4" = %2}
# CHECK: irdl.results(out: %0)
# CHECK: }
# CHECK: irdl.operation @type_var {
# CHECK: %0 = irdl.any
# CHECK: %1 = irdl.is i32
# CHECK: %2 = irdl.is i64
# CHECK: %3 = irdl.any_of(%1, %2)
# CHECK: %4 = irdl.is i8
# CHECK: %5 = irdl.is i16
# CHECK: %6 = irdl.any_of(%4, %5)
# CHECK: %7 = irdl.any_of(%3, %6)
# CHECK: irdl.operands(in1: %0, in2: %0, in3: %3, in4: %7, in5: %6)
# CHECK: }
# CHECK: }
with Context(), Location.unknown():
Test.load()
print(Test._mlir_module)
# CHECK: (self, /, a, b, c, d, x, y, *, loc=None, ip=None)
print(ConstraintOp.__init__.__signature__)
# CHECK: (self, /, out1, out3, a, *, out2=None, b=None, loc=None, ip=None)
print(OptionalOp.__init__.__signature__)
# CHECK: (self, /, *, b=None, a=None, loc=None, ip=None)
print(Optional2Op.__init__.__signature__)
# CHECK: (self, /, out1, out2, out4, a, c, *, out3=None, b=None, loc=None, ip=None)
print(VariadicOp.__init__.__signature__)
# CHECK: (self, /, b, a, *, loc=None, ip=None)
print(Variadic2Op.__init__.__signature__)
# CHECK: (self, /, in1, in2, in4, in5, *, in3=None, loc=None, ip=None)
print(MixedOp.__init__.__signature__)
# CHECK: None None
print(ConstraintOp._ODS_OPERAND_SEGMENTS, ConstraintOp._ODS_RESULT_SEGMENTS)
# CHECK: [1, 0] [1, 0, 1]
print(OptionalOp._ODS_OPERAND_SEGMENTS, OptionalOp._ODS_RESULT_SEGMENTS)
# CHECK: [0] [0]
print(Optional2Op._ODS_OPERAND_SEGMENTS, Optional2Op._ODS_RESULT_SEGMENTS)
# CHECK: [1, 0, -1] [-1, -1, 0, 1]
print(VariadicOp._ODS_OPERAND_SEGMENTS, VariadicOp._ODS_RESULT_SEGMENTS)
# CHECK: [-1] [-1]
print(Variadic2Op._ODS_OPERAND_SEGMENTS, Variadic2Op._ODS_RESULT_SEGMENTS)
i32 = IntegerType.get_signless(32)
i64 = IntegerType.get_signless(64)
f32 = F32Type.get()
iattr = IntegerAttr.get(i32, 2)
fattr = FloatAttr.get_f32(2.3)
module = Module.create()
with InsertionPoint(module.body):
ione = arith.constant(i32, 1)
fone = arith.constant(f32, 1.2)
# CHECK: "ext_test.constraint"(%c1_i32, %c1_i32, %cst, %c1_i32) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, i32, f32, i32) -> ()
c1 = ConstraintOp(ione, ione, fone, ione, iattr, fattr)
# CHECK: "ext_test.constraint"(%c1_i32, %cst, %cst, %cst) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, f32, f32, f32) -> ()
ConstraintOp(ione, fone, fone, fone, iattr, fattr)
# CHECK: ext_test.constraint"(%c1_i32, %cst, %c1_i32, %cst) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, f32, i32, f32) -> ()
ConstraintOp(ione, fone, ione, fone, iattr, fattr)
# CHECK: %0:2 = "ext_test.optional"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0>, resultSegmentSizes = array<i32: 1, 0, 1>} : (i32) -> (i32, i32)
o1 = OptionalOp(i32, i32, ione)
# CHECK: %1:3 = "ext_test.optional"(%c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 1>, resultSegmentSizes = array<i32: 1, 1, 1>} : (i32, i32) -> (i32, i32, i32)
o2 = OptionalOp(i32, i32, ione, out2=i32, b=ione)
# CHECK: ext_test.optional2"() {operandSegmentSizes = array<i32: 0>, resultSegmentSizes = array<i32: 0>} : () -> ()
o3 = Optional2Op()
# CHECK: %2 = "ext_test.optional2"() {operandSegmentSizes = array<i32: 0>, resultSegmentSizes = array<i32: 1>} : () -> i32
o4 = Optional2Op(b=i32)
# CHECK: "ext_test.optional2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 0>} : (i32) -> ()
o5 = Optional2Op(a=ione)
# CHECK: %3 = "ext_test.optional2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 1>} : (i32) -> i32
o6 = Optional2Op(b=i32, a=ione)
# CHECK: %4:4 = "ext_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 0, 2>, resultSegmentSizes = array<i32: 1, 2, 0, 1>} : (i32, i32, i32) -> (i32, i32, i32, i32)
v1 = VariadicOp([i32], [i32, i32], i32, ione, [ione, ione])
# CHECK: %5:5 = "ext_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 1, 1>, resultSegmentSizes = array<i32: 1, 2, 1, 1>} : (i32, i32, i32) -> (i32, i32, i32, i32, i32)
v2 = VariadicOp([i32], [i32, i32], i32, ione, [ione], out3=i32, b=ione)
# CHECK: %6:4 = "ext_test.variadic"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0, 0>, resultSegmentSizes = array<i32: 2, 1, 0, 1>} : (i32) -> (i32, i32, i32, i32)
v3 = VariadicOp([i32, i32], [i32], i32, ione, [])
# CHECK: "ext_test.variadic2"() {operandSegmentSizes = array<i32: 0>, resultSegmentSizes = array<i32: 0>} : () -> ()
v4 = Variadic2Op([], [])
# CHECK: "ext_test.variadic2"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 3>, resultSegmentSizes = array<i32: 0>} : (i32, i32, i32) -> ()
v5 = Variadic2Op([], [ione, ione, ione])
# CHECK: %7:2 = "ext_test.variadic2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 2>} : (i32) -> (i32, i32)
v6 = Variadic2Op([i32, i32], [ione])
# CHECK: %8 = "ext_test.mixed"(%c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 0, 1>} : (i32, i32) -> i32
m1 = MixedOp(ione, iattr, iattr, ione)
# CHECK: %9 = "ext_test.mixed"(%c1_i32, %c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 1, 1>} : (i32, i32, i32) -> i32
m2 = MixedOp(ione, iattr, iattr, ione, in3=ione)
print(module)
assert module.operation.verify()
# CHECK: OpResult(%c1_i32
print(c1.a)
# CHECK: 2 : i32
print(c1.x)
# CHECK: OpResult(%c1_i32
print(o1.a)
# CHECK: None
print(o1.b)
# CHECK: OpResult(%c1_i32
print(o2.b)
# CHECK: 0
print(o1.out1.result_number)
# CHECK: None
print(o1.out2)
# CHECK: 0
print(o2.out1.result_number)
# CHECK: 1
print(o2.out2.result_number)
# CHECK: None
print(o3.a)
# CHECK: OpResult(%c1_i32
print(o5.a)
# CHECK: ['OpResult(%c1_i32 = arith.constant 1 : i32)', 'OpResult(%c1_i32 = arith.constant 1 : i32)']
print([str(i) for i in v1.c])
# CHECK: ['OpResult(%c1_i32 = arith.constant 1 : i32)']
print([str(i) for i in v2.c])
# CHECK: []
print([str(i) for i in v3.c])
# CHECK: 0 0
print(len(v4.a), len(v4.b))
# CHECK: 3 0
print(len(v5.a), len(v5.b))
# CHECK: 1 2
print(len(v6.a), len(v6.b))
# cases to violate constraits
module = Module.create()
with InsertionPoint(module.body):
try:
c1 = ConstraintOp(ione, ione, fone, ione, iattr)
except TypeError as e:
# CHECK: missing a required argument: 'y'
print(e)
try:
c2 = ConstraintOp(ione, ione, fone, ione, iattr, fattr, ione)
except TypeError as e:
# CHECKtoo many positional arguments
print(e)
# CHECK: TEST: testExtDialectWithRegion
@run
def testExtDialectWithRegion():
class ParentIsIfTrait(DynamicOpTrait):
@staticmethod
def verify_invariants(op) -> bool:
if not isinstance(op.parent.opview, IfOp):
op.location.emit_error(
f"{op.name} should be put inside {IfOp.OPERATION_NAME}"
)
return False
return True
class TestRegion(Dialect, name="ext_region"):
pass
class IfOp(TestRegion.Operation, name="if"):
cond: Operand[IntegerType[1]]
result: Result[Any]
then: Region
else_: Region
class YieldOp(
TestRegion.Operation, name="yield", traits=[IsTerminatorTrait, ParentIsIfTrait]
):
value: Operand[Any]
def verify_invariants(self) -> bool:
if self.parent.results[0].type != self.value.type:
self.location.emit_error(
"result type mismatch between YieldOp and its parent IfOp"
)
return False
return True
class NoTermOp(TestRegion.Operation, name="no_term", traits=[NoTerminatorTrait]):
body: Region
with Context(), Location.unknown():
TestRegion.load()
# CHECK: irdl.dialect @ext_region {
# CHECK: irdl.operation @if {
# CHECK: %0 = irdl.is i1
# CHECK: irdl.operands(cond: %0)
# CHECK: %1 = irdl.any
# CHECK: irdl.results(result: %1)
# CHECK: %2 = irdl.region
# CHECK: %3 = irdl.region
# CHECK: irdl.regions(then: %2, else_: %3)
# CHECK: }
# CHECK: irdl.operation @yield {
# CHECK: %0 = irdl.any
# CHECK: irdl.operands(value: %0)
# CHECK: }
# CHECK: irdl.operation @no_term {
# CHECK: %0 = irdl.region
# CHECK: irdl.regions(body: %0)
# CHECK: }
# CHECK: }
print(TestRegion._mlir_module)
# CHECK: (self, /, result, cond, *, loc=None, ip=None)
print(IfOp.__init__.__signature__)
# CHECK: None None
print(IfOp._ODS_OPERAND_SEGMENTS, IfOp._ODS_RESULT_SEGMENTS)
# CHECK: (2, True)
print(IfOp._ODS_REGIONS)
module = Module.create()
with InsertionPoint(module.body):
i1 = IntegerType.get_signless(1)
i32 = IntegerType.get_signless(32)
cond = arith.constant(i1, 1)
if_ = IfOp(i32, cond)
if_.then.blocks.append()
if_.else_.blocks.append()
with InsertionPoint(if_.then.blocks[0]):
v = arith.constant(i32, 2)
YieldOp(v)
with InsertionPoint(if_.else_.blocks[0]):
v = arith.constant(i32, 3)
YieldOp(v)
nt = NoTermOp()
nt.body.blocks.append()
with InsertionPoint(nt.body.blocks[0]):
arith.constant(i32, 4)
# No terminator here
assert module.operation.verify()
# CHECK: module {
# CHECK: %true = arith.constant true
# CHECK: %0 = "ext_region.if"(%true) ({
# CHECK: %c2_i32 = arith.constant 2 : i32
# CHECK: "ext_region.yield"(%c2_i32) : (i32) -> ()
# CHECK: }, {
# CHECK: %c3_i32 = arith.constant 3 : i32
# CHECK: "ext_region.yield"(%c3_i32) : (i32) -> ()
# CHECK: }) : (i1) -> i32
# CHECK: "ext_region.no_term"() ({
# CHECK: %c4_i32 = arith.constant 4 : i32
# CHECK: }) : () -> ()
# CHECK: }
print(module)
# CHECK: %c2_i32 = arith.constant 2 : i32
print(if_.then.blocks[0])
# CHECK: %c3_i32 = arith.constant 3 : i32
print(if_.else_.blocks[0])
# CHECK-LABEL: Testing violation cases
print("Testing violation cases:")
module = Module.create()
with InsertionPoint(module.body):
i1 = IntegerType.get_signless(1)
i32 = IntegerType.get_signless(32)
cond = arith.constant(i1, 1)
if_ = IfOp(i32, cond)
if_.then.blocks.append()
if_.else_.blocks.append()
with InsertionPoint(if_.then.blocks[0]):
v = arith.constant(i32, 2)
with InsertionPoint(if_.else_.blocks[0]):
v = arith.constant(i32, 3)
try:
module.operation.verify()
except Exception as e:
# CHECK: Verification failed:
# CHECK: block with no terminator
print(e)
module = Module.create()
with InsertionPoint(module.body):
v = arith.constant(i32, 2)
YieldOp(v)
try:
module.operation.verify()
except Exception as e:
# CHECK: Verification failed:
# CHECK: ext_region.yield should be put inside ext_region.if
print(e)
module = Module.create()
with InsertionPoint(module.body):
i1 = IntegerType.get_signless(1)
i32 = IntegerType.get_signless(32)
cond = arith.constant(i1, 1)
if_ = IfOp(i1, cond)
if_.then.blocks.append()
if_.else_.blocks.append()
with InsertionPoint(if_.then.blocks[0]):
v = arith.constant(i32, 2)
YieldOp(v)
with InsertionPoint(if_.else_.blocks[0]):
v = arith.constant(i32, 3)
YieldOp(v)
try:
module.operation.verify()
except Exception as e:
# CHECK: Verification failed:
# CHECK: result type mismatch
print(e)
# CHECK: TEST: testExtDialectWithType
@run
def testExtDialectWithType():
class TestType(Dialect, name="ext_type"):
pass
class Array(TestType.Type, name="array"):
elem_type: IntegerType[32] | IntegerType[64]
length: IntegerAttr
class MakeArrayOp(TestType.Operation, name="make_array"):
arr: Result[Array]
class MakeArray3Op(TestType.Operation, name="make_array3"):
arr: Result[Array[IntegerType[32], IntegerAttr[IntegerType[32], 3]]]
with Context(), Location.unknown():
TestType.load()
# CHECK: irdl.dialect @ext_type {
# CHECK: irdl.type @array {
# CHECK: %0 = irdl.is i32
# CHECK: %1 = irdl.is i64
# CHECK: %2 = irdl.any_of(%0, %1)
# CHECK: %3 = irdl.base "#builtin.integer"
# CHECK: irdl.parameters(elem_type: %2, length: %3)
# CHECK: }
# CHECK: irdl.operation @make_array {
# CHECK: %0 = irdl.base @ext_type::@array
# CHECK: irdl.results(arr: %0)
# CHECK: }
# CHECK: irdl.operation @make_array3 {
# CHECK: %0 = irdl.is i32
# CHECK: %1 = irdl.is 3 : i32
# CHECK: %2 = irdl.parametric @ext_type::@array<%0, %1>
# CHECK: irdl.results(arr: %2)
# CHECK: }
# CHECK: }
print(TestType._mlir_module)
# CHECK: ext_type.array
print(Array.type_name)
i32 = IntegerType.get_signless(32)
i64 = IntegerType.get_signless(64)
a4 = Array.get(i32, IntegerAttr.get(i32, 4))
a6 = Array.get(i64, IntegerAttr.get(i32, 6))
# CHECK: !ext_type.array<i32, 4 : i32>
print(a4)
# CHECK: !ext_type.array<i64, 6 : i32>
print(a6)
# CHECK: i32
print(a4.elem_type)
# CHECK: 4 : i32
print(a4.length)
# CHECK: i64
print(a6.elem_type)
# CHECK: 6 : i32
print(a6.length)
# CHECK: <locals>.Array
print(type(Type(a4).maybe_downcast()))
module = Module.create()
with InsertionPoint(module.body):
MakeArrayOp(a4)
MakeArrayOp(a6)
MakeArray3Op()
# CHECK: %0 = "ext_type.make_array"() : () -> !ext_type.array<i32, 4 : i32>
# CHECK: %1 = "ext_type.make_array"() : () -> !ext_type.array<i64, 6 : i32>
# CHECK: %2 = "ext_type.make_array3"() : () -> !ext_type.array<i32, 3 : i32>
assert module.operation.verify()
print(module)
# CHECK: TEST: testExtDialectWithAttr
@run
def testExtDialectWithAttr():
class TestAttr(Dialect, name="ext_attr"):
pass
class IntPair(TestAttr.Attribute, name="pair"):
first: IntegerAttr
second: IntegerAttr
class StrPair(TestAttr.Attribute, name="str_pair"):
first: StringAttr
second: StringAttr
class Op1(TestAttr.Operation, name="op1"):
pair: IntPair
class Op2(TestAttr.Operation, name="op2"):
pair: StrPair
pair2: StrPair[StringAttr["a"], StringAttr["b"]]
with Context(), Location.unknown():
TestAttr.load()
# CHECK: irdl.dialect @ext_attr {
# CHECK: irdl.attribute @pair {
# CHECK: %0 = irdl.base "#builtin.integer"
# CHECK: %1 = irdl.base "#builtin.integer"
# CHECK: irdl.parameters(first: %0, second: %1)
# CHECK: }
# CHECK: irdl.attribute @str_pair {
# CHECK: %0 = irdl.base "#builtin.string"
# CHECK: %1 = irdl.base "#builtin.string"
# CHECK: irdl.parameters(first: %0, second: %1)
# CHECK: }
# CHECK: irdl.operation @op1 {
# CHECK: %0 = irdl.base @ext_attr::@pair
# CHECK: irdl.attributes {"pair" = %0}
# CHECK: }
# CHECK: irdl.operation @op2 {
# CHECK: %0 = irdl.base @ext_attr::@str_pair
# CHECK: %1 = irdl.is "a"
# CHECK: %2 = irdl.is "b"
# CHECK: %3 = irdl.parametric @ext_attr::@str_pair<%1, %2>
# CHECK: irdl.attributes {"pair" = %0, "pair2" = %3}
# CHECK: }
# CHECK: }
print(TestAttr._mlir_module)
# CHECK: ext_attr.pair
print(IntPair.attr_name)
# CHECK: ext_attr.str_pair
print(StrPair.attr_name)
ip = IntPair.get(
IntegerAttr.get(IntegerType.get_signless(32), 1),
IntegerAttr.get(IntegerType.get_signless(32), 2),
)
sp = StrPair.get(StringAttr.get("hello"), StringAttr.get("world"))
# CHECK: #ext_attr.pair<1 : i32, 2 : i32>
print(ip)
# CHECK: #ext_attr.str_pair<"hello", "world">
print(sp)
# CHECK: "hello"
print(sp.first)
# CHECK: "world"
print(sp.second)
sp2 = ir.Attribute(sp).maybe_downcast()
# CHECK: <locals>.StrPair
print(type(sp2))
# CHECK: #ext_attr.str_pair<"hello", "world">
print(str(sp2))
module = Module.create()
with InsertionPoint(module.body):
Op1(ip)
p2 = StrPair.get(StringAttr.get("a"), StringAttr.get("b"))
Op2(sp, p2)
assert module.operation.verify()
# CHECK: "ext_attr.op1"() {pair = #ext_attr.pair<1 : i32, 2 : i32>} : () -> ()
# CHECK: "ext_attr.op2"() {pair = #ext_attr.str_pair<"hello", "world">, pair2 = #ext_attr.str_pair<"a", "b">} : () -> ()
print(module)