diff --git a/mlir/include/mlir-c/Dialect/PDL.h b/mlir/include/mlir-c/Dialect/PDL.h index 6ad2e2da62d8..d04f69e391b1 100644 --- a/mlir/include/mlir-c/Dialect/PDL.h +++ b/mlir/include/mlir-c/Dialect/PDL.h @@ -30,6 +30,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLType(MlirType type); MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLAttributeType(MlirType type); +MLIR_CAPI_EXPORTED MlirTypeID mlirPDLAttributeTypeGetTypeID(void); + MLIR_CAPI_EXPORTED MlirType mlirPDLAttributeTypeGet(MlirContext ctx); //===---------------------------------------------------------------------===// @@ -38,6 +40,8 @@ MLIR_CAPI_EXPORTED MlirType mlirPDLAttributeTypeGet(MlirContext ctx); MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLOperationType(MlirType type); +MLIR_CAPI_EXPORTED MlirTypeID mlirPDLOperationTypeGetTypeID(void); + MLIR_CAPI_EXPORTED MlirType mlirPDLOperationTypeGet(MlirContext ctx); //===---------------------------------------------------------------------===// @@ -46,6 +50,8 @@ MLIR_CAPI_EXPORTED MlirType mlirPDLOperationTypeGet(MlirContext ctx); MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLRangeType(MlirType type); +MLIR_CAPI_EXPORTED MlirTypeID mlirPDLRangeTypeGetTypeID(void); + MLIR_CAPI_EXPORTED MlirType mlirPDLRangeTypeGet(MlirType elementType); MLIR_CAPI_EXPORTED MlirType mlirPDLRangeTypeGetElementType(MlirType type); @@ -56,6 +62,8 @@ MLIR_CAPI_EXPORTED MlirType mlirPDLRangeTypeGetElementType(MlirType type); MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLTypeType(MlirType type); +MLIR_CAPI_EXPORTED MlirTypeID mlirPDLTypeTypeGetTypeID(void); + MLIR_CAPI_EXPORTED MlirType mlirPDLTypeTypeGet(MlirContext ctx); //===---------------------------------------------------------------------===// @@ -64,6 +72,8 @@ MLIR_CAPI_EXPORTED MlirType mlirPDLTypeTypeGet(MlirContext ctx); MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLValueType(MlirType type); +MLIR_CAPI_EXPORTED MlirTypeID mlirPDLValueTypeGetTypeID(void); + MLIR_CAPI_EXPORTED MlirType mlirPDLValueTypeGet(MlirContext ctx); #ifdef __cplusplus diff --git a/mlir/include/mlir-c/Dialect/Quant.h b/mlir/include/mlir-c/Dialect/Quant.h index dc0989e53344..f961c01d5dc2 100644 --- a/mlir/include/mlir-c/Dialect/Quant.h +++ b/mlir/include/mlir-c/Dialect/Quant.h @@ -103,6 +103,8 @@ mlirQuantizedTypeCastExpressedToStorageType(MlirType type, MlirType candidate); /// Returns `true` if the given type is an AnyQuantizedType. MLIR_CAPI_EXPORTED bool mlirTypeIsAAnyQuantizedType(MlirType type); +MLIR_CAPI_EXPORTED MlirTypeID mlirAnyQuantizedTypeGetTypeID(void); + /// Creates an instance of AnyQuantizedType with the given parameters in the /// same context as `storageType` and returns it. The instance is owned by the /// context. @@ -119,6 +121,8 @@ MLIR_CAPI_EXPORTED MlirType mlirAnyQuantizedTypeGet(unsigned flags, /// Returns `true` if the given type is a UniformQuantizedType. MLIR_CAPI_EXPORTED bool mlirTypeIsAUniformQuantizedType(MlirType type); +MLIR_CAPI_EXPORTED MlirTypeID mlirUniformQuantizedTypeGetTypeID(void); + /// Creates an instance of UniformQuantizedType with the given parameters in the /// same context as `storageType` and returns it. The instance is owned by the /// context. @@ -142,6 +146,8 @@ MLIR_CAPI_EXPORTED bool mlirUniformQuantizedTypeIsFixedPoint(MlirType type); /// Returns `true` if the given type is a UniformQuantizedPerAxisType. MLIR_CAPI_EXPORTED bool mlirTypeIsAUniformQuantizedPerAxisType(MlirType type); +MLIR_CAPI_EXPORTED MlirTypeID mlirUniformQuantizedPerAxisTypeGetTypeID(void); + /// Creates an instance of UniformQuantizedPerAxisType with the given parameters /// in the same context as `storageType` and returns it. `scales` and /// `zeroPoints` point to `nDims` number of elements. The instance is owned @@ -180,6 +186,8 @@ mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type); MLIR_CAPI_EXPORTED bool mlirTypeIsAUniformQuantizedSubChannelType(MlirType type); +MLIR_CAPI_EXPORTED MlirTypeID mlirUniformQuantizedSubChannelTypeGetTypeID(void); + /// Creates a UniformQuantizedSubChannelType with the given parameters. /// /// The type is owned by the context. `scalesAttr` and `zeroPointsAttr` must be @@ -220,6 +228,8 @@ mlirUniformQuantizedSubChannelTypeGetZeroPoints(MlirType type); /// Returns `true` if the given type is a CalibratedQuantizedType. MLIR_CAPI_EXPORTED bool mlirTypeIsACalibratedQuantizedType(MlirType type); +MLIR_CAPI_EXPORTED MlirTypeID mlirCalibratedQuantizedTypeGetTypeID(void); + /// Creates an instance of CalibratedQuantizedType with the given parameters /// in the same context as `expressedType` and returns it. The instance is owned /// by the context. diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h index 4ff5061b945a..4930ce5ca6b8 100644 --- a/mlir/include/mlir/Bindings/Python/IRCore.h +++ b/mlir/include/mlir/Bindings/Python/IRCore.h @@ -957,12 +957,6 @@ public: auto cls = ClassTy(m, DerivedTy::pyClassName); cls.def(nanobind::init(), nanobind::keep_alive<0, 1>(), nanobind::arg("cast_from_type")); - cls.def_static( - "isinstance", - [](PyType &otherType) -> bool { - return DerivedTy::isaFunction(otherType); - }, - nanobind::arg("other")); cls.def_prop_ro_static( "static_typeid", [](nanobind::object & /*class*/) { @@ -1094,12 +1088,6 @@ public: } cls.def(nanobind::init(), nanobind::keep_alive<0, 1>(), nanobind::arg("cast_from_attr")); - cls.def_static( - "isinstance", - [](PyAttribute &otherAttr) -> bool { - return DerivedTy::isaFunction(otherAttr); - }, - nanobind::arg("other")); cls.def_prop_ro( "type", [](PyAttribute &attr) -> nanobind::typed { @@ -1555,12 +1543,6 @@ public: .c_str())); cls.def(nanobind::init(), nanobind::keep_alive<0, 1>(), nanobind::arg("value")); - cls.def_static( - "isinstance", - [](PyValue &otherValue) -> bool { - return DerivedTy::isaFunction(otherValue); - }, - nanobind::arg("other_value")); cls.def( MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](DerivedTy &self) -> nanobind::typed { diff --git a/mlir/lib/Bindings/Python/DialectPDL.cpp b/mlir/lib/Bindings/Python/DialectPDL.cpp index ac72734ea5c2..17d0a8312701 100644 --- a/mlir/lib/Bindings/Python/DialectPDL.cpp +++ b/mlir/lib/Bindings/Python/DialectPDL.cpp @@ -39,6 +39,8 @@ struct PDLType : PyConcreteType { struct AttributeType : PyConcreteType { static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLAttributeType; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirPDLAttributeTypeGetTypeID; static constexpr const char *pyClassName = "AttributeType"; using Base::Base; @@ -60,6 +62,8 @@ struct AttributeType : PyConcreteType { struct OperationType : PyConcreteType { static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLOperationType; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirPDLOperationTypeGetTypeID; static constexpr const char *pyClassName = "OperationType"; using Base::Base; @@ -81,6 +85,8 @@ struct OperationType : PyConcreteType { struct RangeType : PyConcreteType { static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLRangeType; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirPDLRangeTypeGetTypeID; static constexpr const char *pyClassName = "RangeType"; using Base::Base; @@ -109,6 +115,8 @@ struct RangeType : PyConcreteType { struct TypeType : PyConcreteType { static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLTypeType; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirPDLTypeTypeGetTypeID; static constexpr const char *pyClassName = "TypeType"; using Base::Base; @@ -130,6 +138,8 @@ struct TypeType : PyConcreteType { struct ValueType : PyConcreteType { static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLValueType; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirPDLValueTypeGetTypeID; static constexpr const char *pyClassName = "ValueType"; using Base::Base; diff --git a/mlir/lib/Bindings/Python/DialectQuant.cpp b/mlir/lib/Bindings/Python/DialectQuant.cpp index 9c6a15c97134..3a6a91f3058a 100644 --- a/mlir/lib/Bindings/Python/DialectQuant.cpp +++ b/mlir/lib/Bindings/Python/DialectQuant.cpp @@ -192,6 +192,8 @@ struct QuantizedType : PyConcreteType { struct AnyQuantizedType : PyConcreteType { static constexpr IsAFunctionTy isaFunction = mlirTypeIsAAnyQuantizedType; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirAnyQuantizedTypeGetTypeID; static constexpr const char *pyClassName = "AnyQuantizedType"; using Base::Base; @@ -221,6 +223,8 @@ struct AnyQuantizedType : PyConcreteType { struct UniformQuantizedType : PyConcreteType { static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUniformQuantizedType; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirUniformQuantizedTypeGetTypeID; static constexpr const char *pyClassName = "UniformQuantizedType"; using Base::Base; @@ -273,6 +277,8 @@ struct UniformQuantizedPerAxisType : PyConcreteType { static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUniformQuantizedPerAxisType; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirUniformQuantizedPerAxisTypeGetTypeID; static constexpr const char *pyClassName = "UniformQuantizedPerAxisType"; using Base::Base; @@ -357,6 +363,8 @@ struct UniformQuantizedSubChannelType : PyConcreteType { static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUniformQuantizedSubChannelType; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirUniformQuantizedSubChannelTypeGetTypeID; static constexpr const char *pyClassName = "UniformQuantizedSubChannelType"; using Base::Base; @@ -448,6 +456,8 @@ struct CalibratedQuantizedType : PyConcreteType { static constexpr IsAFunctionTy isaFunction = mlirTypeIsACalibratedQuantizedType; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirCalibratedQuantizedTypeGetTypeID; static constexpr const char *pyClassName = "CalibratedQuantizedType"; using Base::Base; diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp index ce235470bbdc..b3d15ee59566 100644 --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -118,12 +118,6 @@ public: static void bind(nb::module_ &m) { auto cls = ClassTy(m, DerivedTy::pyClassName); cls.def(nb::init(), nb::arg("expr")); - cls.def_static( - "isinstance", - [](PyAffineExpr &otherAffineExpr) -> bool { - return DerivedTy::isaFunction(otherAffineExpr); - }, - nb::arg("other")); DerivedTy::bindDerived(cls); } diff --git a/mlir/lib/CAPI/Dialect/PDL.cpp b/mlir/lib/CAPI/Dialect/PDL.cpp index bd8b13c6516e..88cd6056480f 100644 --- a/mlir/lib/CAPI/Dialect/PDL.cpp +++ b/mlir/lib/CAPI/Dialect/PDL.cpp @@ -32,6 +32,10 @@ bool mlirTypeIsAPDLAttributeType(MlirType type) { return isa(unwrap(type)); } +MlirTypeID mlirPDLAttributeTypeGetTypeID(void) { + return wrap(pdl::AttributeType::getTypeID()); +} + MlirType mlirPDLAttributeTypeGet(MlirContext ctx) { return wrap(pdl::AttributeType::get(unwrap(ctx))); } @@ -44,6 +48,10 @@ bool mlirTypeIsAPDLOperationType(MlirType type) { return isa(unwrap(type)); } +MlirTypeID mlirPDLOperationTypeGetTypeID(void) { + return wrap(pdl::OperationType::getTypeID()); +} + MlirType mlirPDLOperationTypeGet(MlirContext ctx) { return wrap(pdl::OperationType::get(unwrap(ctx))); } @@ -56,6 +64,10 @@ bool mlirTypeIsAPDLRangeType(MlirType type) { return isa(unwrap(type)); } +MlirTypeID mlirPDLRangeTypeGetTypeID(void) { + return wrap(pdl::RangeType::getTypeID()); +} + MlirType mlirPDLRangeTypeGet(MlirType elementType) { return wrap(pdl::RangeType::get(unwrap(elementType))); } @@ -72,6 +84,10 @@ bool mlirTypeIsAPDLTypeType(MlirType type) { return isa(unwrap(type)); } +MlirTypeID mlirPDLTypeTypeGetTypeID(void) { + return wrap(pdl::TypeType::getTypeID()); +} + MlirType mlirPDLTypeTypeGet(MlirContext ctx) { return wrap(pdl::TypeType::get(unwrap(ctx))); } @@ -84,6 +100,10 @@ bool mlirTypeIsAPDLValueType(MlirType type) { return isa(unwrap(type)); } +MlirTypeID mlirPDLValueTypeGetTypeID(void) { + return wrap(pdl::ValueType::getTypeID()); +} + MlirType mlirPDLValueTypeGet(MlirContext ctx) { return wrap(pdl::ValueType::get(unwrap(ctx))); } diff --git a/mlir/lib/CAPI/Dialect/Quant.cpp b/mlir/lib/CAPI/Dialect/Quant.cpp index 01a6a948f1dc..840051caab85 100644 --- a/mlir/lib/CAPI/Dialect/Quant.cpp +++ b/mlir/lib/CAPI/Dialect/Quant.cpp @@ -113,6 +113,10 @@ bool mlirTypeIsAAnyQuantizedType(MlirType type) { return isa(unwrap(type)); } +MlirTypeID mlirAnyQuantizedTypeGetTypeID(void) { + return wrap(quant::AnyQuantizedType::getTypeID()); +} + MlirType mlirAnyQuantizedTypeGet(unsigned flags, MlirType storageType, MlirType expressedType, int64_t storageTypeMin, int64_t storageTypeMax) { @@ -129,6 +133,10 @@ bool mlirTypeIsAUniformQuantizedType(MlirType type) { return isa(unwrap(type)); } +MlirTypeID mlirUniformQuantizedTypeGetTypeID(void) { + return wrap(quant::UniformQuantizedType::getTypeID()); +} + MlirType mlirUniformQuantizedTypeGet(unsigned flags, MlirType storageType, MlirType expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, @@ -158,6 +166,10 @@ bool mlirTypeIsAUniformQuantizedPerAxisType(MlirType type) { return isa(unwrap(type)); } +MlirTypeID mlirUniformQuantizedPerAxisTypeGetTypeID(void) { + return wrap(quant::UniformQuantizedPerAxisType::getTypeID()); +} + MlirType mlirUniformQuantizedPerAxisTypeGet( unsigned flags, MlirType storageType, MlirType expressedType, intptr_t nDims, double *scales, int64_t *zeroPoints, @@ -203,6 +215,10 @@ bool mlirTypeIsAUniformQuantizedSubChannelType(MlirType type) { return isa(unwrap(type)); } +MlirTypeID mlirUniformQuantizedSubChannelTypeGetTypeID(void) { + return wrap(quant::UniformQuantizedSubChannelType::getTypeID()); +} + MlirType mlirUniformQuantizedSubChannelTypeGet( unsigned flags, MlirType storageType, MlirType expressedType, MlirAttribute scalesAttr, MlirAttribute zeroPointsAttr, intptr_t nDims, @@ -258,6 +274,10 @@ bool mlirTypeIsACalibratedQuantizedType(MlirType type) { return isa(unwrap(type)); } +MlirTypeID mlirCalibratedQuantizedTypeGetTypeID(void) { + return wrap(quant::CalibratedQuantizedType::getTypeID()); +} + MlirType mlirCalibratedQuantizedTypeGet(MlirType expressedType, double min, double max) { return wrap( diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py index 88e8502a29ea..555fb4c5ef3e 100644 --- a/mlir/python/mlir/dialects/arith.py +++ b/mlir/python/mlir/dialects/arith.py @@ -21,26 +21,6 @@ except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e -def _isa(obj: Any, cls: type): - try: - cls(obj) - except ValueError: - return False - return True - - -def _is_any_of(obj: Any, classes: List[type]): - return any(_isa(obj, cls) for cls in classes) - - -def _is_integer_like_type(type: Type): - return _is_any_of(type, [IntegerType, IndexType]) - - -def _is_float_type(type: Type): - return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type]) - - @_ods_cext.register_operation(_Dialect, replace=True) class ConstantOp(ConstantOp): """Specialization for the constant op class.""" @@ -96,9 +76,9 @@ class ConstantOp(ConstantOp): @property def literal_value(self) -> Union[int, float]: - if _is_integer_like_type(self.type): + if isinstance(self.type, (IntegerType, IndexType)): return IntegerAttr(self.value).value - elif _is_float_type(self.type): + elif isinstance(self.type, FloatType): return FloatAttr(self.value).value else: raise ValueError("only integer and float constants have literal values") diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index fb2570c7bb49..af90f3f8c4e3 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -412,9 +412,9 @@ class _BodyBuilder: ) if operand.type == to_type: return operand - if _is_integer_type(to_type): + if isinstance(to_type, IntegerType): return self._cast_to_integer(to_type, operand, is_unsigned_cast) - elif _is_floating_point_type(to_type): + elif isinstance(to_type, FloatType): return self._cast_to_floating_point(to_type, operand, is_unsigned_cast) def _cast_to_integer( @@ -422,11 +422,11 @@ class _BodyBuilder: ) -> Value: to_width = IntegerType(to_type).width operand_type = operand.type - if _is_floating_point_type(operand_type): + if isinstance(operand_type, FloatType): if is_unsigned_cast: return arith.FPToUIOp(to_type, operand).result return arith.FPToSIOp(to_type, operand).result - if _is_index_type(operand_type): + if isinstance(operand_type, IndexType): return arith.IndexCastOp(to_type, operand).result # Assume integer. from_width = IntegerType(operand_type).width @@ -444,13 +444,15 @@ class _BodyBuilder: self, to_type: Type, operand: Value, is_unsigned_cast: bool ) -> Value: operand_type = operand.type - if _is_integer_type(operand_type): + if isinstance(operand_type, IntegerType): if is_unsigned_cast: return arith.UIToFPOp(to_type, operand).result return arith.SIToFPOp(to_type, operand).result # Assume FloatType. - to_width = _get_floating_point_width(to_type) - from_width = _get_floating_point_width(operand_type) + assert isinstance(to_type, FloatType) + assert isinstance(operand_type, FloatType) + to_width = to_type.width + from_width = operand_type.width if to_width > from_width: return arith.ExtFOp(to_type, operand).result elif to_width < from_width: @@ -466,89 +468,89 @@ class _BodyBuilder: return self._cast(type_var_name, operand, True) def _unary_exp(self, x: Value) -> Value: - if _is_floating_point_type(x.type): + if isinstance(x.type, FloatType): return math.ExpOp(x).result raise NotImplementedError("Unsupported 'exp' operand: {x}") def _unary_log(self, x: Value) -> Value: - if _is_floating_point_type(x.type): + if isinstance(x.type, FloatType): return math.LogOp(x).result raise NotImplementedError("Unsupported 'log' operand: {x}") def _unary_abs(self, x: Value) -> Value: - if _is_floating_point_type(x.type): + if isinstance(x.type, FloatType): return math.AbsFOp(x).result raise NotImplementedError("Unsupported 'abs' operand: {x}") def _unary_ceil(self, x: Value) -> Value: - if _is_floating_point_type(x.type): + if isinstance(x.type, FloatType): return math.CeilOp(x).result raise NotImplementedError("Unsupported 'ceil' operand: {x}") def _unary_floor(self, x: Value) -> Value: - if _is_floating_point_type(x.type): + if isinstance(x.type, FloatType): return math.FloorOp(x).result raise NotImplementedError("Unsupported 'floor' operand: {x}") def _unary_negf(self, x: Value) -> Value: - if _is_floating_point_type(x.type): + if isinstance(x.type, FloatType): return arith.NegFOp(x).result - if _is_complex_type(x.type): + if isinstance(x.type, ComplexType): return complex.NegOp(x).result raise NotImplementedError("Unsupported 'negf' operand: {x}") def _binary_add(self, lhs: Value, rhs: Value) -> Value: - if _is_floating_point_type(lhs.type): + if isinstance(lhs.type, FloatType): return arith.AddFOp(lhs, rhs).result - if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + if isinstance(lhs.type, IntegerType) or isinstance(lhs.type, IndexType): return arith.AddIOp(lhs, rhs).result - if _is_complex_type(lhs.type): + if isinstance(lhs.type, ComplexType): return complex.AddOp(lhs, rhs).result raise NotImplementedError("Unsupported 'add' operands: {lhs}, {rhs}") def _binary_sub(self, lhs: Value, rhs: Value) -> Value: - if _is_floating_point_type(lhs.type): + if isinstance(lhs.type, FloatType): return arith.SubFOp(lhs, rhs).result - if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + if isinstance(lhs.type, IntegerType) or isinstance(lhs.type, IndexType): return arith.SubIOp(lhs, rhs).result - if _is_complex_type(lhs.type): + if isinstance(lhs.type, ComplexType): return complex.SubOp(lhs, rhs).result raise NotImplementedError("Unsupported 'sub' operands: {lhs}, {rhs}") def _binary_mul(self, lhs: Value, rhs: Value) -> Value: - if _is_floating_point_type(lhs.type): + if isinstance(lhs.type, FloatType): return arith.MulFOp(lhs, rhs).result - if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + if isinstance(lhs.type, IntegerType) or isinstance(lhs.type, IndexType): return arith.MulIOp(lhs, rhs).result - if _is_complex_type(lhs.type): + if isinstance(lhs.type, ComplexType): return complex.MulOp(lhs, rhs).result raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}") def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value: - if _is_floating_point_type(lhs.type): + if isinstance(lhs.type, FloatType): return arith.MaximumFOp(lhs, rhs).result - if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + if isinstance(lhs.type, IntegerType) or isinstance(lhs.type, IndexType): return arith.MaxSIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'max' operands: {lhs}, {rhs}") def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value: if ( - _is_integer_type(lhs.type) and not _is_bool_type(lhs.type) - ) or _is_index_type(lhs.type): + isinstance(lhs.type, IntegerType) and not _is_bool_type(lhs.type) + ) or isinstance(lhs.type, IndexType): return arith.MaxUIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'max_unsigned' operands: {lhs}, {rhs}") def _binary_min_signed(self, lhs: Value, rhs: Value) -> Value: - if _is_floating_point_type(lhs.type): + if isinstance(lhs.type, FloatType): return arith.MinimumFOp(lhs, rhs).result - if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + if isinstance(lhs.type, IntegerType) or isinstance(lhs.type, IndexType): return arith.MinSIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'min' operands: {lhs}, {rhs}") def _binary_min_unsigned(self, lhs: Value, rhs: Value) -> Value: if ( - _is_integer_type(lhs.type) and not _is_bool_type(lhs.type) - ) or _is_index_type(lhs.type): + isinstance(lhs.type, IntegerType) and not _is_bool_type(lhs.type) + ) or isinstance(lhs.type, IndexType): return arith.MinUIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'min_unsigned' operands: {lhs}, {rhs}") @@ -611,44 +613,7 @@ def _add_type_mapping( block_arg_types.append(element_or_self_type) -def _is_complex_type(t: Type) -> bool: - return ComplexType.isinstance(t) - - -def _is_floating_point_type(t: Type) -> bool: - # TODO: Create a FloatType in the Python API and implement the switch - # there. - return ( - F64Type.isinstance(t) - or F32Type.isinstance(t) - or F16Type.isinstance(t) - or BF16Type.isinstance(t) - ) - - -def _is_integer_type(t: Type) -> bool: - return IntegerType.isinstance(t) - - -def _is_index_type(t: Type) -> bool: - return IndexType.isinstance(t) - - def _is_bool_type(t: Type) -> bool: - if not IntegerType.isinstance(t): + if not isinstance(t, IntegerType): return False - return IntegerType(t).width == 1 - - -def _get_floating_point_width(t: Type) -> int: - # TODO: Create a FloatType in the Python API and implement the switch - # there. - if F64Type.isinstance(t): - return 64 - if F32Type.isinstance(t): - return 32 - if F16Type.isinstance(t): - return 16 - if BF16Type.isinstance(t): - return 16 - raise NotImplementedError(f"Unhandled floating point type switch {t}") + return t.width == 1 diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py index 910a2356ca0e..34f00a3292b7 100644 --- a/mlir/python/mlir/dialects/memref.py +++ b/mlir/python/mlir/dialects/memref.py @@ -8,15 +8,22 @@ from typing import Optional from ._memref_ops_gen import * from ._memref_ops_gen import _Dialect from ._ods_common import _dispatch_mixed_values, MixedValues -from .arith import ConstantOp, _is_integer_like_type -from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType, Operation +from ..ir import ( + IndexType, + IntegerType, + MemRefType, + ShapedType, + StridedLayoutAttr, + Value, +) +from . import arith def _is_constant_int_like(i): return ( isinstance(i, Value) - and isinstance(i.owner, ConstantOp) - and _is_integer_like_type(i.type) + and isinstance(i.owner, arith.ConstantOp) + and isinstance(i.type, (IntegerType, IndexType)) ) diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td index ff16ad8ca0cd..929851724ba7 100644 --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -233,7 +233,7 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu // CHECK: _ods_result_type_source_attr = attributes["type"] // CHECK: _ods_derived_result_type = ( // CHECK: _ods_ir.TypeAttr(_ods_result_type_source_attr).value - // CHECK: if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else + // CHECK: if isinstance(_ods_result_type_source_attr, _ods_ir.TypeAttr) else // CHECK: _ods_result_type_source_attr.type) // CHECK: results = [_ods_derived_result_type] * 2 let arguments = (ins TypeAttr:$type); diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py index c9af5e7b46db..a4cfb3024023 100644 --- a/mlir/test/python/dialects/arith_dialect.py +++ b/mlir/test/python/dialects/arith_dialect.py @@ -42,10 +42,10 @@ def testFastMathFlags(): def testArithValue(): def _binary_op(lhs, rhs, op: str) -> "ArithValue": op = op.capitalize() - if arith._is_float_type(lhs.type) and arith._is_float_type(rhs.type): + if isinstance(lhs.type, FloatType) and isinstance(rhs.type, FloatType): op += "F" - elif arith._is_integer_like_type(lhs.type) and arith._is_integer_like_type( - lhs.type + elif isinstance(lhs.type, (IntegerType, IndexType)) and isinstance( + lhs.type, (IntegerType, IndexType) ): op += "I" else: diff --git a/mlir/test/python/dialects/pdl_types.py b/mlir/test/python/dialects/pdl_types.py index f75428d295c9..58c9e74e95bf 100644 --- a/mlir/test/python/dialects/pdl_types.py +++ b/mlir/test/python/dialects/pdl_types.py @@ -17,17 +17,17 @@ def test_attribute_type(): parsedType = Type.parse("!pdl.attribute") constructedType = pdl.AttributeType.get() - assert pdl.AttributeType.isinstance(parsedType) - assert not pdl.OperationType.isinstance(parsedType) - assert not pdl.RangeType.isinstance(parsedType) - assert not pdl.TypeType.isinstance(parsedType) - assert not pdl.ValueType.isinstance(parsedType) + assert isinstance(parsedType, pdl.AttributeType) + assert not isinstance(parsedType, pdl.OperationType) + assert not isinstance(parsedType, pdl.RangeType) + assert not isinstance(parsedType, pdl.TypeType) + assert not isinstance(parsedType, pdl.ValueType) - assert pdl.AttributeType.isinstance(constructedType) - assert not pdl.OperationType.isinstance(constructedType) - assert not pdl.RangeType.isinstance(constructedType) - assert not pdl.TypeType.isinstance(constructedType) - assert not pdl.ValueType.isinstance(constructedType) + assert isinstance(constructedType, pdl.AttributeType) + assert not isinstance(constructedType, pdl.OperationType) + assert not isinstance(constructedType, pdl.RangeType) + assert not isinstance(constructedType, pdl.TypeType) + assert not isinstance(constructedType, pdl.ValueType) assert parsedType == constructedType @@ -44,17 +44,17 @@ def test_operation_type(): parsedType = Type.parse("!pdl.operation") constructedType = pdl.OperationType.get() - assert not pdl.AttributeType.isinstance(parsedType) - assert pdl.OperationType.isinstance(parsedType) - assert not pdl.RangeType.isinstance(parsedType) - assert not pdl.TypeType.isinstance(parsedType) - assert not pdl.ValueType.isinstance(parsedType) + assert not isinstance(parsedType, pdl.AttributeType) + assert isinstance(parsedType, pdl.OperationType) + assert not isinstance(parsedType, pdl.RangeType) + assert not isinstance(parsedType, pdl.TypeType) + assert not isinstance(parsedType, pdl.ValueType) - assert not pdl.AttributeType.isinstance(constructedType) - assert pdl.OperationType.isinstance(constructedType) - assert not pdl.RangeType.isinstance(constructedType) - assert not pdl.TypeType.isinstance(constructedType) - assert not pdl.ValueType.isinstance(constructedType) + assert not isinstance(constructedType, pdl.AttributeType) + assert isinstance(constructedType, pdl.OperationType) + assert not isinstance(constructedType, pdl.RangeType) + assert not isinstance(constructedType, pdl.TypeType) + assert not isinstance(constructedType, pdl.ValueType) assert parsedType == constructedType @@ -73,17 +73,17 @@ def test_range_type(): constructedType = pdl.RangeType.get(typeType) elementType = constructedType.element_type - assert not pdl.AttributeType.isinstance(parsedType) - assert not pdl.OperationType.isinstance(parsedType) - assert pdl.RangeType.isinstance(parsedType) - assert not pdl.TypeType.isinstance(parsedType) - assert not pdl.ValueType.isinstance(parsedType) + assert not isinstance(parsedType, pdl.AttributeType) + assert not isinstance(parsedType, pdl.OperationType) + assert isinstance(parsedType, pdl.RangeType) + assert not isinstance(parsedType, pdl.TypeType) + assert not isinstance(parsedType, pdl.ValueType) - assert not pdl.AttributeType.isinstance(constructedType) - assert not pdl.OperationType.isinstance(constructedType) - assert pdl.RangeType.isinstance(constructedType) - assert not pdl.TypeType.isinstance(constructedType) - assert not pdl.ValueType.isinstance(constructedType) + assert not isinstance(constructedType, pdl.AttributeType) + assert not isinstance(constructedType, pdl.OperationType) + assert isinstance(constructedType, pdl.RangeType) + assert not isinstance(constructedType, pdl.TypeType) + assert not isinstance(constructedType, pdl.ValueType) assert parsedType == constructedType assert elementType == typeType @@ -103,17 +103,17 @@ def test_type_type(): parsedType = Type.parse("!pdl.type") constructedType = pdl.TypeType.get() - assert not pdl.AttributeType.isinstance(parsedType) - assert not pdl.OperationType.isinstance(parsedType) - assert not pdl.RangeType.isinstance(parsedType) - assert pdl.TypeType.isinstance(parsedType) - assert not pdl.ValueType.isinstance(parsedType) + assert not isinstance(parsedType, pdl.AttributeType) + assert not isinstance(parsedType, pdl.OperationType) + assert not isinstance(parsedType, pdl.RangeType) + assert isinstance(parsedType, pdl.TypeType) + assert not isinstance(parsedType, pdl.ValueType) - assert not pdl.AttributeType.isinstance(constructedType) - assert not pdl.OperationType.isinstance(constructedType) - assert not pdl.RangeType.isinstance(constructedType) - assert pdl.TypeType.isinstance(constructedType) - assert not pdl.ValueType.isinstance(constructedType) + assert not isinstance(constructedType, pdl.AttributeType) + assert not isinstance(constructedType, pdl.OperationType) + assert not isinstance(constructedType, pdl.RangeType) + assert isinstance(constructedType, pdl.TypeType) + assert not isinstance(constructedType, pdl.ValueType) assert parsedType == constructedType @@ -130,17 +130,17 @@ def test_value_type(): parsedType = Type.parse("!pdl.value") constructedType = pdl.ValueType.get() - assert not pdl.AttributeType.isinstance(parsedType) - assert not pdl.OperationType.isinstance(parsedType) - assert not pdl.RangeType.isinstance(parsedType) - assert not pdl.TypeType.isinstance(parsedType) - assert pdl.ValueType.isinstance(parsedType) + assert not isinstance(parsedType, pdl.AttributeType) + assert not isinstance(parsedType, pdl.OperationType) + assert not isinstance(parsedType, pdl.RangeType) + assert not isinstance(parsedType, pdl.TypeType) + assert isinstance(parsedType, pdl.ValueType) - assert not pdl.AttributeType.isinstance(constructedType) - assert not pdl.OperationType.isinstance(constructedType) - assert not pdl.RangeType.isinstance(constructedType) - assert not pdl.TypeType.isinstance(constructedType) - assert pdl.ValueType.isinstance(constructedType) + assert not isinstance(constructedType, pdl.AttributeType) + assert not isinstance(constructedType, pdl.OperationType) + assert not isinstance(constructedType, pdl.RangeType) + assert not isinstance(constructedType, pdl.TypeType) + assert isinstance(constructedType, pdl.ValueType) assert parsedType == constructedType diff --git a/mlir/test/python/dialects/quant.py b/mlir/test/python/dialects/quant.py index 57c528da7b9e..f40d1224c9e4 100644 --- a/mlir/test/python/dialects/quant.py +++ b/mlir/test/python/dialects/quant.py @@ -24,23 +24,23 @@ def test_type_hierarchy(): ) calibrated = Type.parse("!quant.calibrated>") - assert not quant.QuantizedType.isinstance(i8) - assert quant.QuantizedType.isinstance(any) - assert quant.QuantizedType.isinstance(uniform) - assert quant.QuantizedType.isinstance(per_axis) - assert quant.QuantizedType.isinstance(sub_channel) - assert quant.QuantizedType.isinstance(calibrated) + assert not isinstance(i8, quant.QuantizedType) + assert isinstance(any, quant.QuantizedType) + assert isinstance(uniform, quant.QuantizedType) + assert isinstance(per_axis, quant.QuantizedType) + assert isinstance(sub_channel, quant.QuantizedType) + assert isinstance(calibrated, quant.QuantizedType) - assert quant.AnyQuantizedType.isinstance(any) - assert quant.UniformQuantizedType.isinstance(uniform) - assert quant.UniformQuantizedPerAxisType.isinstance(per_axis) - assert quant.UniformQuantizedSubChannelType.isinstance(sub_channel) - assert quant.CalibratedQuantizedType.isinstance(calibrated) + assert isinstance(any, quant.AnyQuantizedType) + assert isinstance(uniform, quant.UniformQuantizedType) + assert isinstance(per_axis, quant.UniformQuantizedPerAxisType) + assert isinstance(sub_channel, quant.UniformQuantizedSubChannelType) + assert isinstance(calibrated, quant.CalibratedQuantizedType) - assert not quant.AnyQuantizedType.isinstance(uniform) - assert not quant.UniformQuantizedType.isinstance(per_axis) - assert not quant.UniformQuantizedType.isinstance(sub_channel) - assert not quant.UniformQuantizedPerAxisType.isinstance(sub_channel) + assert not isinstance(uniform, quant.AnyQuantizedType) + assert not isinstance(per_axis, quant.UniformQuantizedType) + assert not isinstance(sub_channel, quant.UniformQuantizedType) + assert not isinstance(sub_channel, quant.UniformQuantizedPerAxisType) # CHECK-LABEL: TEST: test_any_quantized_type diff --git a/mlir/test/python/ir/affine_expr.py b/mlir/test/python/ir/affine_expr.py index c2a2ab3509ca..82c509efdf8f 100644 --- a/mlir/test/python/ir/affine_expr.py +++ b/mlir/test/python/ir/affine_expr.py @@ -354,21 +354,21 @@ def testIsInstance(): mul = AffineMulExpr.get(d1, c2) # CHECK: True - print(AffineDimExpr.isinstance(d1)) + print(isinstance(d1, AffineDimExpr)) # CHECK: False - print(AffineConstantExpr.isinstance(d1)) + print(isinstance(d1, AffineConstantExpr)) # CHECK: True - print(AffineConstantExpr.isinstance(c2)) + print(isinstance(c2, AffineConstantExpr)) # CHECK: False - print(AffineMulExpr.isinstance(c2)) + print(isinstance(c2, AffineMulExpr)) # CHECK: True - print(AffineAddExpr.isinstance(add)) + print(isinstance(add, AffineAddExpr)) # CHECK: False - print(AffineMulExpr.isinstance(add)) + print(isinstance(add, AffineMulExpr)) # CHECK: True - print(AffineMulExpr.isinstance(mul)) + print(isinstance(mul, AffineMulExpr)) # CHECK: False - print(AffineAddExpr.isinstance(mul)) + print(isinstance(mul, AffineAddExpr)) # CHECK-LABEL: TEST: testCompose diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py index 2f3c4460d3f5..5ab671bd4d29 100644 --- a/mlir/test/python/ir/attributes.py +++ b/mlir/test/python/ir/attributes.py @@ -94,10 +94,10 @@ def testAttrIsInstance(): with Context(): a1 = Attribute.parse("42") a2 = Attribute.parse("[42]") - assert IntegerAttr.isinstance(a1) - assert not IntegerAttr.isinstance(a2) - assert not ArrayAttr.isinstance(a1) - assert ArrayAttr.isinstance(a2) + assert isinstance(a1, IntegerAttr) + assert not isinstance(a2, IntegerAttr) + assert not isinstance(a1, ArrayAttr) + assert isinstance(a2, ArrayAttr) # CHECK-LABEL: TEST: testAttrEqDoesNotRaise diff --git a/mlir/test/python/ir/auto_location.py b/mlir/test/python/ir/auto_location.py index 1747c66aa663..6448a88dc177 100644 --- a/mlir/test/python/ir/auto_location.py +++ b/mlir/test/python/ir/auto_location.py @@ -34,7 +34,7 @@ def testInferLocations(): # Test nesting of loc_tracebacks(). with loc_tracebacks(): # fmt: off - # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]arith.py":65:12 to :76) at callsite("constant"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]arith.py":110:40 to :81) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:14 to :48) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:4 to :7) at ""("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:1 to :4)))))) + # CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]arith.py":45:12 to :76) at callsite("constant"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]arith.py":90:40 to :81) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:14 to :48) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:4 to :7) at ""("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:1 to :4)))))) # fmt: on print(one.location) diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py index 54863253fc77..aa1665a4020f 100644 --- a/mlir/test/python/ir/builtin_types.py +++ b/mlir/test/python/ir/builtin_types.py @@ -97,15 +97,15 @@ def testTypeIsInstance(): t1 = Type.parse("i32", ctx) t2 = Type.parse("f32", ctx) # CHECK: True - print(IntegerType.isinstance(t1)) + print(isinstance(t1, IntegerType)) # CHECK: False - print(F32Type.isinstance(t1)) + print(isinstance(t1, F32Type)) # CHECK: False - print(FloatType.isinstance(t1)) + print(isinstance(t1, FloatType)) # CHECK: True - print(F32Type.isinstance(t2)) + print(isinstance(t2, F32Type)) # CHECK: True - print(FloatType.isinstance(t2)) + print(isinstance(t2, FloatType)) # CHECK-LABEL: TEST: testFloatTypeSubclasses diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py index 4a241afb8e89..45efb880bab4 100644 --- a/mlir/test/python/ir/value.py +++ b/mlir/test/python/ir/value.py @@ -69,12 +69,12 @@ def testValueIsInstance(): ctx, ) func = module.body.operations[0] - assert BlockArgument.isinstance(func.regions[0].blocks[0].arguments[0]) - assert not OpResult.isinstance(func.regions[0].blocks[0].arguments[0]) + assert isinstance(func.regions[0].blocks[0].arguments[0], BlockArgument) + assert not isinstance(func.regions[0].blocks[0].arguments[0], OpResult) op = func.regions[0].blocks[0].operations[0] - assert not BlockArgument.isinstance(op.results[0]) - assert OpResult.isinstance(op.results[0]) + assert not isinstance(op.results[0], BlockArgument) + assert isinstance(op.results[0], OpResult) # CHECK-LABEL: TEST: testValueHash diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp index 2c33f4efac3a..6545559ff1b1 100644 --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -886,7 +886,7 @@ constexpr const char *firstAttrDerivedResultTypeTemplate = _ods_result_type_source_attr = attributes["{0}"] _ods_derived_result_type = ( _ods_ir.TypeAttr(_ods_result_type_source_attr).value - if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else + if isinstance(_ods_result_type_source_attr, _ods_ir.TypeAttr) else _ods_result_type_source_attr.type) results = [_ods_derived_result_type] * {1})Py";