# What
This PR adds a shared library `MLIRPythonSupport` which contains all of
the CRTP classes ike `PyConcreteValue`, `PyConcreteType`,
`PyConcreteAttribute`, as well as other useful code like `Defaulting*`
and etc enabling their reuse in downstream projects. Downstream projects
can now do
```c++
struct PyTestType : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType<PyTestType> {
...
};
class PyTestAttr : public mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteAttribute<PyTestAttr> {
...
}
NB_MODULE(_mlirPythonTestNanobind, m) {
PyTestType::bind(m);
PyTestAttr::bind(m);
}
```
instead of using the discordant alternative
`mlir_type_subclass`/`mlir_attr_subclass` (same goes for
`PyConcreteValue`/`mlir_value_subclass`).
# Why
This PR is mostly code motion (along with CMake) but before I describe
the changes I want to state the goals/benefits:
1. Currently upstream "core" extensions and "dialect" extensions ([all
of the `Dialect*` extensions
here](d7c734b5a1/mlir/lib/Bindings/Python))
are a two-tier system;
**a**. [core
extensions](https://github.com/llvm/llvm-project/blob/main/mlir/lib/Bindings/Python/IRTypes.cpp#L361)
enjoy first class support as far as type inference[^3], type stub
generation, and ease of implementation, while dialect extensions [have
poorer support](https://reviews.llvm.org/D150927), incorrect type stub
generation much more tedious (boilerplate) implementation;
**b**. Crucially, this two-tiered system is reflected in the fact that
**the two sets of types/attributes are not in the same Python object
hierarchy**. To wit: `isinstance(..., Type)` and `isinstance(...,
Attribute)` are not supported for the dialect extensions[^2];
**c**. Since these types are not exposed in public headers, downstream
users (dialect extensions or not) cannot write functions that overload
on e.g. `PyFloat8*Type` - that's quite a [useful
feature](fdbee98df8/cpp_ext/TorchOps.cpp (L29-L69))!
2. The dialect extensions incur a sizeable performance penalty relative
to the core extensions in that every single trip across the wire (either
`python->cpp` or `cpp->python`) requires work in addition to nanobind's
own casting/construction pipeline;
**a**. When going from `python->cpp`, [we extract the capsule object
from the Python
object](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h#L219C24-L219C46)
and then extract from the capsule the `Mlir*` opaque struct/ptr. This
side isn't so onerous;
**b**. When going from `cpp->python` we call long-hand call Python
`import` APIs and construct the Python object using `_CAPICreate`. Note,
there at least 2 `attr` calls incurred in addition to `_CAPICreate`;
this is already much more [efficiently handled by nanobind
itself](4ba51fcf79/src/nb_internals.h (L381-L382))!
3. This division blocks various features: in some configurations[^1] we
trigger a circular import bug because "dialect" types and attributes
perform an [import of the root `_mlir`
module](bd9651bf78/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h (L585))
when they are created (the types themselves, not even instances of those
types). This blocks type stub generation for dialect extensions (i.e.,
the reason we currently only generate type stubs for `_mlir`).
# How
Prior this was not done/possible because of "ODR" issues but I have
resolved those issues; the basic idea for how we solve this is "move
things we want to share into shared libraries":
1. Move IRCore (stuff like `PyConcreteValue`, `PyConcreteType`,
`PyConcreteAttribute`) into `MLIRPythonSupport`;
- Note, we move the rest of the things in `IRModule.h` (renamed to
`IRCore.h`) because `PyConcreteValue`, `PyConcreteType`,
`PyConcreteAttribute` depend on them. This makes for a bigger PR than
one would hope for but ultimately I think we should give people access
to these classes to use as they see fit (specifically inherit from, but
also liberally use in bindings signatures instead of the opaque `Mlir*`
struct wrappers).
2. Put all of this code into a nested namespace
`MLIR_BINDINGS_PYTHON_DOMAIN` which is determined by a compile time
define (and tied to `MLIR_BINDINGS_PYTHON_NB_DOMAIN`). This is necessary
in order to prevent conflicts on both symbol name **and** typeid
(necessary for nanobind to not double register binded types) between
multiple bindings libraries (e.g., `torch-mlir`, and `jax`). Note
[nanobind doesn't support `module_local` like
pybind11](https://nanobind.readthedocs.io/en/latest/porting.html#removed-features).
It does support `NB_DOMAIN` but that is not sufficient for
disambiguating typeids across projects (to wit: we currently define
`NB_DOMAIN` and it was still necessary to move everything to a nested
namespace);
3. Build the [nanobind library itself as a shared
object](https://github.com/wjakob/nanobind/blob/master/cmake/nanobind-config.cmake#L127)
(and link it to both the extensions and `MLIRPythonSupport`).
4. CMake to make this work, in-tree, out-of-tree, downstream, upstream,
etc.
# Testing
Three tests are added here
1. `PythonTestModuleNanobind` is ported to use
`PyConcreteType<PyTestType>` instead of `mlir_type_subclass` and
`PyConcreteAttribute<PyTestAttr>` instead of `mlir_atrr_subclass`,
verifying this works for non-core extensions in-tree;
2. `StandaloneExtensionNanobind` is ported to use `struct PyCustomType :
mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType<PyCustomType>`
instead of `mlir_type_subclass` verifying this works for non-core
extensions out-of-tree;
3. `StandaloneExtensionNanobind`'s `smoketest` is extended to also load
another bindings package (namely `mlir`) verifying
`MLIR_BINDINGS_PYTHON_DOMAIN` successfully disambiguates symbols and
typeids.
I have also tested this downstream:
https://github.com/llvm/eudsl/pull/287 as well run the following builder
bots:
mlir-nvidia-gcc7:
https://lab.llvm.org/buildbot/#/buildrequests/6654424?redirect_to_build=true
I have also tested against IREE:
https://github.com/iree-org/iree/pull/21916
# Integration
It is highly recommended to set the CMake var
`MLIR_BINDINGS_PYTHON_NB_DOMAIN` (which will also determine
`MLIR_BINDINGS_PYTHON_DOMAIN`) to something unique for each downstream.
This can also be passed explicitly to `add_mlir_python_modules` if your
project builds multiple bindings packages. I added a `WARNING` to this
effect in `AddMLIRPython.cmake`.
[^3]: Python values being typed correctly when exiting from cpp;
[^1]: Specifically when the modules are imported using `importlib`,
which occurs with nanobind's
[stubgen](https://github.com/wjakob/nanobind/blob/master/src/stubgen.py#L965);
[^2]: The workaround we implemented was a class method for the dialect
bindings called `Class.isinstance(...)`;
1036 lines
42 KiB
Python
1036 lines
42 KiB
Python
# RUN: %PYTHON %s | FileCheck %s
|
|
import sys
|
|
import typing
|
|
from typing import Union, Optional
|
|
|
|
from mlir.ir import *
|
|
import mlir.dialects.func as func
|
|
import mlir.dialects.python_test as test
|
|
import mlir.dialects.tensor as tensor
|
|
import mlir.dialects.arith as arith
|
|
|
|
from mlir._mlir_libs._mlirPythonTestNanobind import (
|
|
TestAttr,
|
|
TestType,
|
|
TestTensorValue,
|
|
TestIntegerRankedTensorType,
|
|
)
|
|
|
|
test.register_python_test_dialect(get_dialect_registry())
|
|
|
|
|
|
def run(f):
|
|
print("\nTEST:", f.__name__)
|
|
f()
|
|
return f
|
|
|
|
|
|
# CHECK-LABEL: TEST: testAttributes
|
|
@run
|
|
def testAttributes():
|
|
with Context() as ctx, Location.unknown():
|
|
#
|
|
# Check op construction with attributes.
|
|
#
|
|
|
|
i32 = IntegerType.get_signless(32)
|
|
one = IntegerAttr.get(i32, 1)
|
|
two = IntegerAttr.get(i32, 2)
|
|
unit = UnitAttr.get()
|
|
|
|
# CHECK: python_test.attributed_op {
|
|
# CHECK-DAG: mandatory_i32 = 1 : i32
|
|
# CHECK-DAG: optional_i32 = 2 : i32
|
|
# CHECK-DAG: unit
|
|
# CHECK: }
|
|
op = test.AttributedOp(one, optional_i32=two, unit=unit)
|
|
print(f"{op}")
|
|
|
|
# CHECK: python_test.attributed_op {
|
|
# CHECK: mandatory_i32 = 2 : i32
|
|
# CHECK: }
|
|
op2 = test.AttributedOp(two)
|
|
print(f"{op2}")
|
|
|
|
#
|
|
# Check generic "attributes" access and mutation.
|
|
#
|
|
|
|
assert "additional" not in op.attributes
|
|
|
|
# CHECK: python_test.attributed_op {
|
|
# CHECK-DAG: additional = 1 : i32
|
|
# CHECK-DAG: mandatory_i32 = 2 : i32
|
|
# CHECK: }
|
|
op2.attributes["additional"] = one
|
|
print(f"{op2}")
|
|
|
|
# CHECK: python_test.attributed_op {
|
|
# CHECK-DAG: additional = 2 : i32
|
|
# CHECK-DAG: mandatory_i32 = 2 : i32
|
|
# CHECK: }
|
|
op2.attributes["additional"] = two
|
|
print(f"{op2}")
|
|
|
|
# CHECK: python_test.attributed_op {
|
|
# CHECK-NOT: additional = 2 : i32
|
|
# CHECK: mandatory_i32 = 2 : i32
|
|
# CHECK: }
|
|
del op2.attributes["additional"]
|
|
print(f"{op2}")
|
|
|
|
try:
|
|
print(op.attributes["additional"])
|
|
except KeyError:
|
|
pass
|
|
else:
|
|
assert False, "expected KeyError on unknown attribute key"
|
|
|
|
#
|
|
# Check accessors to defined attributes.
|
|
#
|
|
|
|
# CHECK: Mandatory: 1
|
|
# CHECK: Optional: 2
|
|
# CHECK: Unit: True
|
|
print(f"Mandatory: {op.mandatory_i32.value}")
|
|
print(f"Optional: {op.optional_i32.value}")
|
|
print(f"Unit: {op.unit}")
|
|
|
|
# CHECK: Mandatory: 2
|
|
# CHECK: Optional: None
|
|
# CHECK: Unit: False
|
|
print(f"Mandatory: {op2.mandatory_i32.value}")
|
|
print(f"Optional: {op2.optional_i32}")
|
|
print(f"Unit: {op2.unit}")
|
|
|
|
# CHECK: Mandatory: 2
|
|
# CHECK: Optional: None
|
|
# CHECK: Unit: False
|
|
op.mandatory_i32 = two
|
|
op.optional_i32 = None
|
|
op.unit = False
|
|
print(f"Mandatory: {op.mandatory_i32.value}")
|
|
print(f"Optional: {op.optional_i32}")
|
|
print(f"Unit: {op.unit}")
|
|
assert "optional_i32" not in op.attributes
|
|
assert "unit" not in op.attributes
|
|
|
|
try:
|
|
op.mandatory_i32 = None
|
|
except ValueError:
|
|
pass
|
|
else:
|
|
assert False, "expected ValueError on setting a mandatory attribute to None"
|
|
|
|
# CHECK: Optional: 2
|
|
op.optional_i32 = two
|
|
print(f"Optional: {op.optional_i32.value}")
|
|
|
|
# CHECK: Optional: None
|
|
del op.optional_i32
|
|
print(f"Optional: {op.optional_i32}")
|
|
|
|
# CHECK: Unit: False
|
|
op.unit = None
|
|
print(f"Unit: {op.unit}")
|
|
assert "unit" not in op.attributes
|
|
|
|
# CHECK: Unit: True
|
|
op.unit = True
|
|
print(f"Unit: {op.unit}")
|
|
|
|
# CHECK: Unit: False
|
|
del op.unit
|
|
print(f"Unit: {op.unit}")
|
|
|
|
|
|
# CHECK-LABEL: TEST: attrBuilder
|
|
@run
|
|
def attrBuilder():
|
|
with Context() as ctx, Location.unknown():
|
|
# CHECK: python_test.attributes_op
|
|
op = test.AttributesOp(
|
|
# CHECK-DAG: x_affinemap = affine_map<() -> (2)>
|
|
x_affinemap=AffineMap.get_constant(2),
|
|
# CHECK-DAG: x_affinemaparr = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
|
|
x_affinemaparr=[AffineMap.get_identity(3)],
|
|
# CHECK-DAG: x_arr = [true, "x"]
|
|
x_arr=[BoolAttr.get(True), StringAttr.get("x")],
|
|
x_boolarr=[False, True], # CHECK-DAG: x_boolarr = [false, true]
|
|
x_bool=True, # CHECK-DAG: x_bool = true
|
|
x_dboolarr=[True, False], # CHECK-DAG: x_dboolarr = array<i1: true, false>
|
|
x_df16arr=[21, 22], # CHECK-DAG: x_df16arr = array<i16: 21, 22>
|
|
# CHECK-DAG: x_df32arr = array<f32: 2.300000e+01, 2.400000e+01>
|
|
x_df32arr=[23, 24],
|
|
# CHECK-DAG: x_df64arr = array<f64: 2.500000e+01, 2.600000e+01>
|
|
x_df64arr=[25, 26],
|
|
x_di32arr=[0, 1], # CHECK-DAG: x_di32arr = array<i32: 0, 1>
|
|
# CHECK-DAG: x_di64arr = array<i64: 1, 2>
|
|
x_di64arr=[1, 2],
|
|
x_di8arr=[2, 3], # CHECK-DAG: x_di8arr = array<i8: 2, 3>
|
|
# CHECK-DAG: x_dictarr = [{a = false}]
|
|
x_dictarr=[{"a": BoolAttr.get(False)}],
|
|
x_dict={"b": BoolAttr.get(True)}, # CHECK-DAG: x_dict = {b = true}
|
|
x_f32=-2.25, # CHECK-DAG: x_f32 = -2.250000e+00 : f32
|
|
# CHECK-DAG: x_f32arr = [2.000000e+00 : f32, 3.000000e+00 : f32]
|
|
x_f32arr=[2.0, 3.0],
|
|
x_f64=4.25, # CHECK-DAG: x_f64 = 4.250000e+00 : f64
|
|
x_f64arr=[4.0, 8.0], # CHECK-DAG: x_f64arr = [4.000000e+00, 8.000000e+00]
|
|
# CHECK-DAG: x_f64elems = dense<[8.000000e+00, 1.600000e+01]> : tensor<2xf64>
|
|
x_f64elems=[8.0, 16.0],
|
|
# CHECK-DAG: x_flatsymrefarr = [@symbol1, @symbol2]
|
|
x_flatsymrefarr=["symbol1", "symbol2"],
|
|
x_flatsymref="symbol3", # CHECK-DAG: x_flatsymref = @symbol3
|
|
x_i1=0, # CHECK-DAG: x_i1 = false
|
|
x_i16=42, # CHECK-DAG: x_i16 = 42 : i16
|
|
x_i32=6, # CHECK-DAG: x_i32 = 6 : i32
|
|
x_i32arr=[4, 5], # CHECK-DAG: x_i32arr = [4 : i32, 5 : i32]
|
|
x_i32elems=[5, 6], # CHECK-DAG: x_i32elems = dense<[5, 6]> : tensor<2xi32>
|
|
x_i64=9, # CHECK-DAG: x_i64 = 9 : i64
|
|
x_i64arr=[7, 8], # CHECK-DAG: x_i64arr = [7, 8]
|
|
x_i64elems=[8, 9], # CHECK-DAG: x_i64elems = dense<[8, 9]> : tensor<2xi64>
|
|
x_i64svecarr=[10, 11], # CHECK-DAG: x_i64svecarr = [10, 11]
|
|
x_i8=11, # CHECK-DAG: x_i8 = 11 : i8
|
|
x_idx=10, # CHECK-DAG: x_idx = 10 : index
|
|
# CHECK-DAG: x_idxelems = dense<[11, 12]> : tensor<2xindex>
|
|
x_idxelems=[11, 12],
|
|
# CHECK-DAG: x_idxlistarr = [{{\[}}13], [14, 15]]
|
|
x_idxlistarr=[[13], [14, 15]],
|
|
x_si1=-1, # CHECK-DAG: x_si1 = -1 : si1
|
|
x_si16=-2, # CHECK-DAG: x_si16 = -2 : si16
|
|
x_si32=-3, # CHECK-DAG: x_si32 = -3 : si32
|
|
x_si64=-123, # CHECK-DAG: x_si64 = -123 : si64
|
|
x_si8=-4, # CHECK-DAG: x_si8 = -4 : si8
|
|
x_strarr=["hello", "world"], # CHECK-DAG: x_strarr = ["hello", "world"]
|
|
x_str="hello world!", # CHECK-DAG: x_str = "hello world!"
|
|
# CHECK-DAG: x_symrefarr = [@flatsym, @deep::@sym]
|
|
x_symrefarr=["flatsym", ["deep", "sym"]],
|
|
x_symref=["deep", "sym2"], # CHECK-DAG: x_symref = @deep::@sym2
|
|
x_sym="symbol", # CHECK-DAG: x_sym = "symbol"
|
|
x_typearr=[F32Type.get()], # CHECK-DAG: x_typearr = [f32]
|
|
x_type=F64Type.get(), # CHECK-DAG: x_type = f64
|
|
x_ui1=1, # CHECK-DAG: x_ui1 = 1 : ui1
|
|
x_ui16=2, # CHECK-DAG: x_ui16 = 2 : ui16
|
|
x_ui32=3, # CHECK-DAG: x_ui32 = 3 : ui32
|
|
x_ui64=4, # CHECK-DAG: x_ui64 = 4 : ui64
|
|
x_ui8=5, # CHECK-DAG: x_ui8 = 5 : ui8
|
|
x_unit=True, # CHECK-DAG: x_unit
|
|
)
|
|
op.verify()
|
|
op.print(use_local_scope=True)
|
|
|
|
# fmt: off
|
|
assert typing.get_type_hints(test.AttributesOp.x_affinemaparr.fset)["value"] is ArrayAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_affinemaparr.fget)["return"] is ArrayAttr
|
|
assert type(op.x_affinemaparr) is typing.get_type_hints(test.AttributesOp.x_affinemaparr.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_affinemap.fset)["value"] is AffineMapAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_affinemap.fget)["return"] is AffineMapAttr
|
|
assert type(op.x_affinemap) is typing.get_type_hints(test.AttributesOp.x_affinemap.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_arr.fset)["value"] is ArrayAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_arr.fget)["return"] is ArrayAttr
|
|
assert type(op.x_arr) is typing.get_type_hints(test.AttributesOp.x_arr.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_boolarr.fset)["value"] is ArrayAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_boolarr.fget)["return"] is ArrayAttr
|
|
assert type(op.x_boolarr) is typing.get_type_hints(test.AttributesOp.x_boolarr.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_bool.fset)["value"] is BoolAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_bool.fget)["return"] is BoolAttr
|
|
assert type(op.x_bool) is typing.get_type_hints(test.AttributesOp.x_bool.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_dboolarr.fset)["value"] is DenseBoolArrayAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_dboolarr.fget)["return"] is DenseBoolArrayAttr
|
|
assert type(op.x_dboolarr) is typing.get_type_hints(test.AttributesOp.x_dboolarr.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_df32arr.fset)["value"] is DenseF32ArrayAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_df32arr.fget)["return"] is DenseF32ArrayAttr
|
|
assert type(op.x_df32arr) is typing.get_type_hints(test.AttributesOp.x_df32arr.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_df64arr.fset)["value"] is DenseF64ArrayAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_df64arr.fget)["return"] is DenseF64ArrayAttr
|
|
assert type(op.x_df64arr) is typing.get_type_hints(test.AttributesOp.x_df64arr.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_df16arr.fset)["value"] is DenseI16ArrayAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_df16arr.fget)["return"] is DenseI16ArrayAttr
|
|
assert type(op.x_df16arr) is typing.get_type_hints(test.AttributesOp.x_df16arr.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_di32arr.fset)["value"] is DenseI32ArrayAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_di32arr.fget)["return"] is DenseI32ArrayAttr
|
|
assert type(op.x_di32arr) is typing.get_type_hints(test.AttributesOp.x_di32arr.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_di64arr.fset)["value"] is DenseI64ArrayAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_di64arr.fget)["return"] is DenseI64ArrayAttr
|
|
assert type(op.x_di64arr) is typing.get_type_hints(test.AttributesOp.x_di64arr.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_di8arr.fset)["value"] is DenseI8ArrayAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_di8arr.fget)["return"] is DenseI8ArrayAttr
|
|
assert type(op.x_di8arr) is typing.get_type_hints(test.AttributesOp.x_di8arr.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_dictarr.fset)["value"] is ArrayAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_dictarr.fget)["return"] is ArrayAttr
|
|
assert type(op.x_dictarr) is typing.get_type_hints(test.AttributesOp.x_dictarr.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_dict.fset)["value"] is DictAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_dict.fget)["return"] is DictAttr
|
|
assert type(op.x_dict) is typing.get_type_hints(test.AttributesOp.x_dict.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_f32arr.fset)["value"] is ArrayAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_f32arr.fget)["return"] is ArrayAttr
|
|
assert type(op.x_f32arr) is typing.get_type_hints(test.AttributesOp.x_f32arr.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_f32.fset)["value"] is FloatAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_f32.fget)["return"] is FloatAttr
|
|
assert type(op.x_f32) is typing.get_type_hints(test.AttributesOp.x_f32.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_f64arr.fset)["value"] is ArrayAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_f64arr.fget)["return"] is ArrayAttr
|
|
assert type(op.x_f64arr) is typing.get_type_hints(test.AttributesOp.x_f64arr.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_f64.fset)["value"] is FloatAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_f64.fget)["return"] is FloatAttr
|
|
assert type(op.x_f64) is typing.get_type_hints(test.AttributesOp.x_f64.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_f64elems.fset)["value"] is DenseFPElementsAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_f64elems.fget)["return"] is DenseFPElementsAttr
|
|
assert type(op.x_f64elems) is typing.get_type_hints(test.AttributesOp.x_f64elems.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_flatsymrefarr.fset)["value"] is ArrayAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_flatsymrefarr.fget)["return"] is ArrayAttr
|
|
assert type(op.x_flatsymrefarr) is typing.get_type_hints(test.AttributesOp.x_flatsymrefarr.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_flatsymref.fset)["value"] is FlatSymbolRefAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_flatsymref.fget)["return"] is FlatSymbolRefAttr
|
|
assert type(op.x_flatsymref) is typing.get_type_hints(test.AttributesOp.x_flatsymref.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_i16.fset)["value"] is IntegerAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_i16.fget)["return"] is IntegerAttr
|
|
assert type(op.x_i16) is typing.get_type_hints(test.AttributesOp.x_i16.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_i1.fset)["value"] is BoolAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_i1.fget)["return"] is BoolAttr
|
|
assert type(op.x_i1) is typing.get_type_hints(test.AttributesOp.x_i1.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_i32arr.fset)["value"] is ArrayAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_i32arr.fget)["return"] is ArrayAttr
|
|
assert type(op.x_i32arr) is typing.get_type_hints(test.AttributesOp.x_i32arr.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_i32.fset)["value"] is IntegerAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_i32.fget)["return"] is IntegerAttr
|
|
assert type(op.x_i32) is typing.get_type_hints(test.AttributesOp.x_i32.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_i32elems.fset)["value"] is DenseIntElementsAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_i32elems.fget)["return"] is DenseIntElementsAttr
|
|
assert type(op.x_i32elems) is typing.get_type_hints(test.AttributesOp.x_i32elems.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_i64arr.fset)["value"] is ArrayAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_i64arr.fget)["return"] is ArrayAttr
|
|
assert type(op.x_i64arr) is typing.get_type_hints(test.AttributesOp.x_i64arr.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_i64.fset)["value"] is IntegerAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_i64.fget)["return"] is IntegerAttr
|
|
assert type(op.x_i64) is typing.get_type_hints(test.AttributesOp.x_i64.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_i64elems.fset)["value"] is DenseIntElementsAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_i64elems.fget)["return"] is DenseIntElementsAttr
|
|
assert type(op.x_i64elems) is typing.get_type_hints(test.AttributesOp.x_i64elems.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_i64svecarr.fset)["value"] is ArrayAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_i64svecarr.fget)["return"] is ArrayAttr
|
|
assert type(op.x_i64svecarr) is typing.get_type_hints(test.AttributesOp.x_i64svecarr.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_i8.fset)["value"] is IntegerAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_i8.fget)["return"] is IntegerAttr
|
|
assert type(op.x_i8) is typing.get_type_hints(test.AttributesOp.x_i8.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_idx.fset)["value"] is IntegerAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_idx.fget)["return"] is IntegerAttr
|
|
assert type(op.x_idx) is typing.get_type_hints(test.AttributesOp.x_idx.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_idxelems.fset)["value"] is DenseIntElementsAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_idxelems.fget)["return"] is DenseIntElementsAttr
|
|
assert type(op.x_idxelems) is typing.get_type_hints(test.AttributesOp.x_idxelems.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_idxlistarr.fset)["value"] is ArrayAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_idxlistarr.fget)["return"] is ArrayAttr
|
|
assert type(op.x_idxlistarr) is typing.get_type_hints(test.AttributesOp.x_idxlistarr.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_si16.fset)["value"] is IntegerAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_si16.fget)["return"] is IntegerAttr
|
|
assert type(op.x_si16) is typing.get_type_hints(test.AttributesOp.x_si16.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_si1.fset)["value"] is IntegerAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_si1.fget)["return"] is IntegerAttr
|
|
assert type(op.x_si1) is typing.get_type_hints(test.AttributesOp.x_si1.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_si32.fset)["value"] is IntegerAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_si32.fget)["return"] is IntegerAttr
|
|
assert type(op.x_si32) is typing.get_type_hints(test.AttributesOp.x_si32.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_si64.fset)["value"] is IntegerAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_si64.fget)["return"] is IntegerAttr
|
|
assert type(op.x_si64) is typing.get_type_hints(test.AttributesOp.x_si64.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_si8.fset)["value"] is IntegerAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_si8.fget)["return"] is IntegerAttr
|
|
assert type(op.x_si8) is typing.get_type_hints(test.AttributesOp.x_si8.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_strarr.fset)["value"] is ArrayAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_strarr.fget)["return"] is ArrayAttr
|
|
assert type(op.x_strarr) is typing.get_type_hints(test.AttributesOp.x_strarr.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_str.fset)["value"] is StringAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_str.fget)["return"] is StringAttr
|
|
assert type(op.x_str) is typing.get_type_hints(test.AttributesOp.x_str.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_sym.fset)["value"] is StringAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_sym.fget)["return"] is StringAttr
|
|
assert type(op.x_sym) is typing.get_type_hints(test.AttributesOp.x_sym.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_symrefarr.fset)["value"] is ArrayAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_symrefarr.fget)["return"] is ArrayAttr
|
|
assert type(op.x_symrefarr) is typing.get_type_hints(test.AttributesOp.x_symrefarr.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_symref.fset)["value"] is SymbolRefAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_symref.fget)["return"] is SymbolRefAttr
|
|
assert type(op.x_symref) is typing.get_type_hints(test.AttributesOp.x_symref.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_typearr.fset)["value"] is ArrayAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_typearr.fget)["return"] is ArrayAttr
|
|
assert type(op.x_typearr) is typing.get_type_hints(test.AttributesOp.x_typearr.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_type.fset)["value"] is TypeAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_type.fget)["return"] is TypeAttr
|
|
assert type(op.x_type) is typing.get_type_hints(test.AttributesOp.x_type.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_ui16.fset)["value"] is IntegerAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_ui16.fget)["return"] is IntegerAttr
|
|
assert type(op.x_ui16) is typing.get_type_hints(test.AttributesOp.x_ui16.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_ui1.fset)["value"] is IntegerAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_ui1.fget)["return"] is IntegerAttr
|
|
assert type(op.x_ui1) is typing.get_type_hints(test.AttributesOp.x_ui1.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_ui32.fset)["value"] is IntegerAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_ui32.fget)["return"] is IntegerAttr
|
|
assert type(op.x_ui32) is typing.get_type_hints(test.AttributesOp.x_ui32.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_ui64.fset)["value"] is IntegerAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_ui64.fget)["return"] is IntegerAttr
|
|
assert type(op.x_ui64) is typing.get_type_hints(test.AttributesOp.x_ui64.fget)["return"]
|
|
|
|
assert typing.get_type_hints(test.AttributesOp.x_ui8.fset)["value"] is IntegerAttr
|
|
assert typing.get_type_hints(test.AttributesOp.x_ui8.fget)["return"] is IntegerAttr
|
|
assert type(op.x_ui8) is typing.get_type_hints(test.AttributesOp.x_ui8.fget)["return"]
|
|
# fmt: on
|
|
|
|
|
|
# CHECK-LABEL: TEST: inferReturnTypes
|
|
@run
|
|
def inferReturnTypes():
|
|
with Context() as ctx, Location.unknown(ctx):
|
|
module = Module.create()
|
|
with InsertionPoint(module.body):
|
|
op = test.InferResultsOp()
|
|
dummy = test.DummyOp()
|
|
|
|
# CHECK: [Type(i32), Type(i64)]
|
|
iface = InferTypeOpInterface(op)
|
|
print(iface.inferReturnTypes())
|
|
|
|
# CHECK: [Type(i32), Type(i64)]
|
|
iface_static = InferTypeOpInterface(test.InferResultsOp)
|
|
print(iface.inferReturnTypes())
|
|
|
|
assert isinstance(iface.opview, test.InferResultsOp)
|
|
assert iface.opview == iface.operation.opview
|
|
|
|
try:
|
|
iface_static.opview
|
|
except TypeError:
|
|
pass
|
|
else:
|
|
assert False, (
|
|
"not expected to be able to obtain an opview from a static" " interface"
|
|
)
|
|
|
|
try:
|
|
InferTypeOpInterface(dummy)
|
|
except ValueError:
|
|
pass
|
|
else:
|
|
assert False, "not expected dummy op to implement the interface"
|
|
|
|
try:
|
|
InferTypeOpInterface(test.DummyOp)
|
|
except ValueError:
|
|
pass
|
|
else:
|
|
assert False, "not expected dummy op class to implement the interface"
|
|
|
|
|
|
# CHECK-LABEL: TEST: resultTypesDefinedByTraits
|
|
@run
|
|
def resultTypesDefinedByTraits():
|
|
with Context() as ctx, Location.unknown(ctx):
|
|
module = Module.create()
|
|
with InsertionPoint(module.body):
|
|
inferred = test.InferResultsOp()
|
|
|
|
# CHECK: i32 i64
|
|
print(inferred.single.type, inferred.doubled.type)
|
|
|
|
same = test.SameOperandAndResultTypeOp([inferred.results[0]])
|
|
# CHECK-COUNT-2: i32
|
|
print(same.one.type)
|
|
print(same.two.type)
|
|
assert (
|
|
typing.get_type_hints(test.SameOperandAndResultTypeOp.one.fget)[
|
|
"return"
|
|
]
|
|
is OpResult
|
|
)
|
|
assert type(same.one) is OpResult
|
|
|
|
first_type_attr = test.FirstAttrDeriveTypeAttrOp(
|
|
inferred.results[1], TypeAttr.get(IndexType.get())
|
|
)
|
|
# CHECK-COUNT-2: index
|
|
print(first_type_attr.one.type)
|
|
print(first_type_attr.two.type)
|
|
|
|
first_attr = test.FirstAttrDeriveAttrOp(FloatAttr.get(F32Type.get(), 3.14))
|
|
# CHECK-COUNT-3: f32
|
|
print(first_attr.one.type)
|
|
print(first_attr.two.type)
|
|
print(first_attr.three.type)
|
|
|
|
implied = test.InferResultsImpliedOp()
|
|
# CHECK: i32
|
|
print(implied.integer.type)
|
|
# CHECK: f64
|
|
print(implied.flt.type)
|
|
# CHECK: index
|
|
print(implied.index.type)
|
|
|
|
# provide the result types to avoid inferring them
|
|
f64 = F64Type.get()
|
|
no_imply = test.InferResultsImpliedOp(results=[f64, f64, f64])
|
|
# CHECK-COUNT-3: f64
|
|
print(no_imply.integer.type, no_imply.flt.type, no_imply.index.type)
|
|
|
|
no_infer = test.InferResultsOp(results=[F32Type.get(), IndexType.get()])
|
|
# CHECK: f32 index
|
|
print(no_infer.single.type, no_infer.doubled.type)
|
|
|
|
|
|
# CHECK-LABEL: TEST: testOptionalOperandOp
|
|
@run
|
|
def testOptionalOperandOp():
|
|
with Context() as ctx, Location.unknown():
|
|
module = Module.create()
|
|
with InsertionPoint(module.body):
|
|
op1 = test.OptionalOperandOp()
|
|
# CHECK: op1.input is None: True
|
|
print(f"op1.input is None: {op1.input is None}")
|
|
assert (
|
|
typing.get_type_hints(test.OptionalOperandOp.input.fget)["return"]
|
|
is Optional[Value]
|
|
)
|
|
assert (
|
|
typing.get_type_hints(test.OptionalOperandOp.result.fget)["return"]
|
|
== OpResult[IntegerType]
|
|
)
|
|
assert type(op1.result) is OpResult
|
|
|
|
op2 = test.OptionalOperandOp(input=op1)
|
|
# CHECK: op2.input is None: False
|
|
print(f"op2.input is None: {op2.input is None}")
|
|
|
|
|
|
# CHECK-LABEL: TEST: testCustomAttribute
|
|
@run
|
|
def testCustomAttribute():
|
|
with Context() as ctx, Location.unknown():
|
|
a = TestAttr.get()
|
|
# CHECK: #python_test.test_attr
|
|
print(a)
|
|
|
|
# CHECK: python_test.custom_attributed_op {
|
|
# CHECK: #python_test.test_attr
|
|
# CHECK: }
|
|
op2 = test.CustomAttributedOp(a)
|
|
print(f"{op2}")
|
|
|
|
# CHECK: #python_test.test_attr
|
|
print(f"{op2.test_attr}")
|
|
|
|
# CHECK: TestAttr(#python_test.test_attr)
|
|
print(repr(op2.test_attr))
|
|
|
|
# The following cast must not assert.
|
|
b = TestAttr(a)
|
|
|
|
unit = UnitAttr.get()
|
|
try:
|
|
TestAttr(unit)
|
|
except ValueError as e:
|
|
assert "Cannot cast attribute to TestAttr" in str(e)
|
|
else:
|
|
raise
|
|
|
|
# The following must trigger a TypeError from our adaptors and must not
|
|
# crash.
|
|
try:
|
|
TestAttr(42)
|
|
except TypeError as e:
|
|
assert (
|
|
"__init__(): incompatible function arguments. The following argument types are supported"
|
|
in str(e)
|
|
)
|
|
assert (
|
|
"__init__(self, cast_from_attr: mlir._mlir_libs._mlir.ir.Attribute) -> None"
|
|
in str(e)
|
|
)
|
|
assert (
|
|
"Invoked with types: mlir._mlir_libs._mlirPythonTestNanobind.TestAttr, int"
|
|
in str(e)
|
|
)
|
|
else:
|
|
raise
|
|
|
|
# The following must trigger a TypeError from pybind (therefore, not
|
|
# checking its message) and must not crash.
|
|
try:
|
|
TestAttr(42, 56)
|
|
except TypeError:
|
|
pass
|
|
else:
|
|
raise
|
|
|
|
|
|
@run
|
|
def testCustomType():
|
|
with Context() as ctx:
|
|
a = TestType.get()
|
|
# CHECK: !python_test.test_type
|
|
print(a)
|
|
|
|
# The following cast must not assert.
|
|
b = TestType(a)
|
|
# Instance custom types should have typeids
|
|
assert isinstance(b.typeid, TypeID)
|
|
|
|
i8 = IntegerType.get_signless(8)
|
|
try:
|
|
TestType(i8)
|
|
except ValueError as e:
|
|
assert "Cannot cast type to TestType" in str(e)
|
|
else:
|
|
raise
|
|
|
|
# The following must trigger a TypeError from our adaptors and must not
|
|
# crash.
|
|
try:
|
|
TestType(42)
|
|
except TypeError as e:
|
|
assert (
|
|
"__init__(): incompatible function arguments. The following argument types are supported"
|
|
in str(e)
|
|
)
|
|
assert (
|
|
"__init__(self, cast_from_type: mlir._mlir_libs._mlir.ir.Type) -> None"
|
|
in str(e)
|
|
)
|
|
assert (
|
|
"Invoked with types: mlir._mlir_libs._mlirPythonTestNanobind.TestType, int"
|
|
in str(e)
|
|
)
|
|
else:
|
|
raise
|
|
|
|
# The following must trigger a TypeError from pybind (therefore, not
|
|
# checking its message) and must not crash.
|
|
try:
|
|
TestType(42, 56)
|
|
except TypeError:
|
|
pass
|
|
else:
|
|
raise
|
|
|
|
|
|
@run
|
|
# CHECK-LABEL: TEST: testValue
|
|
def testValue():
|
|
# Check that Value is a generic class at runtime.
|
|
assert hasattr(Value, "__class_getitem__")
|
|
|
|
|
|
@run
|
|
# CHECK-LABEL: TEST: testTensorValue
|
|
def testTensorValue():
|
|
with Context() as ctx, Location.unknown():
|
|
i8 = IntegerType.get_signless(8)
|
|
|
|
class Tensor(TestTensorValue):
|
|
def __str__(self):
|
|
return super().__str__().replace("Value", "Tensor")
|
|
|
|
module = Module.create()
|
|
with InsertionPoint(module.body):
|
|
t = tensor.EmptyOp([10, 10], i8).result
|
|
|
|
# CHECK: Value(%{{.*}} = tensor.empty() : tensor<10x10xi8>)
|
|
print(Value(t))
|
|
|
|
tt = Tensor(t)
|
|
# CHECK: Tensor(%{{.*}} = tensor.empty() : tensor<10x10xi8>)
|
|
print(tt)
|
|
|
|
# CHECK: False
|
|
print(tt.is_null())
|
|
|
|
# Classes of custom types that inherit from concrete types should have
|
|
# static_typeid
|
|
assert isinstance(TestIntegerRankedTensorType.static_typeid, TypeID)
|
|
# And it should be equal to the in-tree concrete type
|
|
assert TestIntegerRankedTensorType.static_typeid == t.type.typeid
|
|
|
|
d = tensor.EmptyOp([1, 2, 3], IntegerType.get_signless(5)).result
|
|
# CHECK: Value(%{{.*}} = tensor.empty() : tensor<1x2x3xi5>)
|
|
print(d)
|
|
# CHECK: TestTensorValue
|
|
print(repr(d))
|
|
|
|
|
|
# CHECK-LABEL: TEST: inferReturnTypeComponents
|
|
@run
|
|
def inferReturnTypeComponents():
|
|
with Context() as ctx, Location.unknown(ctx):
|
|
module = Module.create()
|
|
i32 = IntegerType.get_signless(32)
|
|
with InsertionPoint(module.body):
|
|
resultType = UnrankedTensorType.get(i32)
|
|
operandTypes = [
|
|
RankedTensorType.get([1, 3, 10, 10], i32),
|
|
UnrankedTensorType.get(i32),
|
|
]
|
|
f = func.FuncOp(
|
|
"test_inferReturnTypeComponents", (operandTypes, [resultType])
|
|
)
|
|
entry_block = Block.create_at_start(f.operation.regions[0], operandTypes)
|
|
with InsertionPoint(entry_block):
|
|
ranked_op = test.InferShapedTypeComponentsOp(
|
|
resultType, entry_block.arguments[0]
|
|
)
|
|
unranked_op = test.InferShapedTypeComponentsOp(
|
|
resultType, entry_block.arguments[1]
|
|
)
|
|
|
|
# CHECK: has rank: True
|
|
# CHECK: rank: 4
|
|
# CHECK: element type: i32
|
|
# CHECK: shape: [1, 3, 10, 10]
|
|
iface = InferShapedTypeOpInterface(ranked_op)
|
|
shaped_type_components = iface.inferReturnTypeComponents(
|
|
operands=[ranked_op.operand]
|
|
)[0]
|
|
print("has rank:", shaped_type_components.has_rank)
|
|
print("rank:", shaped_type_components.rank)
|
|
print("element type:", shaped_type_components.element_type)
|
|
print("shape:", shaped_type_components.shape)
|
|
|
|
# CHECK: has rank: False
|
|
# CHECK: rank: None
|
|
# CHECK: element type: i32
|
|
# CHECK: shape: None
|
|
iface = InferShapedTypeOpInterface(unranked_op)
|
|
shaped_type_components = iface.inferReturnTypeComponents(
|
|
operands=[unranked_op.operand]
|
|
)[0]
|
|
print("has rank:", shaped_type_components.has_rank)
|
|
print("rank:", shaped_type_components.rank)
|
|
print("element type:", shaped_type_components.element_type)
|
|
print("shape:", shaped_type_components.shape)
|
|
|
|
|
|
# CHECK-LABEL: TEST: testCustomTypeTypeCaster
|
|
@run
|
|
def testCustomTypeTypeCaster():
|
|
with Context() as ctx, Location.unknown():
|
|
a = TestType.get()
|
|
assert a.typeid is not None
|
|
|
|
b = Type.parse("!python_test.test_type")
|
|
# CHECK: !python_test.test_type
|
|
print(b)
|
|
# CHECK: TestType(!python_test.test_type)
|
|
print(repr(b))
|
|
|
|
c = TestIntegerRankedTensorType.get([10, 10], 5)
|
|
# CHECK: tensor<10x10xi5>
|
|
print(c)
|
|
# CHECK: TestIntegerRankedTensorType(tensor<10x10xi5>)
|
|
print(repr(c))
|
|
|
|
# CHECK: Type caster is already registered
|
|
try:
|
|
|
|
@register_type_caster(c.typeid)
|
|
def type_caster(pytype):
|
|
return TestIntegerRankedTensorType(pytype)
|
|
|
|
except RuntimeError as e:
|
|
print(e)
|
|
|
|
# python_test dialect registers a caster for RankedTensorType in its extension (pybind) module.
|
|
# So this one replaces that one (successfully). And then just to be sure we restore the original caster below.
|
|
@register_type_caster(c.typeid, replace=True)
|
|
def type_caster(pytype):
|
|
return RankedTensorType(pytype)
|
|
|
|
d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result
|
|
# CHECK: tensor<10x10xi5>
|
|
print(d.type)
|
|
# CHECK: ranked tensor type RankedTensorType(tensor<10x10xi5>)
|
|
print("ranked tensor type", repr(d.type))
|
|
|
|
@register_type_caster(c.typeid, replace=True)
|
|
def type_caster(pytype):
|
|
return TestIntegerRankedTensorType(pytype)
|
|
|
|
d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result
|
|
# CHECK: tensor<10x10xi5>
|
|
print(d.type)
|
|
# CHECK: TestIntegerRankedTensorType(tensor<10x10xi5>)
|
|
print(repr(d.type))
|
|
|
|
|
|
# CHECK-LABEL: TEST: testInferTypeOpInterface
|
|
@run
|
|
def testInferTypeOpInterface():
|
|
with Context() as ctx, Location.unknown(ctx):
|
|
module = Module.create()
|
|
with InsertionPoint(module.body):
|
|
i64 = IntegerType.get_signless(64)
|
|
zero = arith.ConstantOp(i64, 0)
|
|
|
|
one_operand = test.InferResultsVariadicInputsOp(single=zero, doubled=None)
|
|
# CHECK: i32
|
|
print(one_operand.result.type)
|
|
|
|
two_operands = test.InferResultsVariadicInputsOp(single=zero, doubled=zero)
|
|
# CHECK: f32
|
|
print(two_operands.result.type)
|
|
|
|
assert (
|
|
typing.get_type_hints(test.infer_results_variadic_inputs_op)["return"]
|
|
is OpResult
|
|
)
|
|
assert (
|
|
type(test.infer_results_variadic_inputs_op(single=zero, doubled=zero))
|
|
is OpResult
|
|
)
|
|
|
|
|
|
# CHECK-LABEL: TEST: testVariadicOperandAccess
|
|
@run
|
|
def testVariadicOperandAccess():
|
|
def values(lst):
|
|
return [str(e) for e in lst]
|
|
|
|
with Context() as ctx, Location.unknown(ctx):
|
|
module = Module.create()
|
|
with InsertionPoint(module.body):
|
|
i32 = IntegerType.get_signless(32)
|
|
zero = arith.ConstantOp(i32, 0)
|
|
one = arith.ConstantOp(i32, 1)
|
|
two = arith.ConstantOp(i32, 2)
|
|
three = arith.ConstantOp(i32, 3)
|
|
four = arith.ConstantOp(i32, 4)
|
|
|
|
variadic_operands = test.SameVariadicOperandSizeOp(
|
|
[zero, one], two, [three, four]
|
|
)
|
|
# CHECK: Value(%{{.*}} = arith.constant 2 : i32)
|
|
print(variadic_operands.non_variadic)
|
|
assert (
|
|
typing.get_type_hints(test.SameVariadicOperandSizeOp.non_variadic.fget)[
|
|
"return"
|
|
]
|
|
is Value
|
|
)
|
|
assert type(variadic_operands.non_variadic) is Value
|
|
|
|
# CHECK: ['Value(%{{.*}} = arith.constant 0 : i32)', 'Value(%{{.*}} = arith.constant 1 : i32)']
|
|
print(values(variadic_operands.variadic1))
|
|
assert (
|
|
typing.get_type_hints(test.SameVariadicOperandSizeOp.variadic1.fget)[
|
|
"return"
|
|
]
|
|
is OpOperandList
|
|
)
|
|
assert type(variadic_operands.variadic1) is OpOperandList
|
|
|
|
# CHECK: ['Value(%{{.*}} = arith.constant 3 : i32)', 'Value(%{{.*}} = arith.constant 4 : i32)']
|
|
print(values(variadic_operands.variadic2))
|
|
assert type(variadic_operands.variadic2) is OpOperandList
|
|
|
|
assert (
|
|
typing.get_type_hints(test.same_variadic_operand)["return"]
|
|
is test.SameVariadicOperandSizeOp
|
|
)
|
|
assert (
|
|
type(test.same_variadic_operand([zero, one], two, [three, four]))
|
|
is test.SameVariadicOperandSizeOp
|
|
)
|
|
|
|
|
|
# CHECK-LABEL: TEST: testVariadicResultAccess
|
|
@run
|
|
def testVariadicResultAccess():
|
|
def types(lst):
|
|
return [e.type for e in lst]
|
|
|
|
with Context() as ctx, Location.unknown(ctx):
|
|
module = Module.create()
|
|
with InsertionPoint(module.body):
|
|
i = [IntegerType.get_signless(k) for k in range(7)]
|
|
|
|
# Test Variadic-Fixed-Variadic
|
|
op = test.SameVariadicResultSizeOpVFV([i[0], i[1]], i[2], [i[3], i[4]])
|
|
# CHECK: i2
|
|
print(op.non_variadic.type)
|
|
# CHECK: [IntegerType(i0), IntegerType(i1)]
|
|
print(types(op.variadic1))
|
|
# CHECK: [IntegerType(i3), IntegerType(i4)]
|
|
print(types(op.variadic2))
|
|
|
|
assert (
|
|
typing.get_type_hints(test.same_variadic_result_vfv)["return"]
|
|
== Union[OpResult, OpResultList, test.SameVariadicResultSizeOpVFV]
|
|
)
|
|
assert (
|
|
type(test.same_variadic_result_vfv([i[0], i[1]], i[2], [i[3], i[4]]))
|
|
is OpResultList
|
|
)
|
|
|
|
# Test Variadic-Variadic-Variadic
|
|
op = test.SameVariadicResultSizeOpVVV(
|
|
[i[0], i[1]], [i[2], i[3]], [i[4], i[5]]
|
|
)
|
|
# CHECK: [IntegerType(i0), IntegerType(i1)]
|
|
print(types(op.variadic1))
|
|
# CHECK: [IntegerType(i2), IntegerType(i3)]
|
|
print(types(op.variadic2))
|
|
# CHECK: [IntegerType(i4), IntegerType(i5)]
|
|
print(types(op.variadic3))
|
|
|
|
# Test Fixed-Fixed-Variadic
|
|
op = test.SameVariadicResultSizeOpFFV(i[0], i[1], [i[2], i[3], i[4]])
|
|
# CHECK: i0
|
|
print(op.non_variadic1.type)
|
|
# CHECK: i1
|
|
print(op.non_variadic2.type)
|
|
# CHECK: [IntegerType(i2), IntegerType(i3), IntegerType(i4)]
|
|
print(types(op.variadic))
|
|
assert (
|
|
typing.get_type_hints(test.SameVariadicResultSizeOpFFV.variadic.fget)[
|
|
"return"
|
|
]
|
|
is OpResultList
|
|
)
|
|
assert type(op.variadic) is OpResultList
|
|
|
|
# Test Variadic-Variadic-Fixed
|
|
op = test.SameVariadicResultSizeOpVVF(
|
|
[i[0], i[1], i[2]], [i[3], i[4], i[5]], i[6]
|
|
)
|
|
# CHECK: [IntegerType(i0), IntegerType(i1), IntegerType(i2)]
|
|
print(types(op.variadic1))
|
|
# CHECK: [IntegerType(i3), IntegerType(i4), IntegerType(i5)]
|
|
print(types(op.variadic2))
|
|
# CHECK: i6
|
|
print(op.non_variadic.type)
|
|
|
|
# Test Fixed-Variadic-Fixed-Variadic-Fixed
|
|
op = test.SameVariadicResultSizeOpFVFVF(
|
|
i[0], [i[1], i[2]], i[3], [i[4], i[5]], i[6]
|
|
)
|
|
# CHECK: i0
|
|
print(op.non_variadic1.type)
|
|
# CHECK: [IntegerType(i1), IntegerType(i2)]
|
|
print(types(op.variadic1))
|
|
# CHECK: i3
|
|
print(op.non_variadic2.type)
|
|
# CHECK: [IntegerType(i4), IntegerType(i5)]
|
|
print(types(op.variadic2))
|
|
# CHECK: i6
|
|
print(op.non_variadic3.type)
|
|
|
|
# Test Fixed-Variadic-Fixed-Variadic-Fixed - Variadic group size 0
|
|
op = test.SameVariadicResultSizeOpFVFVF(i[0], [], i[1], [], i[2])
|
|
# CHECK: i0
|
|
print(op.non_variadic1.type)
|
|
# CHECK: []
|
|
print(types(op.variadic1))
|
|
# CHECK: i1
|
|
print(op.non_variadic2.type)
|
|
# CHECK: []
|
|
print(types(op.variadic2))
|
|
# CHECK: i2
|
|
print(op.non_variadic3.type)
|
|
|
|
# Test Fixed-Variadic-Fixed-Variadic-Fixed - Variadic group size 1
|
|
op = test.SameVariadicResultSizeOpFVFVF(i[0], [i[1]], i[2], [i[3]], i[4])
|
|
# CHECK: i0
|
|
print(op.non_variadic1.type)
|
|
# CHECK: [IntegerType(i1)]
|
|
print(types(op.variadic1))
|
|
# CHECK: i2
|
|
print(op.non_variadic2.type)
|
|
# CHECK: [IntegerType(i3)]
|
|
print(types(op.variadic2))
|
|
# CHECK: i4
|
|
print(op.non_variadic3.type)
|
|
|
|
assert (
|
|
typing.get_type_hints(test.results_variadic)["return"]
|
|
== Union[OpResult, OpResultList, test.ResultsVariadicOp]
|
|
)
|
|
assert type(test.results_variadic([i[0]])) is OpResult
|
|
op_res_variadic = test.ResultsVariadicOp([i[0]])
|
|
assert (
|
|
typing.get_type_hints(test.ResultsVariadicOp.res.fget)["return"]
|
|
is OpResultList
|
|
)
|
|
assert type(op_res_variadic.res) is OpResultList
|
|
|
|
|
|
# CHECK-LABEL: TEST: testVariadicAndNormalRegionOp
|
|
@run
|
|
def testVariadicAndNormalRegionOp():
|
|
with Context() as ctx, Location.unknown(ctx):
|
|
module = Module.create()
|
|
with InsertionPoint(module.body):
|
|
region_op = test.VariadicAndNormalRegionOp(2)
|
|
assert (
|
|
typing.get_type_hints(test.VariadicAndNormalRegionOp.region.fget)[
|
|
"return"
|
|
]
|
|
is Region
|
|
)
|
|
assert type(region_op.region) is Region
|
|
assert (
|
|
typing.get_type_hints(test.VariadicAndNormalRegionOp.variadic.fget)[
|
|
"return"
|
|
]
|
|
is RegionSequence
|
|
)
|
|
assert type(region_op.variadic) is RegionSequence
|
|
|
|
assert isinstance(region_op.opview, OpView)
|
|
assert isinstance(region_op.operation.opview, OpView)
|