[MLIR][Python] Support region in python-defined dialects (#179086)
This PR adds basic support for defining regions in Python-defined
dialects. Example usage:
```python
class TestRegion(Dialect, name="ext_region"):
pass
class IfOp(TestRegion.Operation, name="if"):
cond: Operand[IntegerType[1]]
then: Region
else_: Region
```
Current limitations:
* We can’t specify region constraints yet (e.g., number of blocks or
block argument types). This will be addressed as a follow-up task.
* We can’t mark an op as a `Terminator` or `NoTerminator` yet. This
depends on `DynamicOpTraits` (#177735) and Python-side trait API
support, and will be implemented in a follow-up PR.
This is the first PR after splitting off #179032.
This is a follow-up PR of #169045.
---------
Co-authored-by: Rolf Morel <rolfmorel@gmail.com>
This commit is contained in:
parent
fcc4231ac5
commit
cb274ea176
@ -29,10 +29,13 @@ __all__ = [
|
||||
"Dialect",
|
||||
"Operand",
|
||||
"Result",
|
||||
"Region",
|
||||
"Operation",
|
||||
]
|
||||
|
||||
Operand = ir.Value
|
||||
Result = ir.OpResult
|
||||
Region = ir.Region
|
||||
|
||||
|
||||
class ConstraintLoweringContext:
|
||||
@ -102,7 +105,6 @@ class FieldDef:
|
||||
"""
|
||||
|
||||
name: str
|
||||
constraint: Any
|
||||
variadicity: Variadicity
|
||||
|
||||
@staticmethod
|
||||
@ -117,38 +119,50 @@ class FieldDef:
|
||||
|
||||
origin = get_origin(type_)
|
||||
if origin is ir.OpResult:
|
||||
return ResultDef(name, get_args(type_)[0], variadicity)
|
||||
return ResultDef(name, variadicity, get_args(type_)[0])
|
||||
elif origin is ir.Value:
|
||||
return OperandDef(name, get_args(type_)[0], variadicity)
|
||||
return OperandDef(name, variadicity, get_args(type_)[0])
|
||||
elif issubclass(origin or type_, ir.Attribute):
|
||||
return AttributeDef(name, type_, variadicity)
|
||||
return AttributeDef(name, variadicity, type_)
|
||||
elif type_ is ir.Region:
|
||||
return RegionDef(name, variadicity)
|
||||
raise TypeError(f"unsupported type in operation definition: {type_}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class OperandDef(FieldDef):
|
||||
pass
|
||||
constraint: Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResultDef(FieldDef):
|
||||
pass
|
||||
constraint: Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttributeDef(FieldDef):
|
||||
constraint: Any
|
||||
|
||||
def __post_init__(self):
|
||||
if self.variadicity != Variadicity.single:
|
||||
raise ValueError("optional attribute is not supported in IRDL")
|
||||
raise ValueError("optional attribute is not currently supported")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RegionDef(FieldDef):
|
||||
def __post_init__(self):
|
||||
if self.variadicity != Variadicity.single:
|
||||
raise ValueError("optional region is not currently supported")
|
||||
|
||||
|
||||
def partition_fields(
|
||||
fields: List[FieldDef],
|
||||
) -> Tuple[List[OperandDef], List[AttributeDef], List[ResultDef]]:
|
||||
) -> Tuple[List[OperandDef], List[AttributeDef], List[ResultDef], List[RegionDef]]:
|
||||
operands = [i for i in fields if isinstance(i, OperandDef)]
|
||||
attrs = [i for i in fields if isinstance(i, AttributeDef)]
|
||||
results = [i for i in fields if isinstance(i, ResultDef)]
|
||||
return operands, attrs, results
|
||||
regions = [i for i in fields if isinstance(i, RegionDef)]
|
||||
return operands, attrs, results, regions
|
||||
|
||||
|
||||
def normalize_value_range(
|
||||
@ -216,6 +230,11 @@ class Operation(ir.OpView):
|
||||
if not name:
|
||||
return
|
||||
|
||||
if not hasattr(cls, "_dialect_name") or not hasattr(cls, "_dialect_obj"):
|
||||
raise RuntimeError(
|
||||
"Operation subclasses must inherit from a Dialect's Operation subclass"
|
||||
)
|
||||
|
||||
op_name = name
|
||||
cls._op_name = op_name
|
||||
dialect_name = cls._dialect_name
|
||||
@ -223,10 +242,11 @@ class Operation(ir.OpView):
|
||||
|
||||
cls._generate_class_attributes(dialect_name, op_name, fields)
|
||||
cls._generate_init_method(fields)
|
||||
operands, attrs, results = partition_fields(fields)
|
||||
operands, attrs, results, regions = partition_fields(fields)
|
||||
cls._generate_attr_properties(attrs)
|
||||
cls._generate_operand_properties(operands)
|
||||
cls._generate_result_properties(results)
|
||||
cls._generate_region_properties(regions)
|
||||
|
||||
dialect_obj.operations.append(cls)
|
||||
|
||||
@ -254,7 +274,9 @@ class Operation(ir.OpView):
|
||||
)
|
||||
# results are placed at the beginning of the parameter list,
|
||||
# but operands and attributes can appear in any relative order.
|
||||
args = result_args + [i for i in fields if not isinstance(i, ResultDef)]
|
||||
args = result_args + [
|
||||
i for i in fields if not isinstance(i, ResultDef | RegionDef)
|
||||
]
|
||||
positional_args = [
|
||||
i.name for i in args if i.variadicity != Variadicity.optional
|
||||
]
|
||||
@ -272,7 +294,7 @@ class Operation(ir.OpView):
|
||||
|
||||
@classmethod
|
||||
def _generate_init_method(cls, fields: List[FieldDef]) -> None:
|
||||
operands, attrs, results = partition_fields(fields)
|
||||
operands, attrs, results, regions = partition_fields(fields)
|
||||
inferred_types = [infer_type(i.constraint) for i in results]
|
||||
|
||||
# we infer result types only when all result types can be inferred
|
||||
@ -299,7 +321,7 @@ class Operation(ir.OpView):
|
||||
for attr in attrs
|
||||
if args[attr.name] is not None
|
||||
)
|
||||
_regions = None
|
||||
_regions = len(regions) or None
|
||||
_ods_successors = None
|
||||
self = args["self"]
|
||||
super(Operation, self).__init__(
|
||||
@ -323,13 +345,13 @@ class Operation(ir.OpView):
|
||||
def _generate_class_attributes(
|
||||
cls, dialect_name: str, op_name: str, fields: List[FieldDef]
|
||||
) -> None:
|
||||
operands, attrs, results = partition_fields(fields)
|
||||
operands, attrs, results, regions = partition_fields(fields)
|
||||
|
||||
operand_segments = cls._generate_segments(operands)
|
||||
result_segments = cls._generate_segments(results)
|
||||
|
||||
cls.OPERATION_NAME = f"{dialect_name}.{op_name}"
|
||||
cls._ODS_REGIONS = (0, True)
|
||||
cls._ODS_REGIONS = (len(regions), True)
|
||||
cls._ODS_OPERAND_SEGMENTS = operand_segments
|
||||
cls._ODS_RESULT_SEGMENTS = result_segments
|
||||
|
||||
@ -342,6 +364,15 @@ class Operation(ir.OpView):
|
||||
property(lambda self, name=attr.name: self.attributes[name]),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _generate_region_properties(cls, regions: List[RegionDef]) -> None:
|
||||
for i, region in enumerate(regions):
|
||||
setattr(
|
||||
cls,
|
||||
region.name,
|
||||
property(lambda self, i=i: self.regions[i]),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _generate_operand_properties(cls, operands: List[OperandDef]) -> None:
|
||||
for i, operand in enumerate(operands):
|
||||
@ -379,7 +410,7 @@ class Operation(ir.OpView):
|
||||
@classmethod
|
||||
def _emit_operation(cls) -> None:
|
||||
ctx = ConstraintLoweringContext()
|
||||
operands, attrs, results = partition_fields(cls._fields)
|
||||
operands, attrs, results, regions = partition_fields(cls._fields)
|
||||
|
||||
op = irdl.operation_(cls._op_name)
|
||||
with ir.InsertionPoint(op.body):
|
||||
@ -400,6 +431,11 @@ class Operation(ir.OpView):
|
||||
[i.name for i in results],
|
||||
[i.variadicity for i in results],
|
||||
)
|
||||
if regions:
|
||||
irdl.regions_(
|
||||
[irdl.region([]) for _ in regions],
|
||||
[i.name for i in regions],
|
||||
)
|
||||
|
||||
|
||||
class Dialect(ir.Dialect):
|
||||
|
||||
@ -76,6 +76,8 @@ def testMyInt():
|
||||
print(add1._ODS_OPERAND_SEGMENTS)
|
||||
# CHECK: None
|
||||
print(add1._ODS_RESULT_SEGMENTS)
|
||||
# CHECK: (0, True)
|
||||
print(add1._ODS_REGIONS)
|
||||
# CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32
|
||||
print(add1.lhs.owner)
|
||||
# CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32
|
||||
@ -338,3 +340,73 @@ def testExtDialect():
|
||||
except TypeError as e:
|
||||
# CHECK:too many positional arguments
|
||||
print(e)
|
||||
|
||||
|
||||
# CHECK: TEST: testExtDialectWithRegion
|
||||
@run
|
||||
def testExtDialectWithRegion():
|
||||
class TestRegion(Dialect, name="ext_region"):
|
||||
pass
|
||||
|
||||
class IfOp(TestRegion.Operation, name="if"):
|
||||
cond: Operand[IntegerType[1]]
|
||||
then: Region
|
||||
else_: Region
|
||||
|
||||
with Context(), Location.unknown():
|
||||
TestRegion.load()
|
||||
# CHECK: irdl.dialect @ext_region {
|
||||
# CHECK: irdl.operation @if {
|
||||
# CHECK: %0 = irdl.is i1
|
||||
# CHECK: irdl.operands(cond: %0)
|
||||
# CHECK: %1 = irdl.region
|
||||
# CHECK: %2 = irdl.region
|
||||
# CHECK: irdl.regions(then: %1, else_: %2)
|
||||
# CHECK: }
|
||||
print(TestRegion._mlir_module)
|
||||
|
||||
# CHECK: (self, /, cond, *, loc=None, ip=None)
|
||||
print(IfOp.__init__.__signature__)
|
||||
|
||||
# CHECK: None None
|
||||
print(IfOp._ODS_OPERAND_SEGMENTS, IfOp._ODS_RESULT_SEGMENTS)
|
||||
# CHECK: (2, True)
|
||||
print(IfOp._ODS_REGIONS)
|
||||
|
||||
from mlir.dialects import llvm
|
||||
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
i1 = IntegerType.get_signless(1)
|
||||
i32 = IntegerType.get_signless(32)
|
||||
cond = arith.constant(i1, 1)
|
||||
|
||||
if_ = IfOp(cond)
|
||||
if_.then.blocks.append()
|
||||
if_.else_.blocks.append()
|
||||
|
||||
with InsertionPoint(if_.then.blocks[0]):
|
||||
v = arith.constant(i32, 2)
|
||||
llvm.unreachable()
|
||||
|
||||
with InsertionPoint(if_.else_.blocks[0]):
|
||||
v = arith.constant(i32, 3)
|
||||
llvm.unreachable()
|
||||
|
||||
assert module.operation.verify()
|
||||
# CHECK: module {
|
||||
# CHECK: %true = arith.constant true
|
||||
# CHECK: "ext_region.if"(%true) ({
|
||||
# CHECK: %c2_i32 = arith.constant 2 : i32
|
||||
# CHECK: llvm.unreachable
|
||||
# CHECK: }, {
|
||||
# CHECK: %c3_i32 = arith.constant 3 : i32
|
||||
# CHECK: llvm.unreachable
|
||||
# CHECK: }) : (i1) -> ()
|
||||
# CHECK: }
|
||||
print(module)
|
||||
|
||||
# CHECK: %c2_i32 = arith.constant 2 : i32
|
||||
print(if_.then.blocks[0])
|
||||
# CHECK: %c3_i32 = arith.constant 3 : i32
|
||||
print(if_.else_.blocks[0])
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user