[mlir][Python] use canonical Python isinstance instead of Type.isinstance (#172892)

We've been able to do `isinstance(x, Type)` for a quite a while now
(since
bfb1ba7526)
so remove `Type.isinstance` and the the special-casing
(`_is_integer_type`, `_is_floating_point_type`, `_is_index_type`) in
some places (and therefore support various `fp8`, `fp6`, `fp4` types).
This commit is contained in:
Maksim Levental 2026-01-05 16:07:24 -05:00 committed by GitHub
parent e826168a24
commit fb8bbd4ed8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 220 additions and 212 deletions

View File

@ -30,6 +30,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLType(MlirType type);
MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLAttributeType(MlirType type);
MLIR_CAPI_EXPORTED MlirTypeID mlirPDLAttributeTypeGetTypeID(void);
MLIR_CAPI_EXPORTED MlirType mlirPDLAttributeTypeGet(MlirContext ctx);
//===---------------------------------------------------------------------===//
@ -38,6 +40,8 @@ MLIR_CAPI_EXPORTED MlirType mlirPDLAttributeTypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLOperationType(MlirType type);
MLIR_CAPI_EXPORTED MlirTypeID mlirPDLOperationTypeGetTypeID(void);
MLIR_CAPI_EXPORTED MlirType mlirPDLOperationTypeGet(MlirContext ctx);
//===---------------------------------------------------------------------===//
@ -46,6 +50,8 @@ MLIR_CAPI_EXPORTED MlirType mlirPDLOperationTypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLRangeType(MlirType type);
MLIR_CAPI_EXPORTED MlirTypeID mlirPDLRangeTypeGetTypeID(void);
MLIR_CAPI_EXPORTED MlirType mlirPDLRangeTypeGet(MlirType elementType);
MLIR_CAPI_EXPORTED MlirType mlirPDLRangeTypeGetElementType(MlirType type);
@ -56,6 +62,8 @@ MLIR_CAPI_EXPORTED MlirType mlirPDLRangeTypeGetElementType(MlirType type);
MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLTypeType(MlirType type);
MLIR_CAPI_EXPORTED MlirTypeID mlirPDLTypeTypeGetTypeID(void);
MLIR_CAPI_EXPORTED MlirType mlirPDLTypeTypeGet(MlirContext ctx);
//===---------------------------------------------------------------------===//
@ -64,6 +72,8 @@ MLIR_CAPI_EXPORTED MlirType mlirPDLTypeTypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLValueType(MlirType type);
MLIR_CAPI_EXPORTED MlirTypeID mlirPDLValueTypeGetTypeID(void);
MLIR_CAPI_EXPORTED MlirType mlirPDLValueTypeGet(MlirContext ctx);
#ifdef __cplusplus

View File

@ -103,6 +103,8 @@ mlirQuantizedTypeCastExpressedToStorageType(MlirType type, MlirType candidate);
/// Returns `true` if the given type is an AnyQuantizedType.
MLIR_CAPI_EXPORTED bool mlirTypeIsAAnyQuantizedType(MlirType type);
MLIR_CAPI_EXPORTED MlirTypeID mlirAnyQuantizedTypeGetTypeID(void);
/// Creates an instance of AnyQuantizedType with the given parameters in the
/// same context as `storageType` and returns it. The instance is owned by the
/// context.
@ -119,6 +121,8 @@ MLIR_CAPI_EXPORTED MlirType mlirAnyQuantizedTypeGet(unsigned flags,
/// Returns `true` if the given type is a UniformQuantizedType.
MLIR_CAPI_EXPORTED bool mlirTypeIsAUniformQuantizedType(MlirType type);
MLIR_CAPI_EXPORTED MlirTypeID mlirUniformQuantizedTypeGetTypeID(void);
/// Creates an instance of UniformQuantizedType with the given parameters in the
/// same context as `storageType` and returns it. The instance is owned by the
/// context.
@ -142,6 +146,8 @@ MLIR_CAPI_EXPORTED bool mlirUniformQuantizedTypeIsFixedPoint(MlirType type);
/// Returns `true` if the given type is a UniformQuantizedPerAxisType.
MLIR_CAPI_EXPORTED bool mlirTypeIsAUniformQuantizedPerAxisType(MlirType type);
MLIR_CAPI_EXPORTED MlirTypeID mlirUniformQuantizedPerAxisTypeGetTypeID(void);
/// Creates an instance of UniformQuantizedPerAxisType with the given parameters
/// in the same context as `storageType` and returns it. `scales` and
/// `zeroPoints` point to `nDims` number of elements. The instance is owned
@ -180,6 +186,8 @@ mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type);
MLIR_CAPI_EXPORTED bool
mlirTypeIsAUniformQuantizedSubChannelType(MlirType type);
MLIR_CAPI_EXPORTED MlirTypeID mlirUniformQuantizedSubChannelTypeGetTypeID(void);
/// Creates a UniformQuantizedSubChannelType with the given parameters.
///
/// The type is owned by the context. `scalesAttr` and `zeroPointsAttr` must be
@ -220,6 +228,8 @@ mlirUniformQuantizedSubChannelTypeGetZeroPoints(MlirType type);
/// Returns `true` if the given type is a CalibratedQuantizedType.
MLIR_CAPI_EXPORTED bool mlirTypeIsACalibratedQuantizedType(MlirType type);
MLIR_CAPI_EXPORTED MlirTypeID mlirCalibratedQuantizedTypeGetTypeID(void);
/// Creates an instance of CalibratedQuantizedType with the given parameters
/// in the same context as `expressedType` and returns it. The instance is owned
/// by the context.

View File

@ -957,12 +957,6 @@ public:
auto cls = ClassTy(m, DerivedTy::pyClassName);
cls.def(nanobind::init<PyType &>(), nanobind::keep_alive<0, 1>(),
nanobind::arg("cast_from_type"));
cls.def_static(
"isinstance",
[](PyType &otherType) -> bool {
return DerivedTy::isaFunction(otherType);
},
nanobind::arg("other"));
cls.def_prop_ro_static(
"static_typeid",
[](nanobind::object & /*class*/) {
@ -1094,12 +1088,6 @@ public:
}
cls.def(nanobind::init<PyAttribute &>(), nanobind::keep_alive<0, 1>(),
nanobind::arg("cast_from_attr"));
cls.def_static(
"isinstance",
[](PyAttribute &otherAttr) -> bool {
return DerivedTy::isaFunction(otherAttr);
},
nanobind::arg("other"));
cls.def_prop_ro(
"type",
[](PyAttribute &attr) -> nanobind::typed<nanobind::object, PyType> {
@ -1555,12 +1543,6 @@ public:
.c_str()));
cls.def(nanobind::init<PyValue &>(), nanobind::keep_alive<0, 1>(),
nanobind::arg("value"));
cls.def_static(
"isinstance",
[](PyValue &otherValue) -> bool {
return DerivedTy::isaFunction(otherValue);
},
nanobind::arg("other_value"));
cls.def(
MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
[](DerivedTy &self) -> nanobind::typed<nanobind::object, DerivedTy> {

View File

@ -39,6 +39,8 @@ struct PDLType : PyConcreteType<PDLType> {
struct AttributeType : PyConcreteType<AttributeType> {
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLAttributeType;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirPDLAttributeTypeGetTypeID;
static constexpr const char *pyClassName = "AttributeType";
using Base::Base;
@ -60,6 +62,8 @@ struct AttributeType : PyConcreteType<AttributeType> {
struct OperationType : PyConcreteType<OperationType> {
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLOperationType;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirPDLOperationTypeGetTypeID;
static constexpr const char *pyClassName = "OperationType";
using Base::Base;
@ -81,6 +85,8 @@ struct OperationType : PyConcreteType<OperationType> {
struct RangeType : PyConcreteType<RangeType> {
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLRangeType;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirPDLRangeTypeGetTypeID;
static constexpr const char *pyClassName = "RangeType";
using Base::Base;
@ -109,6 +115,8 @@ struct RangeType : PyConcreteType<RangeType> {
struct TypeType : PyConcreteType<TypeType> {
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLTypeType;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirPDLTypeTypeGetTypeID;
static constexpr const char *pyClassName = "TypeType";
using Base::Base;
@ -130,6 +138,8 @@ struct TypeType : PyConcreteType<TypeType> {
struct ValueType : PyConcreteType<ValueType> {
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLValueType;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirPDLValueTypeGetTypeID;
static constexpr const char *pyClassName = "ValueType";
using Base::Base;

View File

@ -192,6 +192,8 @@ struct QuantizedType : PyConcreteType<QuantizedType> {
struct AnyQuantizedType : PyConcreteType<AnyQuantizedType, QuantizedType> {
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAAnyQuantizedType;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirAnyQuantizedTypeGetTypeID;
static constexpr const char *pyClassName = "AnyQuantizedType";
using Base::Base;
@ -221,6 +223,8 @@ struct AnyQuantizedType : PyConcreteType<AnyQuantizedType, QuantizedType> {
struct UniformQuantizedType
: PyConcreteType<UniformQuantizedType, QuantizedType> {
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUniformQuantizedType;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirUniformQuantizedTypeGetTypeID;
static constexpr const char *pyClassName = "UniformQuantizedType";
using Base::Base;
@ -273,6 +277,8 @@ struct UniformQuantizedPerAxisType
: PyConcreteType<UniformQuantizedPerAxisType, QuantizedType> {
static constexpr IsAFunctionTy isaFunction =
mlirTypeIsAUniformQuantizedPerAxisType;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirUniformQuantizedPerAxisTypeGetTypeID;
static constexpr const char *pyClassName = "UniformQuantizedPerAxisType";
using Base::Base;
@ -357,6 +363,8 @@ struct UniformQuantizedSubChannelType
: PyConcreteType<UniformQuantizedSubChannelType, QuantizedType> {
static constexpr IsAFunctionTy isaFunction =
mlirTypeIsAUniformQuantizedSubChannelType;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirUniformQuantizedSubChannelTypeGetTypeID;
static constexpr const char *pyClassName = "UniformQuantizedSubChannelType";
using Base::Base;
@ -448,6 +456,8 @@ struct CalibratedQuantizedType
: PyConcreteType<CalibratedQuantizedType, QuantizedType> {
static constexpr IsAFunctionTy isaFunction =
mlirTypeIsACalibratedQuantizedType;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirCalibratedQuantizedTypeGetTypeID;
static constexpr const char *pyClassName = "CalibratedQuantizedType";
using Base::Base;

View File

@ -118,12 +118,6 @@ public:
static void bind(nb::module_ &m) {
auto cls = ClassTy(m, DerivedTy::pyClassName);
cls.def(nb::init<PyAffineExpr &>(), nb::arg("expr"));
cls.def_static(
"isinstance",
[](PyAffineExpr &otherAffineExpr) -> bool {
return DerivedTy::isaFunction(otherAffineExpr);
},
nb::arg("other"));
DerivedTy::bindDerived(cls);
}

View File

@ -32,6 +32,10 @@ bool mlirTypeIsAPDLAttributeType(MlirType type) {
return isa<pdl::AttributeType>(unwrap(type));
}
MlirTypeID mlirPDLAttributeTypeGetTypeID(void) {
return wrap(pdl::AttributeType::getTypeID());
}
MlirType mlirPDLAttributeTypeGet(MlirContext ctx) {
return wrap(pdl::AttributeType::get(unwrap(ctx)));
}
@ -44,6 +48,10 @@ bool mlirTypeIsAPDLOperationType(MlirType type) {
return isa<pdl::OperationType>(unwrap(type));
}
MlirTypeID mlirPDLOperationTypeGetTypeID(void) {
return wrap(pdl::OperationType::getTypeID());
}
MlirType mlirPDLOperationTypeGet(MlirContext ctx) {
return wrap(pdl::OperationType::get(unwrap(ctx)));
}
@ -56,6 +64,10 @@ bool mlirTypeIsAPDLRangeType(MlirType type) {
return isa<pdl::RangeType>(unwrap(type));
}
MlirTypeID mlirPDLRangeTypeGetTypeID(void) {
return wrap(pdl::RangeType::getTypeID());
}
MlirType mlirPDLRangeTypeGet(MlirType elementType) {
return wrap(pdl::RangeType::get(unwrap(elementType)));
}
@ -72,6 +84,10 @@ bool mlirTypeIsAPDLTypeType(MlirType type) {
return isa<pdl::TypeType>(unwrap(type));
}
MlirTypeID mlirPDLTypeTypeGetTypeID(void) {
return wrap(pdl::TypeType::getTypeID());
}
MlirType mlirPDLTypeTypeGet(MlirContext ctx) {
return wrap(pdl::TypeType::get(unwrap(ctx)));
}
@ -84,6 +100,10 @@ bool mlirTypeIsAPDLValueType(MlirType type) {
return isa<pdl::ValueType>(unwrap(type));
}
MlirTypeID mlirPDLValueTypeGetTypeID(void) {
return wrap(pdl::ValueType::getTypeID());
}
MlirType mlirPDLValueTypeGet(MlirContext ctx) {
return wrap(pdl::ValueType::get(unwrap(ctx)));
}

View File

@ -113,6 +113,10 @@ bool mlirTypeIsAAnyQuantizedType(MlirType type) {
return isa<quant::AnyQuantizedType>(unwrap(type));
}
MlirTypeID mlirAnyQuantizedTypeGetTypeID(void) {
return wrap(quant::AnyQuantizedType::getTypeID());
}
MlirType mlirAnyQuantizedTypeGet(unsigned flags, MlirType storageType,
MlirType expressedType, int64_t storageTypeMin,
int64_t storageTypeMax) {
@ -129,6 +133,10 @@ bool mlirTypeIsAUniformQuantizedType(MlirType type) {
return isa<quant::UniformQuantizedType>(unwrap(type));
}
MlirTypeID mlirUniformQuantizedTypeGetTypeID(void) {
return wrap(quant::UniformQuantizedType::getTypeID());
}
MlirType mlirUniformQuantizedTypeGet(unsigned flags, MlirType storageType,
MlirType expressedType, double scale,
int64_t zeroPoint, int64_t storageTypeMin,
@ -158,6 +166,10 @@ bool mlirTypeIsAUniformQuantizedPerAxisType(MlirType type) {
return isa<quant::UniformQuantizedPerAxisType>(unwrap(type));
}
MlirTypeID mlirUniformQuantizedPerAxisTypeGetTypeID(void) {
return wrap(quant::UniformQuantizedPerAxisType::getTypeID());
}
MlirType mlirUniformQuantizedPerAxisTypeGet(
unsigned flags, MlirType storageType, MlirType expressedType,
intptr_t nDims, double *scales, int64_t *zeroPoints,
@ -203,6 +215,10 @@ bool mlirTypeIsAUniformQuantizedSubChannelType(MlirType type) {
return isa<quant::UniformQuantizedSubChannelType>(unwrap(type));
}
MlirTypeID mlirUniformQuantizedSubChannelTypeGetTypeID(void) {
return wrap(quant::UniformQuantizedSubChannelType::getTypeID());
}
MlirType mlirUniformQuantizedSubChannelTypeGet(
unsigned flags, MlirType storageType, MlirType expressedType,
MlirAttribute scalesAttr, MlirAttribute zeroPointsAttr, intptr_t nDims,
@ -258,6 +274,10 @@ bool mlirTypeIsACalibratedQuantizedType(MlirType type) {
return isa<quant::CalibratedQuantizedType>(unwrap(type));
}
MlirTypeID mlirCalibratedQuantizedTypeGetTypeID(void) {
return wrap(quant::CalibratedQuantizedType::getTypeID());
}
MlirType mlirCalibratedQuantizedTypeGet(MlirType expressedType, double min,
double max) {
return wrap(

View File

@ -21,26 +21,6 @@ except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
def _isa(obj: Any, cls: type):
try:
cls(obj)
except ValueError:
return False
return True
def _is_any_of(obj: Any, classes: List[type]):
return any(_isa(obj, cls) for cls in classes)
def _is_integer_like_type(type: Type):
return _is_any_of(type, [IntegerType, IndexType])
def _is_float_type(type: Type):
return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])
@_ods_cext.register_operation(_Dialect, replace=True)
class ConstantOp(ConstantOp):
"""Specialization for the constant op class."""
@ -96,9 +76,9 @@ class ConstantOp(ConstantOp):
@property
def literal_value(self) -> Union[int, float]:
if _is_integer_like_type(self.type):
if isinstance(self.type, (IntegerType, IndexType)):
return IntegerAttr(self.value).value
elif _is_float_type(self.type):
elif isinstance(self.type, FloatType):
return FloatAttr(self.value).value
else:
raise ValueError("only integer and float constants have literal values")

View File

@ -412,9 +412,9 @@ class _BodyBuilder:
)
if operand.type == to_type:
return operand
if _is_integer_type(to_type):
if isinstance(to_type, IntegerType):
return self._cast_to_integer(to_type, operand, is_unsigned_cast)
elif _is_floating_point_type(to_type):
elif isinstance(to_type, FloatType):
return self._cast_to_floating_point(to_type, operand, is_unsigned_cast)
def _cast_to_integer(
@ -422,11 +422,11 @@ class _BodyBuilder:
) -> Value:
to_width = IntegerType(to_type).width
operand_type = operand.type
if _is_floating_point_type(operand_type):
if isinstance(operand_type, FloatType):
if is_unsigned_cast:
return arith.FPToUIOp(to_type, operand).result
return arith.FPToSIOp(to_type, operand).result
if _is_index_type(operand_type):
if isinstance(operand_type, IndexType):
return arith.IndexCastOp(to_type, operand).result
# Assume integer.
from_width = IntegerType(operand_type).width
@ -444,13 +444,15 @@ class _BodyBuilder:
self, to_type: Type, operand: Value, is_unsigned_cast: bool
) -> Value:
operand_type = operand.type
if _is_integer_type(operand_type):
if isinstance(operand_type, IntegerType):
if is_unsigned_cast:
return arith.UIToFPOp(to_type, operand).result
return arith.SIToFPOp(to_type, operand).result
# Assume FloatType.
to_width = _get_floating_point_width(to_type)
from_width = _get_floating_point_width(operand_type)
assert isinstance(to_type, FloatType)
assert isinstance(operand_type, FloatType)
to_width = to_type.width
from_width = operand_type.width
if to_width > from_width:
return arith.ExtFOp(to_type, operand).result
elif to_width < from_width:
@ -466,89 +468,89 @@ class _BodyBuilder:
return self._cast(type_var_name, operand, True)
def _unary_exp(self, x: Value) -> Value:
if _is_floating_point_type(x.type):
if isinstance(x.type, FloatType):
return math.ExpOp(x).result
raise NotImplementedError("Unsupported 'exp' operand: {x}")
def _unary_log(self, x: Value) -> Value:
if _is_floating_point_type(x.type):
if isinstance(x.type, FloatType):
return math.LogOp(x).result
raise NotImplementedError("Unsupported 'log' operand: {x}")
def _unary_abs(self, x: Value) -> Value:
if _is_floating_point_type(x.type):
if isinstance(x.type, FloatType):
return math.AbsFOp(x).result
raise NotImplementedError("Unsupported 'abs' operand: {x}")
def _unary_ceil(self, x: Value) -> Value:
if _is_floating_point_type(x.type):
if isinstance(x.type, FloatType):
return math.CeilOp(x).result
raise NotImplementedError("Unsupported 'ceil' operand: {x}")
def _unary_floor(self, x: Value) -> Value:
if _is_floating_point_type(x.type):
if isinstance(x.type, FloatType):
return math.FloorOp(x).result
raise NotImplementedError("Unsupported 'floor' operand: {x}")
def _unary_negf(self, x: Value) -> Value:
if _is_floating_point_type(x.type):
if isinstance(x.type, FloatType):
return arith.NegFOp(x).result
if _is_complex_type(x.type):
if isinstance(x.type, ComplexType):
return complex.NegOp(x).result
raise NotImplementedError("Unsupported 'negf' operand: {x}")
def _binary_add(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
if isinstance(lhs.type, FloatType):
return arith.AddFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
if isinstance(lhs.type, IntegerType) or isinstance(lhs.type, IndexType):
return arith.AddIOp(lhs, rhs).result
if _is_complex_type(lhs.type):
if isinstance(lhs.type, ComplexType):
return complex.AddOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'add' operands: {lhs}, {rhs}")
def _binary_sub(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
if isinstance(lhs.type, FloatType):
return arith.SubFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
if isinstance(lhs.type, IntegerType) or isinstance(lhs.type, IndexType):
return arith.SubIOp(lhs, rhs).result
if _is_complex_type(lhs.type):
if isinstance(lhs.type, ComplexType):
return complex.SubOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'sub' operands: {lhs}, {rhs}")
def _binary_mul(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
if isinstance(lhs.type, FloatType):
return arith.MulFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
if isinstance(lhs.type, IntegerType) or isinstance(lhs.type, IndexType):
return arith.MulIOp(lhs, rhs).result
if _is_complex_type(lhs.type):
if isinstance(lhs.type, ComplexType):
return complex.MulOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}")
def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
if isinstance(lhs.type, FloatType):
return arith.MaximumFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
if isinstance(lhs.type, IntegerType) or isinstance(lhs.type, IndexType):
return arith.MaxSIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'max' operands: {lhs}, {rhs}")
def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value:
if (
_is_integer_type(lhs.type) and not _is_bool_type(lhs.type)
) or _is_index_type(lhs.type):
isinstance(lhs.type, IntegerType) and not _is_bool_type(lhs.type)
) or isinstance(lhs.type, IndexType):
return arith.MaxUIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'max_unsigned' operands: {lhs}, {rhs}")
def _binary_min_signed(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
if isinstance(lhs.type, FloatType):
return arith.MinimumFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
if isinstance(lhs.type, IntegerType) or isinstance(lhs.type, IndexType):
return arith.MinSIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'min' operands: {lhs}, {rhs}")
def _binary_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
if (
_is_integer_type(lhs.type) and not _is_bool_type(lhs.type)
) or _is_index_type(lhs.type):
isinstance(lhs.type, IntegerType) and not _is_bool_type(lhs.type)
) or isinstance(lhs.type, IndexType):
return arith.MinUIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'min_unsigned' operands: {lhs}, {rhs}")
@ -611,44 +613,7 @@ def _add_type_mapping(
block_arg_types.append(element_or_self_type)
def _is_complex_type(t: Type) -> bool:
return ComplexType.isinstance(t)
def _is_floating_point_type(t: Type) -> bool:
# TODO: Create a FloatType in the Python API and implement the switch
# there.
return (
F64Type.isinstance(t)
or F32Type.isinstance(t)
or F16Type.isinstance(t)
or BF16Type.isinstance(t)
)
def _is_integer_type(t: Type) -> bool:
return IntegerType.isinstance(t)
def _is_index_type(t: Type) -> bool:
return IndexType.isinstance(t)
def _is_bool_type(t: Type) -> bool:
if not IntegerType.isinstance(t):
if not isinstance(t, IntegerType):
return False
return IntegerType(t).width == 1
def _get_floating_point_width(t: Type) -> int:
# TODO: Create a FloatType in the Python API and implement the switch
# there.
if F64Type.isinstance(t):
return 64
if F32Type.isinstance(t):
return 32
if F16Type.isinstance(t):
return 16
if BF16Type.isinstance(t):
return 16
raise NotImplementedError(f"Unhandled floating point type switch {t}")
return t.width == 1

View File

@ -8,15 +8,22 @@ from typing import Optional
from ._memref_ops_gen import *
from ._memref_ops_gen import _Dialect
from ._ods_common import _dispatch_mixed_values, MixedValues
from .arith import ConstantOp, _is_integer_like_type
from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType, Operation
from ..ir import (
IndexType,
IntegerType,
MemRefType,
ShapedType,
StridedLayoutAttr,
Value,
)
from . import arith
def _is_constant_int_like(i):
return (
isinstance(i, Value)
and isinstance(i.owner, ConstantOp)
and _is_integer_like_type(i.type)
and isinstance(i.owner, arith.ConstantOp)
and isinstance(i.type, (IntegerType, IndexType))
)

View File

@ -233,7 +233,7 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu
// CHECK: _ods_result_type_source_attr = attributes["type"]
// CHECK: _ods_derived_result_type = (
// CHECK: _ods_ir.TypeAttr(_ods_result_type_source_attr).value
// CHECK: if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
// CHECK: if isinstance(_ods_result_type_source_attr, _ods_ir.TypeAttr) else
// CHECK: _ods_result_type_source_attr.type)
// CHECK: results = [_ods_derived_result_type] * 2
let arguments = (ins TypeAttr:$type);

View File

@ -42,10 +42,10 @@ def testFastMathFlags():
def testArithValue():
def _binary_op(lhs, rhs, op: str) -> "ArithValue":
op = op.capitalize()
if arith._is_float_type(lhs.type) and arith._is_float_type(rhs.type):
if isinstance(lhs.type, FloatType) and isinstance(rhs.type, FloatType):
op += "F"
elif arith._is_integer_like_type(lhs.type) and arith._is_integer_like_type(
lhs.type
elif isinstance(lhs.type, (IntegerType, IndexType)) and isinstance(
lhs.type, (IntegerType, IndexType)
):
op += "I"
else:

View File

@ -17,17 +17,17 @@ def test_attribute_type():
parsedType = Type.parse("!pdl.attribute")
constructedType = pdl.AttributeType.get()
assert pdl.AttributeType.isinstance(parsedType)
assert not pdl.OperationType.isinstance(parsedType)
assert not pdl.RangeType.isinstance(parsedType)
assert not pdl.TypeType.isinstance(parsedType)
assert not pdl.ValueType.isinstance(parsedType)
assert isinstance(parsedType, pdl.AttributeType)
assert not isinstance(parsedType, pdl.OperationType)
assert not isinstance(parsedType, pdl.RangeType)
assert not isinstance(parsedType, pdl.TypeType)
assert not isinstance(parsedType, pdl.ValueType)
assert pdl.AttributeType.isinstance(constructedType)
assert not pdl.OperationType.isinstance(constructedType)
assert not pdl.RangeType.isinstance(constructedType)
assert not pdl.TypeType.isinstance(constructedType)
assert not pdl.ValueType.isinstance(constructedType)
assert isinstance(constructedType, pdl.AttributeType)
assert not isinstance(constructedType, pdl.OperationType)
assert not isinstance(constructedType, pdl.RangeType)
assert not isinstance(constructedType, pdl.TypeType)
assert not isinstance(constructedType, pdl.ValueType)
assert parsedType == constructedType
@ -44,17 +44,17 @@ def test_operation_type():
parsedType = Type.parse("!pdl.operation")
constructedType = pdl.OperationType.get()
assert not pdl.AttributeType.isinstance(parsedType)
assert pdl.OperationType.isinstance(parsedType)
assert not pdl.RangeType.isinstance(parsedType)
assert not pdl.TypeType.isinstance(parsedType)
assert not pdl.ValueType.isinstance(parsedType)
assert not isinstance(parsedType, pdl.AttributeType)
assert isinstance(parsedType, pdl.OperationType)
assert not isinstance(parsedType, pdl.RangeType)
assert not isinstance(parsedType, pdl.TypeType)
assert not isinstance(parsedType, pdl.ValueType)
assert not pdl.AttributeType.isinstance(constructedType)
assert pdl.OperationType.isinstance(constructedType)
assert not pdl.RangeType.isinstance(constructedType)
assert not pdl.TypeType.isinstance(constructedType)
assert not pdl.ValueType.isinstance(constructedType)
assert not isinstance(constructedType, pdl.AttributeType)
assert isinstance(constructedType, pdl.OperationType)
assert not isinstance(constructedType, pdl.RangeType)
assert not isinstance(constructedType, pdl.TypeType)
assert not isinstance(constructedType, pdl.ValueType)
assert parsedType == constructedType
@ -73,17 +73,17 @@ def test_range_type():
constructedType = pdl.RangeType.get(typeType)
elementType = constructedType.element_type
assert not pdl.AttributeType.isinstance(parsedType)
assert not pdl.OperationType.isinstance(parsedType)
assert pdl.RangeType.isinstance(parsedType)
assert not pdl.TypeType.isinstance(parsedType)
assert not pdl.ValueType.isinstance(parsedType)
assert not isinstance(parsedType, pdl.AttributeType)
assert not isinstance(parsedType, pdl.OperationType)
assert isinstance(parsedType, pdl.RangeType)
assert not isinstance(parsedType, pdl.TypeType)
assert not isinstance(parsedType, pdl.ValueType)
assert not pdl.AttributeType.isinstance(constructedType)
assert not pdl.OperationType.isinstance(constructedType)
assert pdl.RangeType.isinstance(constructedType)
assert not pdl.TypeType.isinstance(constructedType)
assert not pdl.ValueType.isinstance(constructedType)
assert not isinstance(constructedType, pdl.AttributeType)
assert not isinstance(constructedType, pdl.OperationType)
assert isinstance(constructedType, pdl.RangeType)
assert not isinstance(constructedType, pdl.TypeType)
assert not isinstance(constructedType, pdl.ValueType)
assert parsedType == constructedType
assert elementType == typeType
@ -103,17 +103,17 @@ def test_type_type():
parsedType = Type.parse("!pdl.type")
constructedType = pdl.TypeType.get()
assert not pdl.AttributeType.isinstance(parsedType)
assert not pdl.OperationType.isinstance(parsedType)
assert not pdl.RangeType.isinstance(parsedType)
assert pdl.TypeType.isinstance(parsedType)
assert not pdl.ValueType.isinstance(parsedType)
assert not isinstance(parsedType, pdl.AttributeType)
assert not isinstance(parsedType, pdl.OperationType)
assert not isinstance(parsedType, pdl.RangeType)
assert isinstance(parsedType, pdl.TypeType)
assert not isinstance(parsedType, pdl.ValueType)
assert not pdl.AttributeType.isinstance(constructedType)
assert not pdl.OperationType.isinstance(constructedType)
assert not pdl.RangeType.isinstance(constructedType)
assert pdl.TypeType.isinstance(constructedType)
assert not pdl.ValueType.isinstance(constructedType)
assert not isinstance(constructedType, pdl.AttributeType)
assert not isinstance(constructedType, pdl.OperationType)
assert not isinstance(constructedType, pdl.RangeType)
assert isinstance(constructedType, pdl.TypeType)
assert not isinstance(constructedType, pdl.ValueType)
assert parsedType == constructedType
@ -130,17 +130,17 @@ def test_value_type():
parsedType = Type.parse("!pdl.value")
constructedType = pdl.ValueType.get()
assert not pdl.AttributeType.isinstance(parsedType)
assert not pdl.OperationType.isinstance(parsedType)
assert not pdl.RangeType.isinstance(parsedType)
assert not pdl.TypeType.isinstance(parsedType)
assert pdl.ValueType.isinstance(parsedType)
assert not isinstance(parsedType, pdl.AttributeType)
assert not isinstance(parsedType, pdl.OperationType)
assert not isinstance(parsedType, pdl.RangeType)
assert not isinstance(parsedType, pdl.TypeType)
assert isinstance(parsedType, pdl.ValueType)
assert not pdl.AttributeType.isinstance(constructedType)
assert not pdl.OperationType.isinstance(constructedType)
assert not pdl.RangeType.isinstance(constructedType)
assert not pdl.TypeType.isinstance(constructedType)
assert pdl.ValueType.isinstance(constructedType)
assert not isinstance(constructedType, pdl.AttributeType)
assert not isinstance(constructedType, pdl.OperationType)
assert not isinstance(constructedType, pdl.RangeType)
assert not isinstance(constructedType, pdl.TypeType)
assert isinstance(constructedType, pdl.ValueType)
assert parsedType == constructedType

View File

@ -24,23 +24,23 @@ def test_type_hierarchy():
)
calibrated = Type.parse("!quant.calibrated<f32<-0.998:1.2321>>")
assert not quant.QuantizedType.isinstance(i8)
assert quant.QuantizedType.isinstance(any)
assert quant.QuantizedType.isinstance(uniform)
assert quant.QuantizedType.isinstance(per_axis)
assert quant.QuantizedType.isinstance(sub_channel)
assert quant.QuantizedType.isinstance(calibrated)
assert not isinstance(i8, quant.QuantizedType)
assert isinstance(any, quant.QuantizedType)
assert isinstance(uniform, quant.QuantizedType)
assert isinstance(per_axis, quant.QuantizedType)
assert isinstance(sub_channel, quant.QuantizedType)
assert isinstance(calibrated, quant.QuantizedType)
assert quant.AnyQuantizedType.isinstance(any)
assert quant.UniformQuantizedType.isinstance(uniform)
assert quant.UniformQuantizedPerAxisType.isinstance(per_axis)
assert quant.UniformQuantizedSubChannelType.isinstance(sub_channel)
assert quant.CalibratedQuantizedType.isinstance(calibrated)
assert isinstance(any, quant.AnyQuantizedType)
assert isinstance(uniform, quant.UniformQuantizedType)
assert isinstance(per_axis, quant.UniformQuantizedPerAxisType)
assert isinstance(sub_channel, quant.UniformQuantizedSubChannelType)
assert isinstance(calibrated, quant.CalibratedQuantizedType)
assert not quant.AnyQuantizedType.isinstance(uniform)
assert not quant.UniformQuantizedType.isinstance(per_axis)
assert not quant.UniformQuantizedType.isinstance(sub_channel)
assert not quant.UniformQuantizedPerAxisType.isinstance(sub_channel)
assert not isinstance(uniform, quant.AnyQuantizedType)
assert not isinstance(per_axis, quant.UniformQuantizedType)
assert not isinstance(sub_channel, quant.UniformQuantizedType)
assert not isinstance(sub_channel, quant.UniformQuantizedPerAxisType)
# CHECK-LABEL: TEST: test_any_quantized_type

View File

@ -354,21 +354,21 @@ def testIsInstance():
mul = AffineMulExpr.get(d1, c2)
# CHECK: True
print(AffineDimExpr.isinstance(d1))
print(isinstance(d1, AffineDimExpr))
# CHECK: False
print(AffineConstantExpr.isinstance(d1))
print(isinstance(d1, AffineConstantExpr))
# CHECK: True
print(AffineConstantExpr.isinstance(c2))
print(isinstance(c2, AffineConstantExpr))
# CHECK: False
print(AffineMulExpr.isinstance(c2))
print(isinstance(c2, AffineMulExpr))
# CHECK: True
print(AffineAddExpr.isinstance(add))
print(isinstance(add, AffineAddExpr))
# CHECK: False
print(AffineMulExpr.isinstance(add))
print(isinstance(add, AffineMulExpr))
# CHECK: True
print(AffineMulExpr.isinstance(mul))
print(isinstance(mul, AffineMulExpr))
# CHECK: False
print(AffineAddExpr.isinstance(mul))
print(isinstance(mul, AffineAddExpr))
# CHECK-LABEL: TEST: testCompose

View File

@ -94,10 +94,10 @@ def testAttrIsInstance():
with Context():
a1 = Attribute.parse("42")
a2 = Attribute.parse("[42]")
assert IntegerAttr.isinstance(a1)
assert not IntegerAttr.isinstance(a2)
assert not ArrayAttr.isinstance(a1)
assert ArrayAttr.isinstance(a2)
assert isinstance(a1, IntegerAttr)
assert not isinstance(a2, IntegerAttr)
assert not isinstance(a1, ArrayAttr)
assert isinstance(a2, ArrayAttr)
# CHECK-LABEL: TEST: testAttrEqDoesNotRaise

View File

@ -34,7 +34,7 @@ def testInferLocations():
# Test nesting of loc_tracebacks().
with loc_tracebacks():
# fmt: off
# CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]arith.py":65:12 to :76) at callsite("constant"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]arith.py":110:40 to :81) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:14 to :48) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:1 to :4))))))
# CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]arith.py":45:12 to :76) at callsite("constant"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]arith.py":90:40 to :81) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:14 to :48) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:1 to :4))))))
# fmt: on
print(one.location)

View File

@ -97,15 +97,15 @@ def testTypeIsInstance():
t1 = Type.parse("i32", ctx)
t2 = Type.parse("f32", ctx)
# CHECK: True
print(IntegerType.isinstance(t1))
print(isinstance(t1, IntegerType))
# CHECK: False
print(F32Type.isinstance(t1))
print(isinstance(t1, F32Type))
# CHECK: False
print(FloatType.isinstance(t1))
print(isinstance(t1, FloatType))
# CHECK: True
print(F32Type.isinstance(t2))
print(isinstance(t2, F32Type))
# CHECK: True
print(FloatType.isinstance(t2))
print(isinstance(t2, FloatType))
# CHECK-LABEL: TEST: testFloatTypeSubclasses

View File

@ -69,12 +69,12 @@ def testValueIsInstance():
ctx,
)
func = module.body.operations[0]
assert BlockArgument.isinstance(func.regions[0].blocks[0].arguments[0])
assert not OpResult.isinstance(func.regions[0].blocks[0].arguments[0])
assert isinstance(func.regions[0].blocks[0].arguments[0], BlockArgument)
assert not isinstance(func.regions[0].blocks[0].arguments[0], OpResult)
op = func.regions[0].blocks[0].operations[0]
assert not BlockArgument.isinstance(op.results[0])
assert OpResult.isinstance(op.results[0])
assert not isinstance(op.results[0], BlockArgument)
assert isinstance(op.results[0], OpResult)
# CHECK-LABEL: TEST: testValueHash

View File

@ -886,7 +886,7 @@ constexpr const char *firstAttrDerivedResultTypeTemplate =
_ods_result_type_source_attr = attributes["{0}"]
_ods_derived_result_type = (
_ods_ir.TypeAttr(_ods_result_type_source_attr).value
if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
if isinstance(_ods_result_type_source_attr, _ods_ir.TypeAttr) else
_ods_result_type_source_attr.type)
results = [_ods_derived_result_type] * {1})Py";