[mlir][Python] use canonical Python isinstance instead of Type.isinstance (#172892)
We've been able to do `isinstance(x, Type)` for a quite a while now
(since
bfb1ba7526)
so remove `Type.isinstance` and the the special-casing
(`_is_integer_type`, `_is_floating_point_type`, `_is_index_type`) in
some places (and therefore support various `fp8`, `fp6`, `fp4` types).
This commit is contained in:
parent
e826168a24
commit
fb8bbd4ed8
@ -30,6 +30,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLType(MlirType type);
|
||||
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLAttributeType(MlirType type);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirPDLAttributeTypeGetTypeID(void);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirType mlirPDLAttributeTypeGet(MlirContext ctx);
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
@ -38,6 +40,8 @@ MLIR_CAPI_EXPORTED MlirType mlirPDLAttributeTypeGet(MlirContext ctx);
|
||||
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLOperationType(MlirType type);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirPDLOperationTypeGetTypeID(void);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirType mlirPDLOperationTypeGet(MlirContext ctx);
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
@ -46,6 +50,8 @@ MLIR_CAPI_EXPORTED MlirType mlirPDLOperationTypeGet(MlirContext ctx);
|
||||
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLRangeType(MlirType type);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirPDLRangeTypeGetTypeID(void);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirType mlirPDLRangeTypeGet(MlirType elementType);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirType mlirPDLRangeTypeGetElementType(MlirType type);
|
||||
@ -56,6 +62,8 @@ MLIR_CAPI_EXPORTED MlirType mlirPDLRangeTypeGetElementType(MlirType type);
|
||||
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLTypeType(MlirType type);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirPDLTypeTypeGetTypeID(void);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirType mlirPDLTypeTypeGet(MlirContext ctx);
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
@ -64,6 +72,8 @@ MLIR_CAPI_EXPORTED MlirType mlirPDLTypeTypeGet(MlirContext ctx);
|
||||
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLValueType(MlirType type);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirPDLValueTypeGetTypeID(void);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirType mlirPDLValueTypeGet(MlirContext ctx);
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
@ -103,6 +103,8 @@ mlirQuantizedTypeCastExpressedToStorageType(MlirType type, MlirType candidate);
|
||||
/// Returns `true` if the given type is an AnyQuantizedType.
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsAAnyQuantizedType(MlirType type);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirAnyQuantizedTypeGetTypeID(void);
|
||||
|
||||
/// Creates an instance of AnyQuantizedType with the given parameters in the
|
||||
/// same context as `storageType` and returns it. The instance is owned by the
|
||||
/// context.
|
||||
@ -119,6 +121,8 @@ MLIR_CAPI_EXPORTED MlirType mlirAnyQuantizedTypeGet(unsigned flags,
|
||||
/// Returns `true` if the given type is a UniformQuantizedType.
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsAUniformQuantizedType(MlirType type);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirUniformQuantizedTypeGetTypeID(void);
|
||||
|
||||
/// Creates an instance of UniformQuantizedType with the given parameters in the
|
||||
/// same context as `storageType` and returns it. The instance is owned by the
|
||||
/// context.
|
||||
@ -142,6 +146,8 @@ MLIR_CAPI_EXPORTED bool mlirUniformQuantizedTypeIsFixedPoint(MlirType type);
|
||||
/// Returns `true` if the given type is a UniformQuantizedPerAxisType.
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsAUniformQuantizedPerAxisType(MlirType type);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirUniformQuantizedPerAxisTypeGetTypeID(void);
|
||||
|
||||
/// Creates an instance of UniformQuantizedPerAxisType with the given parameters
|
||||
/// in the same context as `storageType` and returns it. `scales` and
|
||||
/// `zeroPoints` point to `nDims` number of elements. The instance is owned
|
||||
@ -180,6 +186,8 @@ mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type);
|
||||
MLIR_CAPI_EXPORTED bool
|
||||
mlirTypeIsAUniformQuantizedSubChannelType(MlirType type);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirUniformQuantizedSubChannelTypeGetTypeID(void);
|
||||
|
||||
/// Creates a UniformQuantizedSubChannelType with the given parameters.
|
||||
///
|
||||
/// The type is owned by the context. `scalesAttr` and `zeroPointsAttr` must be
|
||||
@ -220,6 +228,8 @@ mlirUniformQuantizedSubChannelTypeGetZeroPoints(MlirType type);
|
||||
/// Returns `true` if the given type is a CalibratedQuantizedType.
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsACalibratedQuantizedType(MlirType type);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirCalibratedQuantizedTypeGetTypeID(void);
|
||||
|
||||
/// Creates an instance of CalibratedQuantizedType with the given parameters
|
||||
/// in the same context as `expressedType` and returns it. The instance is owned
|
||||
/// by the context.
|
||||
|
||||
@ -957,12 +957,6 @@ public:
|
||||
auto cls = ClassTy(m, DerivedTy::pyClassName);
|
||||
cls.def(nanobind::init<PyType &>(), nanobind::keep_alive<0, 1>(),
|
||||
nanobind::arg("cast_from_type"));
|
||||
cls.def_static(
|
||||
"isinstance",
|
||||
[](PyType &otherType) -> bool {
|
||||
return DerivedTy::isaFunction(otherType);
|
||||
},
|
||||
nanobind::arg("other"));
|
||||
cls.def_prop_ro_static(
|
||||
"static_typeid",
|
||||
[](nanobind::object & /*class*/) {
|
||||
@ -1094,12 +1088,6 @@ public:
|
||||
}
|
||||
cls.def(nanobind::init<PyAttribute &>(), nanobind::keep_alive<0, 1>(),
|
||||
nanobind::arg("cast_from_attr"));
|
||||
cls.def_static(
|
||||
"isinstance",
|
||||
[](PyAttribute &otherAttr) -> bool {
|
||||
return DerivedTy::isaFunction(otherAttr);
|
||||
},
|
||||
nanobind::arg("other"));
|
||||
cls.def_prop_ro(
|
||||
"type",
|
||||
[](PyAttribute &attr) -> nanobind::typed<nanobind::object, PyType> {
|
||||
@ -1555,12 +1543,6 @@ public:
|
||||
.c_str()));
|
||||
cls.def(nanobind::init<PyValue &>(), nanobind::keep_alive<0, 1>(),
|
||||
nanobind::arg("value"));
|
||||
cls.def_static(
|
||||
"isinstance",
|
||||
[](PyValue &otherValue) -> bool {
|
||||
return DerivedTy::isaFunction(otherValue);
|
||||
},
|
||||
nanobind::arg("other_value"));
|
||||
cls.def(
|
||||
MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
|
||||
[](DerivedTy &self) -> nanobind::typed<nanobind::object, DerivedTy> {
|
||||
|
||||
@ -39,6 +39,8 @@ struct PDLType : PyConcreteType<PDLType> {
|
||||
|
||||
struct AttributeType : PyConcreteType<AttributeType> {
|
||||
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLAttributeType;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirPDLAttributeTypeGetTypeID;
|
||||
static constexpr const char *pyClassName = "AttributeType";
|
||||
using Base::Base;
|
||||
|
||||
@ -60,6 +62,8 @@ struct AttributeType : PyConcreteType<AttributeType> {
|
||||
|
||||
struct OperationType : PyConcreteType<OperationType> {
|
||||
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLOperationType;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirPDLOperationTypeGetTypeID;
|
||||
static constexpr const char *pyClassName = "OperationType";
|
||||
using Base::Base;
|
||||
|
||||
@ -81,6 +85,8 @@ struct OperationType : PyConcreteType<OperationType> {
|
||||
|
||||
struct RangeType : PyConcreteType<RangeType> {
|
||||
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLRangeType;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirPDLRangeTypeGetTypeID;
|
||||
static constexpr const char *pyClassName = "RangeType";
|
||||
using Base::Base;
|
||||
|
||||
@ -109,6 +115,8 @@ struct RangeType : PyConcreteType<RangeType> {
|
||||
|
||||
struct TypeType : PyConcreteType<TypeType> {
|
||||
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLTypeType;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirPDLTypeTypeGetTypeID;
|
||||
static constexpr const char *pyClassName = "TypeType";
|
||||
using Base::Base;
|
||||
|
||||
@ -130,6 +138,8 @@ struct TypeType : PyConcreteType<TypeType> {
|
||||
|
||||
struct ValueType : PyConcreteType<ValueType> {
|
||||
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPDLValueType;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirPDLValueTypeGetTypeID;
|
||||
static constexpr const char *pyClassName = "ValueType";
|
||||
using Base::Base;
|
||||
|
||||
|
||||
@ -192,6 +192,8 @@ struct QuantizedType : PyConcreteType<QuantizedType> {
|
||||
|
||||
struct AnyQuantizedType : PyConcreteType<AnyQuantizedType, QuantizedType> {
|
||||
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAAnyQuantizedType;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirAnyQuantizedTypeGetTypeID;
|
||||
static constexpr const char *pyClassName = "AnyQuantizedType";
|
||||
using Base::Base;
|
||||
|
||||
@ -221,6 +223,8 @@ struct AnyQuantizedType : PyConcreteType<AnyQuantizedType, QuantizedType> {
|
||||
struct UniformQuantizedType
|
||||
: PyConcreteType<UniformQuantizedType, QuantizedType> {
|
||||
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUniformQuantizedType;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirUniformQuantizedTypeGetTypeID;
|
||||
static constexpr const char *pyClassName = "UniformQuantizedType";
|
||||
using Base::Base;
|
||||
|
||||
@ -273,6 +277,8 @@ struct UniformQuantizedPerAxisType
|
||||
: PyConcreteType<UniformQuantizedPerAxisType, QuantizedType> {
|
||||
static constexpr IsAFunctionTy isaFunction =
|
||||
mlirTypeIsAUniformQuantizedPerAxisType;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirUniformQuantizedPerAxisTypeGetTypeID;
|
||||
static constexpr const char *pyClassName = "UniformQuantizedPerAxisType";
|
||||
using Base::Base;
|
||||
|
||||
@ -357,6 +363,8 @@ struct UniformQuantizedSubChannelType
|
||||
: PyConcreteType<UniformQuantizedSubChannelType, QuantizedType> {
|
||||
static constexpr IsAFunctionTy isaFunction =
|
||||
mlirTypeIsAUniformQuantizedSubChannelType;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirUniformQuantizedSubChannelTypeGetTypeID;
|
||||
static constexpr const char *pyClassName = "UniformQuantizedSubChannelType";
|
||||
using Base::Base;
|
||||
|
||||
@ -448,6 +456,8 @@ struct CalibratedQuantizedType
|
||||
: PyConcreteType<CalibratedQuantizedType, QuantizedType> {
|
||||
static constexpr IsAFunctionTy isaFunction =
|
||||
mlirTypeIsACalibratedQuantizedType;
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
||||
mlirCalibratedQuantizedTypeGetTypeID;
|
||||
static constexpr const char *pyClassName = "CalibratedQuantizedType";
|
||||
using Base::Base;
|
||||
|
||||
|
||||
@ -118,12 +118,6 @@ public:
|
||||
static void bind(nb::module_ &m) {
|
||||
auto cls = ClassTy(m, DerivedTy::pyClassName);
|
||||
cls.def(nb::init<PyAffineExpr &>(), nb::arg("expr"));
|
||||
cls.def_static(
|
||||
"isinstance",
|
||||
[](PyAffineExpr &otherAffineExpr) -> bool {
|
||||
return DerivedTy::isaFunction(otherAffineExpr);
|
||||
},
|
||||
nb::arg("other"));
|
||||
DerivedTy::bindDerived(cls);
|
||||
}
|
||||
|
||||
|
||||
@ -32,6 +32,10 @@ bool mlirTypeIsAPDLAttributeType(MlirType type) {
|
||||
return isa<pdl::AttributeType>(unwrap(type));
|
||||
}
|
||||
|
||||
MlirTypeID mlirPDLAttributeTypeGetTypeID(void) {
|
||||
return wrap(pdl::AttributeType::getTypeID());
|
||||
}
|
||||
|
||||
MlirType mlirPDLAttributeTypeGet(MlirContext ctx) {
|
||||
return wrap(pdl::AttributeType::get(unwrap(ctx)));
|
||||
}
|
||||
@ -44,6 +48,10 @@ bool mlirTypeIsAPDLOperationType(MlirType type) {
|
||||
return isa<pdl::OperationType>(unwrap(type));
|
||||
}
|
||||
|
||||
MlirTypeID mlirPDLOperationTypeGetTypeID(void) {
|
||||
return wrap(pdl::OperationType::getTypeID());
|
||||
}
|
||||
|
||||
MlirType mlirPDLOperationTypeGet(MlirContext ctx) {
|
||||
return wrap(pdl::OperationType::get(unwrap(ctx)));
|
||||
}
|
||||
@ -56,6 +64,10 @@ bool mlirTypeIsAPDLRangeType(MlirType type) {
|
||||
return isa<pdl::RangeType>(unwrap(type));
|
||||
}
|
||||
|
||||
MlirTypeID mlirPDLRangeTypeGetTypeID(void) {
|
||||
return wrap(pdl::RangeType::getTypeID());
|
||||
}
|
||||
|
||||
MlirType mlirPDLRangeTypeGet(MlirType elementType) {
|
||||
return wrap(pdl::RangeType::get(unwrap(elementType)));
|
||||
}
|
||||
@ -72,6 +84,10 @@ bool mlirTypeIsAPDLTypeType(MlirType type) {
|
||||
return isa<pdl::TypeType>(unwrap(type));
|
||||
}
|
||||
|
||||
MlirTypeID mlirPDLTypeTypeGetTypeID(void) {
|
||||
return wrap(pdl::TypeType::getTypeID());
|
||||
}
|
||||
|
||||
MlirType mlirPDLTypeTypeGet(MlirContext ctx) {
|
||||
return wrap(pdl::TypeType::get(unwrap(ctx)));
|
||||
}
|
||||
@ -84,6 +100,10 @@ bool mlirTypeIsAPDLValueType(MlirType type) {
|
||||
return isa<pdl::ValueType>(unwrap(type));
|
||||
}
|
||||
|
||||
MlirTypeID mlirPDLValueTypeGetTypeID(void) {
|
||||
return wrap(pdl::ValueType::getTypeID());
|
||||
}
|
||||
|
||||
MlirType mlirPDLValueTypeGet(MlirContext ctx) {
|
||||
return wrap(pdl::ValueType::get(unwrap(ctx)));
|
||||
}
|
||||
|
||||
@ -113,6 +113,10 @@ bool mlirTypeIsAAnyQuantizedType(MlirType type) {
|
||||
return isa<quant::AnyQuantizedType>(unwrap(type));
|
||||
}
|
||||
|
||||
MlirTypeID mlirAnyQuantizedTypeGetTypeID(void) {
|
||||
return wrap(quant::AnyQuantizedType::getTypeID());
|
||||
}
|
||||
|
||||
MlirType mlirAnyQuantizedTypeGet(unsigned flags, MlirType storageType,
|
||||
MlirType expressedType, int64_t storageTypeMin,
|
||||
int64_t storageTypeMax) {
|
||||
@ -129,6 +133,10 @@ bool mlirTypeIsAUniformQuantizedType(MlirType type) {
|
||||
return isa<quant::UniformQuantizedType>(unwrap(type));
|
||||
}
|
||||
|
||||
MlirTypeID mlirUniformQuantizedTypeGetTypeID(void) {
|
||||
return wrap(quant::UniformQuantizedType::getTypeID());
|
||||
}
|
||||
|
||||
MlirType mlirUniformQuantizedTypeGet(unsigned flags, MlirType storageType,
|
||||
MlirType expressedType, double scale,
|
||||
int64_t zeroPoint, int64_t storageTypeMin,
|
||||
@ -158,6 +166,10 @@ bool mlirTypeIsAUniformQuantizedPerAxisType(MlirType type) {
|
||||
return isa<quant::UniformQuantizedPerAxisType>(unwrap(type));
|
||||
}
|
||||
|
||||
MlirTypeID mlirUniformQuantizedPerAxisTypeGetTypeID(void) {
|
||||
return wrap(quant::UniformQuantizedPerAxisType::getTypeID());
|
||||
}
|
||||
|
||||
MlirType mlirUniformQuantizedPerAxisTypeGet(
|
||||
unsigned flags, MlirType storageType, MlirType expressedType,
|
||||
intptr_t nDims, double *scales, int64_t *zeroPoints,
|
||||
@ -203,6 +215,10 @@ bool mlirTypeIsAUniformQuantizedSubChannelType(MlirType type) {
|
||||
return isa<quant::UniformQuantizedSubChannelType>(unwrap(type));
|
||||
}
|
||||
|
||||
MlirTypeID mlirUniformQuantizedSubChannelTypeGetTypeID(void) {
|
||||
return wrap(quant::UniformQuantizedSubChannelType::getTypeID());
|
||||
}
|
||||
|
||||
MlirType mlirUniformQuantizedSubChannelTypeGet(
|
||||
unsigned flags, MlirType storageType, MlirType expressedType,
|
||||
MlirAttribute scalesAttr, MlirAttribute zeroPointsAttr, intptr_t nDims,
|
||||
@ -258,6 +274,10 @@ bool mlirTypeIsACalibratedQuantizedType(MlirType type) {
|
||||
return isa<quant::CalibratedQuantizedType>(unwrap(type));
|
||||
}
|
||||
|
||||
MlirTypeID mlirCalibratedQuantizedTypeGetTypeID(void) {
|
||||
return wrap(quant::CalibratedQuantizedType::getTypeID());
|
||||
}
|
||||
|
||||
MlirType mlirCalibratedQuantizedTypeGet(MlirType expressedType, double min,
|
||||
double max) {
|
||||
return wrap(
|
||||
|
||||
@ -21,26 +21,6 @@ except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
|
||||
def _isa(obj: Any, cls: type):
|
||||
try:
|
||||
cls(obj)
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _is_any_of(obj: Any, classes: List[type]):
|
||||
return any(_isa(obj, cls) for cls in classes)
|
||||
|
||||
|
||||
def _is_integer_like_type(type: Type):
|
||||
return _is_any_of(type, [IntegerType, IndexType])
|
||||
|
||||
|
||||
def _is_float_type(type: Type):
|
||||
return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class ConstantOp(ConstantOp):
|
||||
"""Specialization for the constant op class."""
|
||||
@ -96,9 +76,9 @@ class ConstantOp(ConstantOp):
|
||||
|
||||
@property
|
||||
def literal_value(self) -> Union[int, float]:
|
||||
if _is_integer_like_type(self.type):
|
||||
if isinstance(self.type, (IntegerType, IndexType)):
|
||||
return IntegerAttr(self.value).value
|
||||
elif _is_float_type(self.type):
|
||||
elif isinstance(self.type, FloatType):
|
||||
return FloatAttr(self.value).value
|
||||
else:
|
||||
raise ValueError("only integer and float constants have literal values")
|
||||
|
||||
@ -412,9 +412,9 @@ class _BodyBuilder:
|
||||
)
|
||||
if operand.type == to_type:
|
||||
return operand
|
||||
if _is_integer_type(to_type):
|
||||
if isinstance(to_type, IntegerType):
|
||||
return self._cast_to_integer(to_type, operand, is_unsigned_cast)
|
||||
elif _is_floating_point_type(to_type):
|
||||
elif isinstance(to_type, FloatType):
|
||||
return self._cast_to_floating_point(to_type, operand, is_unsigned_cast)
|
||||
|
||||
def _cast_to_integer(
|
||||
@ -422,11 +422,11 @@ class _BodyBuilder:
|
||||
) -> Value:
|
||||
to_width = IntegerType(to_type).width
|
||||
operand_type = operand.type
|
||||
if _is_floating_point_type(operand_type):
|
||||
if isinstance(operand_type, FloatType):
|
||||
if is_unsigned_cast:
|
||||
return arith.FPToUIOp(to_type, operand).result
|
||||
return arith.FPToSIOp(to_type, operand).result
|
||||
if _is_index_type(operand_type):
|
||||
if isinstance(operand_type, IndexType):
|
||||
return arith.IndexCastOp(to_type, operand).result
|
||||
# Assume integer.
|
||||
from_width = IntegerType(operand_type).width
|
||||
@ -444,13 +444,15 @@ class _BodyBuilder:
|
||||
self, to_type: Type, operand: Value, is_unsigned_cast: bool
|
||||
) -> Value:
|
||||
operand_type = operand.type
|
||||
if _is_integer_type(operand_type):
|
||||
if isinstance(operand_type, IntegerType):
|
||||
if is_unsigned_cast:
|
||||
return arith.UIToFPOp(to_type, operand).result
|
||||
return arith.SIToFPOp(to_type, operand).result
|
||||
# Assume FloatType.
|
||||
to_width = _get_floating_point_width(to_type)
|
||||
from_width = _get_floating_point_width(operand_type)
|
||||
assert isinstance(to_type, FloatType)
|
||||
assert isinstance(operand_type, FloatType)
|
||||
to_width = to_type.width
|
||||
from_width = operand_type.width
|
||||
if to_width > from_width:
|
||||
return arith.ExtFOp(to_type, operand).result
|
||||
elif to_width < from_width:
|
||||
@ -466,89 +468,89 @@ class _BodyBuilder:
|
||||
return self._cast(type_var_name, operand, True)
|
||||
|
||||
def _unary_exp(self, x: Value) -> Value:
|
||||
if _is_floating_point_type(x.type):
|
||||
if isinstance(x.type, FloatType):
|
||||
return math.ExpOp(x).result
|
||||
raise NotImplementedError("Unsupported 'exp' operand: {x}")
|
||||
|
||||
def _unary_log(self, x: Value) -> Value:
|
||||
if _is_floating_point_type(x.type):
|
||||
if isinstance(x.type, FloatType):
|
||||
return math.LogOp(x).result
|
||||
raise NotImplementedError("Unsupported 'log' operand: {x}")
|
||||
|
||||
def _unary_abs(self, x: Value) -> Value:
|
||||
if _is_floating_point_type(x.type):
|
||||
if isinstance(x.type, FloatType):
|
||||
return math.AbsFOp(x).result
|
||||
raise NotImplementedError("Unsupported 'abs' operand: {x}")
|
||||
|
||||
def _unary_ceil(self, x: Value) -> Value:
|
||||
if _is_floating_point_type(x.type):
|
||||
if isinstance(x.type, FloatType):
|
||||
return math.CeilOp(x).result
|
||||
raise NotImplementedError("Unsupported 'ceil' operand: {x}")
|
||||
|
||||
def _unary_floor(self, x: Value) -> Value:
|
||||
if _is_floating_point_type(x.type):
|
||||
if isinstance(x.type, FloatType):
|
||||
return math.FloorOp(x).result
|
||||
raise NotImplementedError("Unsupported 'floor' operand: {x}")
|
||||
|
||||
def _unary_negf(self, x: Value) -> Value:
|
||||
if _is_floating_point_type(x.type):
|
||||
if isinstance(x.type, FloatType):
|
||||
return arith.NegFOp(x).result
|
||||
if _is_complex_type(x.type):
|
||||
if isinstance(x.type, ComplexType):
|
||||
return complex.NegOp(x).result
|
||||
raise NotImplementedError("Unsupported 'negf' operand: {x}")
|
||||
|
||||
def _binary_add(self, lhs: Value, rhs: Value) -> Value:
|
||||
if _is_floating_point_type(lhs.type):
|
||||
if isinstance(lhs.type, FloatType):
|
||||
return arith.AddFOp(lhs, rhs).result
|
||||
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
|
||||
if isinstance(lhs.type, IntegerType) or isinstance(lhs.type, IndexType):
|
||||
return arith.AddIOp(lhs, rhs).result
|
||||
if _is_complex_type(lhs.type):
|
||||
if isinstance(lhs.type, ComplexType):
|
||||
return complex.AddOp(lhs, rhs).result
|
||||
raise NotImplementedError("Unsupported 'add' operands: {lhs}, {rhs}")
|
||||
|
||||
def _binary_sub(self, lhs: Value, rhs: Value) -> Value:
|
||||
if _is_floating_point_type(lhs.type):
|
||||
if isinstance(lhs.type, FloatType):
|
||||
return arith.SubFOp(lhs, rhs).result
|
||||
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
|
||||
if isinstance(lhs.type, IntegerType) or isinstance(lhs.type, IndexType):
|
||||
return arith.SubIOp(lhs, rhs).result
|
||||
if _is_complex_type(lhs.type):
|
||||
if isinstance(lhs.type, ComplexType):
|
||||
return complex.SubOp(lhs, rhs).result
|
||||
raise NotImplementedError("Unsupported 'sub' operands: {lhs}, {rhs}")
|
||||
|
||||
def _binary_mul(self, lhs: Value, rhs: Value) -> Value:
|
||||
if _is_floating_point_type(lhs.type):
|
||||
if isinstance(lhs.type, FloatType):
|
||||
return arith.MulFOp(lhs, rhs).result
|
||||
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
|
||||
if isinstance(lhs.type, IntegerType) or isinstance(lhs.type, IndexType):
|
||||
return arith.MulIOp(lhs, rhs).result
|
||||
if _is_complex_type(lhs.type):
|
||||
if isinstance(lhs.type, ComplexType):
|
||||
return complex.MulOp(lhs, rhs).result
|
||||
raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}")
|
||||
|
||||
def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value:
|
||||
if _is_floating_point_type(lhs.type):
|
||||
if isinstance(lhs.type, FloatType):
|
||||
return arith.MaximumFOp(lhs, rhs).result
|
||||
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
|
||||
if isinstance(lhs.type, IntegerType) or isinstance(lhs.type, IndexType):
|
||||
return arith.MaxSIOp(lhs, rhs).result
|
||||
raise NotImplementedError("Unsupported 'max' operands: {lhs}, {rhs}")
|
||||
|
||||
def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value:
|
||||
if (
|
||||
_is_integer_type(lhs.type) and not _is_bool_type(lhs.type)
|
||||
) or _is_index_type(lhs.type):
|
||||
isinstance(lhs.type, IntegerType) and not _is_bool_type(lhs.type)
|
||||
) or isinstance(lhs.type, IndexType):
|
||||
return arith.MaxUIOp(lhs, rhs).result
|
||||
raise NotImplementedError("Unsupported 'max_unsigned' operands: {lhs}, {rhs}")
|
||||
|
||||
def _binary_min_signed(self, lhs: Value, rhs: Value) -> Value:
|
||||
if _is_floating_point_type(lhs.type):
|
||||
if isinstance(lhs.type, FloatType):
|
||||
return arith.MinimumFOp(lhs, rhs).result
|
||||
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
|
||||
if isinstance(lhs.type, IntegerType) or isinstance(lhs.type, IndexType):
|
||||
return arith.MinSIOp(lhs, rhs).result
|
||||
raise NotImplementedError("Unsupported 'min' operands: {lhs}, {rhs}")
|
||||
|
||||
def _binary_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
|
||||
if (
|
||||
_is_integer_type(lhs.type) and not _is_bool_type(lhs.type)
|
||||
) or _is_index_type(lhs.type):
|
||||
isinstance(lhs.type, IntegerType) and not _is_bool_type(lhs.type)
|
||||
) or isinstance(lhs.type, IndexType):
|
||||
return arith.MinUIOp(lhs, rhs).result
|
||||
raise NotImplementedError("Unsupported 'min_unsigned' operands: {lhs}, {rhs}")
|
||||
|
||||
@ -611,44 +613,7 @@ def _add_type_mapping(
|
||||
block_arg_types.append(element_or_self_type)
|
||||
|
||||
|
||||
def _is_complex_type(t: Type) -> bool:
|
||||
return ComplexType.isinstance(t)
|
||||
|
||||
|
||||
def _is_floating_point_type(t: Type) -> bool:
|
||||
# TODO: Create a FloatType in the Python API and implement the switch
|
||||
# there.
|
||||
return (
|
||||
F64Type.isinstance(t)
|
||||
or F32Type.isinstance(t)
|
||||
or F16Type.isinstance(t)
|
||||
or BF16Type.isinstance(t)
|
||||
)
|
||||
|
||||
|
||||
def _is_integer_type(t: Type) -> bool:
|
||||
return IntegerType.isinstance(t)
|
||||
|
||||
|
||||
def _is_index_type(t: Type) -> bool:
|
||||
return IndexType.isinstance(t)
|
||||
|
||||
|
||||
def _is_bool_type(t: Type) -> bool:
|
||||
if not IntegerType.isinstance(t):
|
||||
if not isinstance(t, IntegerType):
|
||||
return False
|
||||
return IntegerType(t).width == 1
|
||||
|
||||
|
||||
def _get_floating_point_width(t: Type) -> int:
|
||||
# TODO: Create a FloatType in the Python API and implement the switch
|
||||
# there.
|
||||
if F64Type.isinstance(t):
|
||||
return 64
|
||||
if F32Type.isinstance(t):
|
||||
return 32
|
||||
if F16Type.isinstance(t):
|
||||
return 16
|
||||
if BF16Type.isinstance(t):
|
||||
return 16
|
||||
raise NotImplementedError(f"Unhandled floating point type switch {t}")
|
||||
return t.width == 1
|
||||
|
||||
@ -8,15 +8,22 @@ from typing import Optional
|
||||
from ._memref_ops_gen import *
|
||||
from ._memref_ops_gen import _Dialect
|
||||
from ._ods_common import _dispatch_mixed_values, MixedValues
|
||||
from .arith import ConstantOp, _is_integer_like_type
|
||||
from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType, Operation
|
||||
from ..ir import (
|
||||
IndexType,
|
||||
IntegerType,
|
||||
MemRefType,
|
||||
ShapedType,
|
||||
StridedLayoutAttr,
|
||||
Value,
|
||||
)
|
||||
from . import arith
|
||||
|
||||
|
||||
def _is_constant_int_like(i):
|
||||
return (
|
||||
isinstance(i, Value)
|
||||
and isinstance(i.owner, ConstantOp)
|
||||
and _is_integer_like_type(i.type)
|
||||
and isinstance(i.owner, arith.ConstantOp)
|
||||
and isinstance(i.type, (IntegerType, IndexType))
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -233,7 +233,7 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu
|
||||
// CHECK: _ods_result_type_source_attr = attributes["type"]
|
||||
// CHECK: _ods_derived_result_type = (
|
||||
// CHECK: _ods_ir.TypeAttr(_ods_result_type_source_attr).value
|
||||
// CHECK: if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
|
||||
// CHECK: if isinstance(_ods_result_type_source_attr, _ods_ir.TypeAttr) else
|
||||
// CHECK: _ods_result_type_source_attr.type)
|
||||
// CHECK: results = [_ods_derived_result_type] * 2
|
||||
let arguments = (ins TypeAttr:$type);
|
||||
|
||||
@ -42,10 +42,10 @@ def testFastMathFlags():
|
||||
def testArithValue():
|
||||
def _binary_op(lhs, rhs, op: str) -> "ArithValue":
|
||||
op = op.capitalize()
|
||||
if arith._is_float_type(lhs.type) and arith._is_float_type(rhs.type):
|
||||
if isinstance(lhs.type, FloatType) and isinstance(rhs.type, FloatType):
|
||||
op += "F"
|
||||
elif arith._is_integer_like_type(lhs.type) and arith._is_integer_like_type(
|
||||
lhs.type
|
||||
elif isinstance(lhs.type, (IntegerType, IndexType)) and isinstance(
|
||||
lhs.type, (IntegerType, IndexType)
|
||||
):
|
||||
op += "I"
|
||||
else:
|
||||
|
||||
@ -17,17 +17,17 @@ def test_attribute_type():
|
||||
parsedType = Type.parse("!pdl.attribute")
|
||||
constructedType = pdl.AttributeType.get()
|
||||
|
||||
assert pdl.AttributeType.isinstance(parsedType)
|
||||
assert not pdl.OperationType.isinstance(parsedType)
|
||||
assert not pdl.RangeType.isinstance(parsedType)
|
||||
assert not pdl.TypeType.isinstance(parsedType)
|
||||
assert not pdl.ValueType.isinstance(parsedType)
|
||||
assert isinstance(parsedType, pdl.AttributeType)
|
||||
assert not isinstance(parsedType, pdl.OperationType)
|
||||
assert not isinstance(parsedType, pdl.RangeType)
|
||||
assert not isinstance(parsedType, pdl.TypeType)
|
||||
assert not isinstance(parsedType, pdl.ValueType)
|
||||
|
||||
assert pdl.AttributeType.isinstance(constructedType)
|
||||
assert not pdl.OperationType.isinstance(constructedType)
|
||||
assert not pdl.RangeType.isinstance(constructedType)
|
||||
assert not pdl.TypeType.isinstance(constructedType)
|
||||
assert not pdl.ValueType.isinstance(constructedType)
|
||||
assert isinstance(constructedType, pdl.AttributeType)
|
||||
assert not isinstance(constructedType, pdl.OperationType)
|
||||
assert not isinstance(constructedType, pdl.RangeType)
|
||||
assert not isinstance(constructedType, pdl.TypeType)
|
||||
assert not isinstance(constructedType, pdl.ValueType)
|
||||
|
||||
assert parsedType == constructedType
|
||||
|
||||
@ -44,17 +44,17 @@ def test_operation_type():
|
||||
parsedType = Type.parse("!pdl.operation")
|
||||
constructedType = pdl.OperationType.get()
|
||||
|
||||
assert not pdl.AttributeType.isinstance(parsedType)
|
||||
assert pdl.OperationType.isinstance(parsedType)
|
||||
assert not pdl.RangeType.isinstance(parsedType)
|
||||
assert not pdl.TypeType.isinstance(parsedType)
|
||||
assert not pdl.ValueType.isinstance(parsedType)
|
||||
assert not isinstance(parsedType, pdl.AttributeType)
|
||||
assert isinstance(parsedType, pdl.OperationType)
|
||||
assert not isinstance(parsedType, pdl.RangeType)
|
||||
assert not isinstance(parsedType, pdl.TypeType)
|
||||
assert not isinstance(parsedType, pdl.ValueType)
|
||||
|
||||
assert not pdl.AttributeType.isinstance(constructedType)
|
||||
assert pdl.OperationType.isinstance(constructedType)
|
||||
assert not pdl.RangeType.isinstance(constructedType)
|
||||
assert not pdl.TypeType.isinstance(constructedType)
|
||||
assert not pdl.ValueType.isinstance(constructedType)
|
||||
assert not isinstance(constructedType, pdl.AttributeType)
|
||||
assert isinstance(constructedType, pdl.OperationType)
|
||||
assert not isinstance(constructedType, pdl.RangeType)
|
||||
assert not isinstance(constructedType, pdl.TypeType)
|
||||
assert not isinstance(constructedType, pdl.ValueType)
|
||||
|
||||
assert parsedType == constructedType
|
||||
|
||||
@ -73,17 +73,17 @@ def test_range_type():
|
||||
constructedType = pdl.RangeType.get(typeType)
|
||||
elementType = constructedType.element_type
|
||||
|
||||
assert not pdl.AttributeType.isinstance(parsedType)
|
||||
assert not pdl.OperationType.isinstance(parsedType)
|
||||
assert pdl.RangeType.isinstance(parsedType)
|
||||
assert not pdl.TypeType.isinstance(parsedType)
|
||||
assert not pdl.ValueType.isinstance(parsedType)
|
||||
assert not isinstance(parsedType, pdl.AttributeType)
|
||||
assert not isinstance(parsedType, pdl.OperationType)
|
||||
assert isinstance(parsedType, pdl.RangeType)
|
||||
assert not isinstance(parsedType, pdl.TypeType)
|
||||
assert not isinstance(parsedType, pdl.ValueType)
|
||||
|
||||
assert not pdl.AttributeType.isinstance(constructedType)
|
||||
assert not pdl.OperationType.isinstance(constructedType)
|
||||
assert pdl.RangeType.isinstance(constructedType)
|
||||
assert not pdl.TypeType.isinstance(constructedType)
|
||||
assert not pdl.ValueType.isinstance(constructedType)
|
||||
assert not isinstance(constructedType, pdl.AttributeType)
|
||||
assert not isinstance(constructedType, pdl.OperationType)
|
||||
assert isinstance(constructedType, pdl.RangeType)
|
||||
assert not isinstance(constructedType, pdl.TypeType)
|
||||
assert not isinstance(constructedType, pdl.ValueType)
|
||||
|
||||
assert parsedType == constructedType
|
||||
assert elementType == typeType
|
||||
@ -103,17 +103,17 @@ def test_type_type():
|
||||
parsedType = Type.parse("!pdl.type")
|
||||
constructedType = pdl.TypeType.get()
|
||||
|
||||
assert not pdl.AttributeType.isinstance(parsedType)
|
||||
assert not pdl.OperationType.isinstance(parsedType)
|
||||
assert not pdl.RangeType.isinstance(parsedType)
|
||||
assert pdl.TypeType.isinstance(parsedType)
|
||||
assert not pdl.ValueType.isinstance(parsedType)
|
||||
assert not isinstance(parsedType, pdl.AttributeType)
|
||||
assert not isinstance(parsedType, pdl.OperationType)
|
||||
assert not isinstance(parsedType, pdl.RangeType)
|
||||
assert isinstance(parsedType, pdl.TypeType)
|
||||
assert not isinstance(parsedType, pdl.ValueType)
|
||||
|
||||
assert not pdl.AttributeType.isinstance(constructedType)
|
||||
assert not pdl.OperationType.isinstance(constructedType)
|
||||
assert not pdl.RangeType.isinstance(constructedType)
|
||||
assert pdl.TypeType.isinstance(constructedType)
|
||||
assert not pdl.ValueType.isinstance(constructedType)
|
||||
assert not isinstance(constructedType, pdl.AttributeType)
|
||||
assert not isinstance(constructedType, pdl.OperationType)
|
||||
assert not isinstance(constructedType, pdl.RangeType)
|
||||
assert isinstance(constructedType, pdl.TypeType)
|
||||
assert not isinstance(constructedType, pdl.ValueType)
|
||||
|
||||
assert parsedType == constructedType
|
||||
|
||||
@ -130,17 +130,17 @@ def test_value_type():
|
||||
parsedType = Type.parse("!pdl.value")
|
||||
constructedType = pdl.ValueType.get()
|
||||
|
||||
assert not pdl.AttributeType.isinstance(parsedType)
|
||||
assert not pdl.OperationType.isinstance(parsedType)
|
||||
assert not pdl.RangeType.isinstance(parsedType)
|
||||
assert not pdl.TypeType.isinstance(parsedType)
|
||||
assert pdl.ValueType.isinstance(parsedType)
|
||||
assert not isinstance(parsedType, pdl.AttributeType)
|
||||
assert not isinstance(parsedType, pdl.OperationType)
|
||||
assert not isinstance(parsedType, pdl.RangeType)
|
||||
assert not isinstance(parsedType, pdl.TypeType)
|
||||
assert isinstance(parsedType, pdl.ValueType)
|
||||
|
||||
assert not pdl.AttributeType.isinstance(constructedType)
|
||||
assert not pdl.OperationType.isinstance(constructedType)
|
||||
assert not pdl.RangeType.isinstance(constructedType)
|
||||
assert not pdl.TypeType.isinstance(constructedType)
|
||||
assert pdl.ValueType.isinstance(constructedType)
|
||||
assert not isinstance(constructedType, pdl.AttributeType)
|
||||
assert not isinstance(constructedType, pdl.OperationType)
|
||||
assert not isinstance(constructedType, pdl.RangeType)
|
||||
assert not isinstance(constructedType, pdl.TypeType)
|
||||
assert isinstance(constructedType, pdl.ValueType)
|
||||
|
||||
assert parsedType == constructedType
|
||||
|
||||
|
||||
@ -24,23 +24,23 @@ def test_type_hierarchy():
|
||||
)
|
||||
calibrated = Type.parse("!quant.calibrated<f32<-0.998:1.2321>>")
|
||||
|
||||
assert not quant.QuantizedType.isinstance(i8)
|
||||
assert quant.QuantizedType.isinstance(any)
|
||||
assert quant.QuantizedType.isinstance(uniform)
|
||||
assert quant.QuantizedType.isinstance(per_axis)
|
||||
assert quant.QuantizedType.isinstance(sub_channel)
|
||||
assert quant.QuantizedType.isinstance(calibrated)
|
||||
assert not isinstance(i8, quant.QuantizedType)
|
||||
assert isinstance(any, quant.QuantizedType)
|
||||
assert isinstance(uniform, quant.QuantizedType)
|
||||
assert isinstance(per_axis, quant.QuantizedType)
|
||||
assert isinstance(sub_channel, quant.QuantizedType)
|
||||
assert isinstance(calibrated, quant.QuantizedType)
|
||||
|
||||
assert quant.AnyQuantizedType.isinstance(any)
|
||||
assert quant.UniformQuantizedType.isinstance(uniform)
|
||||
assert quant.UniformQuantizedPerAxisType.isinstance(per_axis)
|
||||
assert quant.UniformQuantizedSubChannelType.isinstance(sub_channel)
|
||||
assert quant.CalibratedQuantizedType.isinstance(calibrated)
|
||||
assert isinstance(any, quant.AnyQuantizedType)
|
||||
assert isinstance(uniform, quant.UniformQuantizedType)
|
||||
assert isinstance(per_axis, quant.UniformQuantizedPerAxisType)
|
||||
assert isinstance(sub_channel, quant.UniformQuantizedSubChannelType)
|
||||
assert isinstance(calibrated, quant.CalibratedQuantizedType)
|
||||
|
||||
assert not quant.AnyQuantizedType.isinstance(uniform)
|
||||
assert not quant.UniformQuantizedType.isinstance(per_axis)
|
||||
assert not quant.UniformQuantizedType.isinstance(sub_channel)
|
||||
assert not quant.UniformQuantizedPerAxisType.isinstance(sub_channel)
|
||||
assert not isinstance(uniform, quant.AnyQuantizedType)
|
||||
assert not isinstance(per_axis, quant.UniformQuantizedType)
|
||||
assert not isinstance(sub_channel, quant.UniformQuantizedType)
|
||||
assert not isinstance(sub_channel, quant.UniformQuantizedPerAxisType)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: test_any_quantized_type
|
||||
|
||||
@ -354,21 +354,21 @@ def testIsInstance():
|
||||
mul = AffineMulExpr.get(d1, c2)
|
||||
|
||||
# CHECK: True
|
||||
print(AffineDimExpr.isinstance(d1))
|
||||
print(isinstance(d1, AffineDimExpr))
|
||||
# CHECK: False
|
||||
print(AffineConstantExpr.isinstance(d1))
|
||||
print(isinstance(d1, AffineConstantExpr))
|
||||
# CHECK: True
|
||||
print(AffineConstantExpr.isinstance(c2))
|
||||
print(isinstance(c2, AffineConstantExpr))
|
||||
# CHECK: False
|
||||
print(AffineMulExpr.isinstance(c2))
|
||||
print(isinstance(c2, AffineMulExpr))
|
||||
# CHECK: True
|
||||
print(AffineAddExpr.isinstance(add))
|
||||
print(isinstance(add, AffineAddExpr))
|
||||
# CHECK: False
|
||||
print(AffineMulExpr.isinstance(add))
|
||||
print(isinstance(add, AffineMulExpr))
|
||||
# CHECK: True
|
||||
print(AffineMulExpr.isinstance(mul))
|
||||
print(isinstance(mul, AffineMulExpr))
|
||||
# CHECK: False
|
||||
print(AffineAddExpr.isinstance(mul))
|
||||
print(isinstance(mul, AffineAddExpr))
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testCompose
|
||||
|
||||
@ -94,10 +94,10 @@ def testAttrIsInstance():
|
||||
with Context():
|
||||
a1 = Attribute.parse("42")
|
||||
a2 = Attribute.parse("[42]")
|
||||
assert IntegerAttr.isinstance(a1)
|
||||
assert not IntegerAttr.isinstance(a2)
|
||||
assert not ArrayAttr.isinstance(a1)
|
||||
assert ArrayAttr.isinstance(a2)
|
||||
assert isinstance(a1, IntegerAttr)
|
||||
assert not isinstance(a2, IntegerAttr)
|
||||
assert not isinstance(a1, ArrayAttr)
|
||||
assert isinstance(a2, ArrayAttr)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testAttrEqDoesNotRaise
|
||||
|
||||
@ -34,7 +34,7 @@ def testInferLocations():
|
||||
# Test nesting of loc_tracebacks().
|
||||
with loc_tracebacks():
|
||||
# fmt: off
|
||||
# CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]arith.py":65:12 to :76) at callsite("constant"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]arith.py":110:40 to :81) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:14 to :48) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:1 to :4))))))
|
||||
# CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]arith.py":45:12 to :76) at callsite("constant"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]arith.py":90:40 to :81) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:14 to :48) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":{{[0-9]+}}:1 to :4))))))
|
||||
# fmt: on
|
||||
print(one.location)
|
||||
|
||||
|
||||
@ -97,15 +97,15 @@ def testTypeIsInstance():
|
||||
t1 = Type.parse("i32", ctx)
|
||||
t2 = Type.parse("f32", ctx)
|
||||
# CHECK: True
|
||||
print(IntegerType.isinstance(t1))
|
||||
print(isinstance(t1, IntegerType))
|
||||
# CHECK: False
|
||||
print(F32Type.isinstance(t1))
|
||||
print(isinstance(t1, F32Type))
|
||||
# CHECK: False
|
||||
print(FloatType.isinstance(t1))
|
||||
print(isinstance(t1, FloatType))
|
||||
# CHECK: True
|
||||
print(F32Type.isinstance(t2))
|
||||
print(isinstance(t2, F32Type))
|
||||
# CHECK: True
|
||||
print(FloatType.isinstance(t2))
|
||||
print(isinstance(t2, FloatType))
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testFloatTypeSubclasses
|
||||
|
||||
@ -69,12 +69,12 @@ def testValueIsInstance():
|
||||
ctx,
|
||||
)
|
||||
func = module.body.operations[0]
|
||||
assert BlockArgument.isinstance(func.regions[0].blocks[0].arguments[0])
|
||||
assert not OpResult.isinstance(func.regions[0].blocks[0].arguments[0])
|
||||
assert isinstance(func.regions[0].blocks[0].arguments[0], BlockArgument)
|
||||
assert not isinstance(func.regions[0].blocks[0].arguments[0], OpResult)
|
||||
|
||||
op = func.regions[0].blocks[0].operations[0]
|
||||
assert not BlockArgument.isinstance(op.results[0])
|
||||
assert OpResult.isinstance(op.results[0])
|
||||
assert not isinstance(op.results[0], BlockArgument)
|
||||
assert isinstance(op.results[0], OpResult)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testValueHash
|
||||
|
||||
@ -886,7 +886,7 @@ constexpr const char *firstAttrDerivedResultTypeTemplate =
|
||||
_ods_result_type_source_attr = attributes["{0}"]
|
||||
_ods_derived_result_type = (
|
||||
_ods_ir.TypeAttr(_ods_result_type_source_attr).value
|
||||
if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
|
||||
if isinstance(_ods_result_type_source_attr, _ods_ir.TypeAttr) else
|
||||
_ods_result_type_source_attr.type)
|
||||
results = [_ods_derived_result_type] * {1})Py";
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user