diff --git a/mlir/python/mlir/dialects/ext.py b/mlir/python/mlir/dialects/ext.py index dfcd7f2d641d..89051bf9ec92 100644 --- a/mlir/python/mlir/dialects/ext.py +++ b/mlir/python/mlir/dialects/ext.py @@ -18,6 +18,7 @@ from collections.abc import Sequence from dataclasses import dataclass from inspect import Parameter, Signature from types import UnionType +from enum import Enum from . import irdl from ._ods_common import _cext, segmented_accessor from .irdl import Variadicity @@ -35,6 +36,8 @@ __all__ = [ "Type", "Attribute", "result", + "operand", + "attribute", ] Operand = ir.Value @@ -108,25 +111,71 @@ class ConstraintLoweringContext: @dataclass class FieldSpecifier: + type_: Any = None infer_type: bool = False default_is_none: bool = False + default_factory: Optional[Callable[[], Any]] = None + kw_only: bool = False - def __post_init__(self): - if self.infer_type and self.default_is_none: - raise ValueError( - "a field cannot be marked with both infer_type and default_is_none" - ) - - def kw_only(self) -> bool: - return self.default_is_none or self.infer_type + @property + def param_kind(self): + if self.default_is_none or self.default_factory or self.infer_type: + return ParameterKind.KEYWORD_ONLY_WITH_DEFAULT + if self.kw_only: + return ParameterKind.KEYWORD_ONLY_WITHOUT_DEFAULT + return ParameterKind.POSITIONAL_OR_KEYWORD -def result(*, infer_type: bool = False) -> Any: +def result( + *, + infer_type: bool = False, + default_factory: Optional[Callable[[], Any]] = None, + kw_only: bool = False, +) -> Result: """ A field specifier for `Result` definitions. """ + if infer_type and default_factory: + raise ValueError( + "a result field cannot have both infer_type and default_factory" + ) - return FieldSpecifier(infer_type=infer_type) + return FieldSpecifier( + type_=Result, + infer_type=infer_type, + default_factory=default_factory, + kw_only=kw_only, + ) + + +def operand( + *, + kw_only: bool = False, +) -> Operand: + """ + A field specifier for `Operand` definitions. + """ + + return FieldSpecifier( + type_=Operand, + kw_only=kw_only, + ) + + +def attribute( + *, + default_factory: Optional[Callable[[], Any]] = None, + kw_only: bool = False, +) -> ir.Attribute: + """ + A field specifier for attribute definitions. + """ + + return FieldSpecifier( + type_=Attribute, + default_factory=default_factory, + kw_only=kw_only, + ) def infer_type_impl(type_) -> Callable[[], ir.Type]: @@ -149,6 +198,12 @@ def infer_type_impl(type_) -> Callable[[], ir.Type]: raise TypeError(f"unsupported type for inferring: {type_}") +class ParameterKind(Enum): + POSITIONAL_OR_KEYWORD = 1 + KEYWORD_ONLY_WITHOUT_DEFAULT = 2 + KEYWORD_ONLY_WITH_DEFAULT = 3 + + @dataclass class FieldDef: """ @@ -159,7 +214,7 @@ class FieldDef: variadicity: Variadicity constraint: Any - kw_only: bool = False + param_kind: ParameterKind = ParameterKind.POSITIONAL_OR_KEYWORD @staticmethod def from_type_hint(name, type_, specifier) -> "FieldDef": @@ -173,46 +228,72 @@ class FieldDef: origin = get_origin(type_) if origin is ir.OpResult: + if specifier.type_ and specifier.type_ is not Result: + raise TypeError( + f"only `result` field specifier can be used for result fields" + ) constraint = get_args(type_)[0] return ResultDef( name, variadicity, constraint, - kw_only=specifier.kw_only(), + param_kind=specifier.param_kind, + default_factory=specifier.default_factory, + default_is_none=specifier.default_is_none, infer_type=( infer_type_impl(constraint) if specifier.infer_type else None ), ) elif origin is ir.Value: + if specifier.type_ and specifier.type_ is not Operand: + raise TypeError( + f"only `operand` field specifier can be used for operand fields" + ) return OperandDef( name, variadicity, get_args(type_)[0], - kw_only=specifier.kw_only(), + param_kind=specifier.param_kind, + default_is_none=specifier.default_is_none, ) elif type_ is ir.Region: + if specifier.type_ and specifier.type_ is not Region: + raise TypeError( + f"this field specifier can not be used for region fields" + ) return RegionDef(name, variadicity, Any) - return AttributeDef(name, variadicity, type_) + + if specifier.type_ and specifier.type_ is not Attribute: + raise TypeError( + f"only `attribute` field specifier can be used for attribute fields" + ) + return AttributeDef( + name, + variadicity, + type_, + param_kind=specifier.param_kind, + default_factory=specifier.default_factory, + ) @dataclass class OperandDef(FieldDef): + default_is_none: bool = False + def __post_init__(self): - if self.variadicity != Variadicity.optional and self.kw_only: - raise ValueError(f"only optional operand can be a keyword parameter") + if self.variadicity != Variadicity.optional and self.default_is_none: + raise ValueError(f"only optional operand can be set to None") @dataclass class ResultDef(FieldDef): infer_type: Callable[[], ir.Type] | None = None + default_factory: Optional[Callable[[], Any]] = None + default_is_none: bool = False def __post_init__(self): - if ( - self.variadicity != Variadicity.optional - and not self.infer_type - and self.kw_only - ): - raise ValueError(f"only optional result can be a keyword parameter") + if self.variadicity != Variadicity.optional and self.default_is_none: + raise ValueError(f"only optional result can be set to None") if self.infer_type and self.variadicity != Variadicity.single: raise ValueError( @@ -226,15 +307,33 @@ class ResultDef(FieldDef): if self.infer_type: return self.infer_type() + if self.default_factory: + return self.default_factory() + return None @dataclass class AttributeDef(FieldDef): + default_factory: Optional[Callable[[], Any]] = None def __post_init__(self): if self.variadicity != Variadicity.single: raise ValueError("optional attribute is not currently supported") + if ( + self.param_kind == ParameterKind.KEYWORD_ONLY_WITH_DEFAULT + and not self.default_factory + ): + raise ValueError(f"only optional attribute can be set to None") + + def process_attr(self, attr): + if attr: + return attr + + if self.default_factory: + return self.default_factory() + + return None @dataclass @@ -415,10 +514,15 @@ class Operation(ir.OpView): params = [Parameter("self", Parameter.POSITIONAL_ONLY)] for i in args: - if i.kw_only: - params.append(Parameter(i.name, Parameter.KEYWORD_ONLY, default=None)) - else: - params.append(Parameter(i.name, Parameter.POSITIONAL_OR_KEYWORD)) + match i.param_kind: + case ParameterKind.POSITIONAL_OR_KEYWORD: + params.append(Parameter(i.name, Parameter.POSITIONAL_OR_KEYWORD)) + case ParameterKind.KEYWORD_ONLY_WITH_DEFAULT: + params.append( + Parameter(i.name, Parameter.KEYWORD_ONLY, default=None) + ) + case ParameterKind.KEYWORD_ONLY_WITHOUT_DEFAULT: + params.append(Parameter(i.name, Parameter.KEYWORD_ONLY)) params.append(Parameter("loc", Parameter.KEYWORD_ONLY, default=None)) params.append(Parameter("ip", Parameter.KEYWORD_ONLY, default=None)) @@ -439,9 +543,7 @@ class Operation(ir.OpView): _operands = [args[operand.name] for operand in operands] _results = [result.process_type(args[result.name]) for result in results] _attributes = dict( - (attr.name, args[attr.name]) - for attr in attrs - if args[attr.name] is not None + (attr.name, attr.process_attr(args[attr.name])) for attr in attrs ) _regions = len(regions) or None _ods_successors = None diff --git a/mlir/test/python/dialects/ext.py b/mlir/test/python/dialects/ext.py index 90f68b876b13..8dfc74ad29d4 100644 --- a/mlir/test/python/dialects/ext.py +++ b/mlir/test/python/dialects/ext.py @@ -753,7 +753,49 @@ def testExtDialectWithInvalidOp(): a: Operand[IntegerType[32]] = None except ValueError as e: - # CHECK: only optional operand can be a keyword parameter + # CHECK: only optional operand can be set to None + print(e) + + try: + + class AssignNoneOnnAttributeOp( + TestInvalid.Operation, name="assign_none_on_attribute" + ): + a: IntegerAttr = None + + except ValueError as e: + # CHECK: only optional attribute can be set to None + print(e) + + try: + + class CannotInferTypeOp(TestInvalid.Operation, name="cannot_infer_type"): + a: Result[IntegerType] = result(infer_type=True) + + except TypeError as e: + # CHECK: unsupported type for inferring + print(e) + + try: + + class WrongFieldSpecifierOp( + TestInvalid.Operation, name="wrong_field_specifier" + ): + a: Result[IntegerType] = operand() + + except TypeError as e: + # CHECK: only `result` field specifier can be used for result fields + print(e) + + try: + + class WrongFieldSpecifierOp2( + TestInvalid.Operation, name="wrong_field_specifier2" + ): + a: IntegerAttr = operand() + + except TypeError as e: + # CHECK: only `attribute` field specifier can be used for attribute fields print(e) @@ -796,3 +838,59 @@ def testExtDialectWithAttrInOp(): # CHECK: "ext_attr_in_op.op_with_attr"() {a = 42 : i32, b = i32} : () -> () # CHECK: "ext_attr_in_op.op_with_attr"() {a = "hello", b = i64} : () -> () print(module) + + +@run +def testExtDialectFieldSpecifiers(): + class TestFieldSpecifiers(Dialect, name="ext_field_specifiers"): + pass + + class OperandSpecifierOp(TestFieldSpecifiers.Operation, name="operand_specifier"): + a: Operand[IntegerType[32]] = operand() + b: Optional[Operand[IntegerType[32]]] = None + c: Operand[IntegerType[32]] = operand(kw_only=True) + + class ResultSpecifierOp(TestFieldSpecifiers.Operation, name="result_specifier"): + a: Result[IntegerType[32]] = result() + b: Result[IntegerType[16]] = result(infer_type=True) + c: Result[IntegerType] = result( + default_factory=lambda: IntegerType.get_signless(8) + ) + d: Sequence[Result[IntegerType]] = result(default_factory=list) + e: Result[IntegerType[32]] = result(kw_only=True) + + class AttributeSpecifierOp( + TestFieldSpecifiers.Operation, name="attribute_specifier" + ): + a: IntegerAttr = attribute() + b: IntegerAttr = attribute( + default_factory=lambda: IntegerAttr.get(IntegerType.get_signless(32), 42) + ) + c: StringAttr["a"] | StringAttr["b"] = attribute( + default_factory=lambda: StringAttr.get("a"), kw_only=True + ) + d: IntegerAttr = attribute(kw_only=True) + + with Context(), Location.unknown(): + TestFieldSpecifiers.load() + + # CHECK: (self, /, a, *, b=None, c, loc=None, ip=None) + print(OperandSpecifierOp.__init__.__signature__) + # CHECK: (self, /, a, *, b=None, c=None, d=None, e, loc=None, ip=None) + print(ResultSpecifierOp.__init__.__signature__) + # CHECK: (self, /, a, *, b=None, c=None, d, loc=None, ip=None) + print(AttributeSpecifierOp.__init__.__signature__) + + module = Module.create() + i32 = IntegerType.get_signless(32) + with InsertionPoint(module.body): + one = arith.constant(i32, 1) + OperandSpecifierOp(one, c=one) + ResultSpecifierOp(i32, e=i32) + AttributeSpecifierOp(IntegerAttr.get(i32, 43), d=IntegerAttr.get(i32, 100)) + + assert module.operation.verify() + # CHECK: "ext_field_specifiers.operand_specifier"(%c1_i32, %c1_i32) {operandSegmentSizes = array} : (i32, i32) -> () + # CHECK: %0:4 = "ext_field_specifiers.result_specifier"() {resultSegmentSizes = array} : () -> (i32, i16, i8, i32) + # CHECK: "ext_field_specifiers.attribute_specifier"() {a = 43 : i32, b = 42 : i32, c = "a", d = 100 : i32} : () -> () + print(module)