[MLIR][Python] Add more field specifiers to Python-defined operations (#188064)
This PR adds two new field specifiers (`operand` and `attribute`) and
extends the existing one (`result`):
- `default_factory` parameter is added for `result` and `attribute` to
specify default value via a lambda/function
- `kw_only` parameter is added for all these three specifiers, to make a
field a keyword-only parameter (without giving a default value).
```python
def result(
*,
infer_type: bool = False,
default_factory: Optional[Callable[[], Any]] = None,
kw_only: bool = False,
) -> Any: ...
def operand(
*,
kw_only: bool = False,
) -> Any: ...
def attribute(
*,
default_factory: Optional[Callable[[], Any]] = None,
kw_only: bool = False,
) -> Any: ...
```
Examples about how to use them:
```python
class OperandSpecifierOp(TestFieldSpecifiers.Operation, name="operand_specifier"):
a: Operand[IntegerType[32]] = operand()
b: Optional[Operand[IntegerType[32]]] = None
c: Operand[IntegerType[32]] = operand(kw_only=True)
class ResultSpecifierOp(TestFieldSpecifiers.Operation, name="result_specifier"):
a: Result[IntegerType[32]] = result()
b: Result[IntegerType[16]] = result(infer_type=True)
c: Result[IntegerType] = result(
default_factory=lambda: IntegerType.get_signless(8)
)
d: Sequence[Result[IntegerType]] = result(default_factory=list)
e: Result[IntegerType[32]] = result(kw_only=True)
class AttributeSpecifierOp(
TestFieldSpecifiers.Operation, name="attribute_specifier"
):
a: IntegerAttr = attribute()
b: IntegerAttr = attribute(
default_factory=lambda: IntegerAttr.get(IntegerType.get_signless(32), 42)
)
c: StringAttr["a"] | StringAttr["b"] = attribute(
default_factory=lambda: StringAttr.get("a")
)
d: IntegerAttr = attribute(kw_only=True)
```
---------
Co-authored-by: Rolf Morel <rolfmorel@gmail.com>
This commit is contained in:
parent
3b76b85b15
commit
e568136e94
@ -18,6 +18,7 @@ from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from inspect import Parameter, Signature
|
||||
from types import UnionType
|
||||
from enum import Enum
|
||||
from . import irdl
|
||||
from ._ods_common import _cext, segmented_accessor
|
||||
from .irdl import Variadicity
|
||||
@ -35,6 +36,8 @@ __all__ = [
|
||||
"Type",
|
||||
"Attribute",
|
||||
"result",
|
||||
"operand",
|
||||
"attribute",
|
||||
]
|
||||
|
||||
Operand = ir.Value
|
||||
@ -108,25 +111,71 @@ class ConstraintLoweringContext:
|
||||
|
||||
@dataclass
|
||||
class FieldSpecifier:
|
||||
type_: Any = None
|
||||
infer_type: bool = False
|
||||
default_is_none: bool = False
|
||||
default_factory: Optional[Callable[[], Any]] = None
|
||||
kw_only: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.infer_type and self.default_is_none:
|
||||
raise ValueError(
|
||||
"a field cannot be marked with both infer_type and default_is_none"
|
||||
)
|
||||
|
||||
def kw_only(self) -> bool:
|
||||
return self.default_is_none or self.infer_type
|
||||
@property
|
||||
def param_kind(self):
|
||||
if self.default_is_none or self.default_factory or self.infer_type:
|
||||
return ParameterKind.KEYWORD_ONLY_WITH_DEFAULT
|
||||
if self.kw_only:
|
||||
return ParameterKind.KEYWORD_ONLY_WITHOUT_DEFAULT
|
||||
return ParameterKind.POSITIONAL_OR_KEYWORD
|
||||
|
||||
|
||||
def result(*, infer_type: bool = False) -> Any:
|
||||
def result(
|
||||
*,
|
||||
infer_type: bool = False,
|
||||
default_factory: Optional[Callable[[], Any]] = None,
|
||||
kw_only: bool = False,
|
||||
) -> Result:
|
||||
"""
|
||||
A field specifier for `Result` definitions.
|
||||
"""
|
||||
if infer_type and default_factory:
|
||||
raise ValueError(
|
||||
"a result field cannot have both infer_type and default_factory"
|
||||
)
|
||||
|
||||
return FieldSpecifier(infer_type=infer_type)
|
||||
return FieldSpecifier(
|
||||
type_=Result,
|
||||
infer_type=infer_type,
|
||||
default_factory=default_factory,
|
||||
kw_only=kw_only,
|
||||
)
|
||||
|
||||
|
||||
def operand(
|
||||
*,
|
||||
kw_only: bool = False,
|
||||
) -> Operand:
|
||||
"""
|
||||
A field specifier for `Operand` definitions.
|
||||
"""
|
||||
|
||||
return FieldSpecifier(
|
||||
type_=Operand,
|
||||
kw_only=kw_only,
|
||||
)
|
||||
|
||||
|
||||
def attribute(
|
||||
*,
|
||||
default_factory: Optional[Callable[[], Any]] = None,
|
||||
kw_only: bool = False,
|
||||
) -> ir.Attribute:
|
||||
"""
|
||||
A field specifier for attribute definitions.
|
||||
"""
|
||||
|
||||
return FieldSpecifier(
|
||||
type_=Attribute,
|
||||
default_factory=default_factory,
|
||||
kw_only=kw_only,
|
||||
)
|
||||
|
||||
|
||||
def infer_type_impl(type_) -> Callable[[], ir.Type]:
|
||||
@ -149,6 +198,12 @@ def infer_type_impl(type_) -> Callable[[], ir.Type]:
|
||||
raise TypeError(f"unsupported type for inferring: {type_}")
|
||||
|
||||
|
||||
class ParameterKind(Enum):
|
||||
POSITIONAL_OR_KEYWORD = 1
|
||||
KEYWORD_ONLY_WITHOUT_DEFAULT = 2
|
||||
KEYWORD_ONLY_WITH_DEFAULT = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
class FieldDef:
|
||||
"""
|
||||
@ -159,7 +214,7 @@ class FieldDef:
|
||||
variadicity: Variadicity
|
||||
constraint: Any
|
||||
|
||||
kw_only: bool = False
|
||||
param_kind: ParameterKind = ParameterKind.POSITIONAL_OR_KEYWORD
|
||||
|
||||
@staticmethod
|
||||
def from_type_hint(name, type_, specifier) -> "FieldDef":
|
||||
@ -173,46 +228,72 @@ class FieldDef:
|
||||
|
||||
origin = get_origin(type_)
|
||||
if origin is ir.OpResult:
|
||||
if specifier.type_ and specifier.type_ is not Result:
|
||||
raise TypeError(
|
||||
f"only `result` field specifier can be used for result fields"
|
||||
)
|
||||
constraint = get_args(type_)[0]
|
||||
return ResultDef(
|
||||
name,
|
||||
variadicity,
|
||||
constraint,
|
||||
kw_only=specifier.kw_only(),
|
||||
param_kind=specifier.param_kind,
|
||||
default_factory=specifier.default_factory,
|
||||
default_is_none=specifier.default_is_none,
|
||||
infer_type=(
|
||||
infer_type_impl(constraint) if specifier.infer_type else None
|
||||
),
|
||||
)
|
||||
elif origin is ir.Value:
|
||||
if specifier.type_ and specifier.type_ is not Operand:
|
||||
raise TypeError(
|
||||
f"only `operand` field specifier can be used for operand fields"
|
||||
)
|
||||
return OperandDef(
|
||||
name,
|
||||
variadicity,
|
||||
get_args(type_)[0],
|
||||
kw_only=specifier.kw_only(),
|
||||
param_kind=specifier.param_kind,
|
||||
default_is_none=specifier.default_is_none,
|
||||
)
|
||||
elif type_ is ir.Region:
|
||||
if specifier.type_ and specifier.type_ is not Region:
|
||||
raise TypeError(
|
||||
f"this field specifier can not be used for region fields"
|
||||
)
|
||||
return RegionDef(name, variadicity, Any)
|
||||
return AttributeDef(name, variadicity, type_)
|
||||
|
||||
if specifier.type_ and specifier.type_ is not Attribute:
|
||||
raise TypeError(
|
||||
f"only `attribute` field specifier can be used for attribute fields"
|
||||
)
|
||||
return AttributeDef(
|
||||
name,
|
||||
variadicity,
|
||||
type_,
|
||||
param_kind=specifier.param_kind,
|
||||
default_factory=specifier.default_factory,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OperandDef(FieldDef):
|
||||
default_is_none: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.variadicity != Variadicity.optional and self.kw_only:
|
||||
raise ValueError(f"only optional operand can be a keyword parameter")
|
||||
if self.variadicity != Variadicity.optional and self.default_is_none:
|
||||
raise ValueError(f"only optional operand can be set to None")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResultDef(FieldDef):
|
||||
infer_type: Callable[[], ir.Type] | None = None
|
||||
default_factory: Optional[Callable[[], Any]] = None
|
||||
default_is_none: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if (
|
||||
self.variadicity != Variadicity.optional
|
||||
and not self.infer_type
|
||||
and self.kw_only
|
||||
):
|
||||
raise ValueError(f"only optional result can be a keyword parameter")
|
||||
if self.variadicity != Variadicity.optional and self.default_is_none:
|
||||
raise ValueError(f"only optional result can be set to None")
|
||||
|
||||
if self.infer_type and self.variadicity != Variadicity.single:
|
||||
raise ValueError(
|
||||
@ -226,15 +307,33 @@ class ResultDef(FieldDef):
|
||||
if self.infer_type:
|
||||
return self.infer_type()
|
||||
|
||||
if self.default_factory:
|
||||
return self.default_factory()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttributeDef(FieldDef):
|
||||
default_factory: Optional[Callable[[], Any]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.variadicity != Variadicity.single:
|
||||
raise ValueError("optional attribute is not currently supported")
|
||||
if (
|
||||
self.param_kind == ParameterKind.KEYWORD_ONLY_WITH_DEFAULT
|
||||
and not self.default_factory
|
||||
):
|
||||
raise ValueError(f"only optional attribute can be set to None")
|
||||
|
||||
def process_attr(self, attr):
|
||||
if attr:
|
||||
return attr
|
||||
|
||||
if self.default_factory:
|
||||
return self.default_factory()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -415,10 +514,15 @@ class Operation(ir.OpView):
|
||||
params = [Parameter("self", Parameter.POSITIONAL_ONLY)]
|
||||
|
||||
for i in args:
|
||||
if i.kw_only:
|
||||
params.append(Parameter(i.name, Parameter.KEYWORD_ONLY, default=None))
|
||||
else:
|
||||
params.append(Parameter(i.name, Parameter.POSITIONAL_OR_KEYWORD))
|
||||
match i.param_kind:
|
||||
case ParameterKind.POSITIONAL_OR_KEYWORD:
|
||||
params.append(Parameter(i.name, Parameter.POSITIONAL_OR_KEYWORD))
|
||||
case ParameterKind.KEYWORD_ONLY_WITH_DEFAULT:
|
||||
params.append(
|
||||
Parameter(i.name, Parameter.KEYWORD_ONLY, default=None)
|
||||
)
|
||||
case ParameterKind.KEYWORD_ONLY_WITHOUT_DEFAULT:
|
||||
params.append(Parameter(i.name, Parameter.KEYWORD_ONLY))
|
||||
|
||||
params.append(Parameter("loc", Parameter.KEYWORD_ONLY, default=None))
|
||||
params.append(Parameter("ip", Parameter.KEYWORD_ONLY, default=None))
|
||||
@ -439,9 +543,7 @@ class Operation(ir.OpView):
|
||||
_operands = [args[operand.name] for operand in operands]
|
||||
_results = [result.process_type(args[result.name]) for result in results]
|
||||
_attributes = dict(
|
||||
(attr.name, args[attr.name])
|
||||
for attr in attrs
|
||||
if args[attr.name] is not None
|
||||
(attr.name, attr.process_attr(args[attr.name])) for attr in attrs
|
||||
)
|
||||
_regions = len(regions) or None
|
||||
_ods_successors = None
|
||||
|
||||
@ -753,7 +753,49 @@ def testExtDialectWithInvalidOp():
|
||||
a: Operand[IntegerType[32]] = None
|
||||
|
||||
except ValueError as e:
|
||||
# CHECK: only optional operand can be a keyword parameter
|
||||
# CHECK: only optional operand can be set to None
|
||||
print(e)
|
||||
|
||||
try:
|
||||
|
||||
class AssignNoneOnnAttributeOp(
|
||||
TestInvalid.Operation, name="assign_none_on_attribute"
|
||||
):
|
||||
a: IntegerAttr = None
|
||||
|
||||
except ValueError as e:
|
||||
# CHECK: only optional attribute can be set to None
|
||||
print(e)
|
||||
|
||||
try:
|
||||
|
||||
class CannotInferTypeOp(TestInvalid.Operation, name="cannot_infer_type"):
|
||||
a: Result[IntegerType] = result(infer_type=True)
|
||||
|
||||
except TypeError as e:
|
||||
# CHECK: unsupported type for inferring
|
||||
print(e)
|
||||
|
||||
try:
|
||||
|
||||
class WrongFieldSpecifierOp(
|
||||
TestInvalid.Operation, name="wrong_field_specifier"
|
||||
):
|
||||
a: Result[IntegerType] = operand()
|
||||
|
||||
except TypeError as e:
|
||||
# CHECK: only `result` field specifier can be used for result fields
|
||||
print(e)
|
||||
|
||||
try:
|
||||
|
||||
class WrongFieldSpecifierOp2(
|
||||
TestInvalid.Operation, name="wrong_field_specifier2"
|
||||
):
|
||||
a: IntegerAttr = operand()
|
||||
|
||||
except TypeError as e:
|
||||
# CHECK: only `attribute` field specifier can be used for attribute fields
|
||||
print(e)
|
||||
|
||||
|
||||
@ -796,3 +838,59 @@ def testExtDialectWithAttrInOp():
|
||||
# CHECK: "ext_attr_in_op.op_with_attr"() {a = 42 : i32, b = i32} : () -> ()
|
||||
# CHECK: "ext_attr_in_op.op_with_attr"() {a = "hello", b = i64} : () -> ()
|
||||
print(module)
|
||||
|
||||
|
||||
@run
|
||||
def testExtDialectFieldSpecifiers():
|
||||
class TestFieldSpecifiers(Dialect, name="ext_field_specifiers"):
|
||||
pass
|
||||
|
||||
class OperandSpecifierOp(TestFieldSpecifiers.Operation, name="operand_specifier"):
|
||||
a: Operand[IntegerType[32]] = operand()
|
||||
b: Optional[Operand[IntegerType[32]]] = None
|
||||
c: Operand[IntegerType[32]] = operand(kw_only=True)
|
||||
|
||||
class ResultSpecifierOp(TestFieldSpecifiers.Operation, name="result_specifier"):
|
||||
a: Result[IntegerType[32]] = result()
|
||||
b: Result[IntegerType[16]] = result(infer_type=True)
|
||||
c: Result[IntegerType] = result(
|
||||
default_factory=lambda: IntegerType.get_signless(8)
|
||||
)
|
||||
d: Sequence[Result[IntegerType]] = result(default_factory=list)
|
||||
e: Result[IntegerType[32]] = result(kw_only=True)
|
||||
|
||||
class AttributeSpecifierOp(
|
||||
TestFieldSpecifiers.Operation, name="attribute_specifier"
|
||||
):
|
||||
a: IntegerAttr = attribute()
|
||||
b: IntegerAttr = attribute(
|
||||
default_factory=lambda: IntegerAttr.get(IntegerType.get_signless(32), 42)
|
||||
)
|
||||
c: StringAttr["a"] | StringAttr["b"] = attribute(
|
||||
default_factory=lambda: StringAttr.get("a"), kw_only=True
|
||||
)
|
||||
d: IntegerAttr = attribute(kw_only=True)
|
||||
|
||||
with Context(), Location.unknown():
|
||||
TestFieldSpecifiers.load()
|
||||
|
||||
# CHECK: (self, /, a, *, b=None, c, loc=None, ip=None)
|
||||
print(OperandSpecifierOp.__init__.__signature__)
|
||||
# CHECK: (self, /, a, *, b=None, c=None, d=None, e, loc=None, ip=None)
|
||||
print(ResultSpecifierOp.__init__.__signature__)
|
||||
# CHECK: (self, /, a, *, b=None, c=None, d, loc=None, ip=None)
|
||||
print(AttributeSpecifierOp.__init__.__signature__)
|
||||
|
||||
module = Module.create()
|
||||
i32 = IntegerType.get_signless(32)
|
||||
with InsertionPoint(module.body):
|
||||
one = arith.constant(i32, 1)
|
||||
OperandSpecifierOp(one, c=one)
|
||||
ResultSpecifierOp(i32, e=i32)
|
||||
AttributeSpecifierOp(IntegerAttr.get(i32, 43), d=IntegerAttr.get(i32, 100))
|
||||
|
||||
assert module.operation.verify()
|
||||
# CHECK: "ext_field_specifiers.operand_specifier"(%c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 0, 1>} : (i32, i32) -> ()
|
||||
# CHECK: %0:4 = "ext_field_specifiers.result_specifier"() {resultSegmentSizes = array<i32: 1, 1, 1, 0, 1>} : () -> (i32, i16, i8, i32)
|
||||
# CHECK: "ext_field_specifiers.attribute_specifier"() {a = 43 : i32, b = 42 : i32, c = "a", d = 100 : i32} : () -> ()
|
||||
print(module)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user