Maksim Levental f0ef5dba6d
[mlir][Python] create MLIRPythonSupport (#171775)
# What

This PR adds a shared library `MLIRPythonSupport` which contains all of
the CRTP classes ike `PyConcreteValue`, `PyConcreteType`,
`PyConcreteAttribute`, as well as other useful code like `Defaulting*`
and etc enabling their reuse in downstream projects. Downstream projects
can now do

```c++
struct PyTestType : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType<PyTestType> {
  ...
};

class PyTestAttr : public mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteAttribute<PyTestAttr> {
  ...
}

NB_MODULE(_mlirPythonTestNanobind, m) {
  PyTestType::bind(m);
  PyTestAttr::bind(m);
}
```

instead of using the discordant alternative
`mlir_type_subclass`/`mlir_attr_subclass` (same goes for
`PyConcreteValue`/`mlir_value_subclass`).

# Why

This PR is mostly code motion (along with CMake) but before I describe
the changes I want to state the goals/benefits:

1. Currently upstream "core" extensions and "dialect" extensions ([all
of the `Dialect*` extensions
here](d7c734b5a1/mlir/lib/Bindings/Python))
are a two-tier system;
**a**. [core
extensions](https://github.com/llvm/llvm-project/blob/main/mlir/lib/Bindings/Python/IRTypes.cpp#L361)
enjoy first class support as far as type inference[^3], type stub
generation, and ease of implementation, while dialect extensions [have
poorer support](https://reviews.llvm.org/D150927), incorrect type stub
generation much more tedious (boilerplate) implementation;
**b**. Crucially, this two-tiered system is reflected in the fact that
**the two sets of types/attributes are not in the same Python object
hierarchy**. To wit: `isinstance(..., Type)` and `isinstance(...,
Attribute)` are not supported for the dialect extensions[^2];
**c**. Since these types are not exposed in public headers, downstream
users (dialect extensions or not) cannot write functions that overload
on e.g. `PyFloat8*Type` - that's quite a [useful
feature](fdbee98df8/cpp_ext/TorchOps.cpp (L29-L69))!
2. The dialect extensions incur a sizeable performance penalty relative
to the core extensions in that every single trip across the wire (either
`python->cpp` or `cpp->python`) requires work in addition to nanobind's
own casting/construction pipeline;
**a**. When going from `python->cpp`, [we extract the capsule object
from the Python
object](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h#L219C24-L219C46)
and then extract from the capsule the `Mlir*` opaque struct/ptr. This
side isn't so onerous;
**b**. When going from `cpp->python` we call long-hand call Python
`import` APIs and construct the Python object using `_CAPICreate`. Note,
there at least 2 `attr` calls incurred in addition to `_CAPICreate`;
this is already much more [efficiently handled by nanobind
itself](4ba51fcf79/src/nb_internals.h (L381-L382))!
3. This division blocks various features: in some configurations[^1] we
trigger a circular import bug because "dialect" types and attributes
perform an [import of the root `_mlir`
module](bd9651bf78/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h (L585))
when they are created (the types themselves, not even instances of those
types). This blocks type stub generation for dialect extensions (i.e.,
the reason we currently only generate type stubs for `_mlir`).

# How

Prior this was not done/possible because of "ODR" issues but I have
resolved those issues; the basic idea for how we solve this is "move
things we want to share into shared libraries":

1. Move IRCore (stuff like `PyConcreteValue`, `PyConcreteType`,
`PyConcreteAttribute`) into `MLIRPythonSupport`;
- Note, we move the rest of the things in `IRModule.h` (renamed to
`IRCore.h`) because `PyConcreteValue`, `PyConcreteType`,
`PyConcreteAttribute` depend on them. This makes for a bigger PR than
one would hope for but ultimately I think we should give people access
to these classes to use as they see fit (specifically inherit from, but
also liberally use in bindings signatures instead of the opaque `Mlir*`
struct wrappers).
2. Put all of this code into a nested namespace
`MLIR_BINDINGS_PYTHON_DOMAIN` which is determined by a compile time
define (and tied to `MLIR_BINDINGS_PYTHON_NB_DOMAIN`). This is necessary
in order to prevent conflicts on both symbol name **and** typeid
(necessary for nanobind to not double register binded types) between
multiple bindings libraries (e.g., `torch-mlir`, and `jax`). Note
[nanobind doesn't support `module_local` like
pybind11](https://nanobind.readthedocs.io/en/latest/porting.html#removed-features).
It does support `NB_DOMAIN` but that is not sufficient for
disambiguating typeids across projects (to wit: we currently define
`NB_DOMAIN` and it was still necessary to move everything to a nested
namespace);
3. Build the [nanobind library itself as a shared
object](https://github.com/wjakob/nanobind/blob/master/cmake/nanobind-config.cmake#L127)
(and link it to both the extensions and `MLIRPythonSupport`).
4. CMake to make this work, in-tree, out-of-tree, downstream, upstream,
etc.

# Testing

Three tests are added here 

1. `PythonTestModuleNanobind` is ported to use
`PyConcreteType<PyTestType>` instead of `mlir_type_subclass` and
`PyConcreteAttribute<PyTestAttr>` instead of `mlir_atrr_subclass`,
verifying this works for non-core extensions in-tree;
2. `StandaloneExtensionNanobind` is ported to use `struct PyCustomType :
mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType<PyCustomType>`
instead of `mlir_type_subclass` verifying this works for non-core
extensions out-of-tree;
3. `StandaloneExtensionNanobind`'s `smoketest` is extended to also load
another bindings package (namely `mlir`) verifying
`MLIR_BINDINGS_PYTHON_DOMAIN` successfully disambiguates symbols and
typeids.

I have also tested this downstream:
https://github.com/llvm/eudsl/pull/287 as well run the following builder
bots:

mlir-nvidia-gcc7:
https://lab.llvm.org/buildbot/#/buildrequests/6654424?redirect_to_build=true

I have also tested against IREE:
https://github.com/iree-org/iree/pull/21916

# Integration

It is highly recommended to set the CMake var
`MLIR_BINDINGS_PYTHON_NB_DOMAIN` (which will also determine
`MLIR_BINDINGS_PYTHON_DOMAIN`) to something unique for each downstream.
This can also be passed explicitly to `add_mlir_python_modules` if your
project builds multiple bindings packages. I added a `WARNING` to this
effect in `AddMLIRPython.cmake`.

[^3]: Python values being typed correctly when exiting from cpp;
[^1]: Specifically when the modules are imported using `importlib`,
which occurs with nanobind's
[stubgen](https://github.com/wjakob/nanobind/blob/master/src/stubgen.py#L965);
[^2]: The workaround we implemented was a class method for the dialect
bindings called `Class.isinstance(...)`;
2026-01-05 09:08:13 -08:00

1036 lines
42 KiB
Python

# RUN: %PYTHON %s | FileCheck %s
import sys
import typing
from typing import Union, Optional
from mlir.ir import *
import mlir.dialects.func as func
import mlir.dialects.python_test as test
import mlir.dialects.tensor as tensor
import mlir.dialects.arith as arith
from mlir._mlir_libs._mlirPythonTestNanobind import (
TestAttr,
TestType,
TestTensorValue,
TestIntegerRankedTensorType,
)
test.register_python_test_dialect(get_dialect_registry())
def run(f):
print("\nTEST:", f.__name__)
f()
return f
# CHECK-LABEL: TEST: testAttributes
@run
def testAttributes():
with Context() as ctx, Location.unknown():
#
# Check op construction with attributes.
#
i32 = IntegerType.get_signless(32)
one = IntegerAttr.get(i32, 1)
two = IntegerAttr.get(i32, 2)
unit = UnitAttr.get()
# CHECK: python_test.attributed_op {
# CHECK-DAG: mandatory_i32 = 1 : i32
# CHECK-DAG: optional_i32 = 2 : i32
# CHECK-DAG: unit
# CHECK: }
op = test.AttributedOp(one, optional_i32=two, unit=unit)
print(f"{op}")
# CHECK: python_test.attributed_op {
# CHECK: mandatory_i32 = 2 : i32
# CHECK: }
op2 = test.AttributedOp(two)
print(f"{op2}")
#
# Check generic "attributes" access and mutation.
#
assert "additional" not in op.attributes
# CHECK: python_test.attributed_op {
# CHECK-DAG: additional = 1 : i32
# CHECK-DAG: mandatory_i32 = 2 : i32
# CHECK: }
op2.attributes["additional"] = one
print(f"{op2}")
# CHECK: python_test.attributed_op {
# CHECK-DAG: additional = 2 : i32
# CHECK-DAG: mandatory_i32 = 2 : i32
# CHECK: }
op2.attributes["additional"] = two
print(f"{op2}")
# CHECK: python_test.attributed_op {
# CHECK-NOT: additional = 2 : i32
# CHECK: mandatory_i32 = 2 : i32
# CHECK: }
del op2.attributes["additional"]
print(f"{op2}")
try:
print(op.attributes["additional"])
except KeyError:
pass
else:
assert False, "expected KeyError on unknown attribute key"
#
# Check accessors to defined attributes.
#
# CHECK: Mandatory: 1
# CHECK: Optional: 2
# CHECK: Unit: True
print(f"Mandatory: {op.mandatory_i32.value}")
print(f"Optional: {op.optional_i32.value}")
print(f"Unit: {op.unit}")
# CHECK: Mandatory: 2
# CHECK: Optional: None
# CHECK: Unit: False
print(f"Mandatory: {op2.mandatory_i32.value}")
print(f"Optional: {op2.optional_i32}")
print(f"Unit: {op2.unit}")
# CHECK: Mandatory: 2
# CHECK: Optional: None
# CHECK: Unit: False
op.mandatory_i32 = two
op.optional_i32 = None
op.unit = False
print(f"Mandatory: {op.mandatory_i32.value}")
print(f"Optional: {op.optional_i32}")
print(f"Unit: {op.unit}")
assert "optional_i32" not in op.attributes
assert "unit" not in op.attributes
try:
op.mandatory_i32 = None
except ValueError:
pass
else:
assert False, "expected ValueError on setting a mandatory attribute to None"
# CHECK: Optional: 2
op.optional_i32 = two
print(f"Optional: {op.optional_i32.value}")
# CHECK: Optional: None
del op.optional_i32
print(f"Optional: {op.optional_i32}")
# CHECK: Unit: False
op.unit = None
print(f"Unit: {op.unit}")
assert "unit" not in op.attributes
# CHECK: Unit: True
op.unit = True
print(f"Unit: {op.unit}")
# CHECK: Unit: False
del op.unit
print(f"Unit: {op.unit}")
# CHECK-LABEL: TEST: attrBuilder
@run
def attrBuilder():
with Context() as ctx, Location.unknown():
# CHECK: python_test.attributes_op
op = test.AttributesOp(
# CHECK-DAG: x_affinemap = affine_map<() -> (2)>
x_affinemap=AffineMap.get_constant(2),
# CHECK-DAG: x_affinemaparr = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
x_affinemaparr=[AffineMap.get_identity(3)],
# CHECK-DAG: x_arr = [true, "x"]
x_arr=[BoolAttr.get(True), StringAttr.get("x")],
x_boolarr=[False, True], # CHECK-DAG: x_boolarr = [false, true]
x_bool=True, # CHECK-DAG: x_bool = true
x_dboolarr=[True, False], # CHECK-DAG: x_dboolarr = array<i1: true, false>
x_df16arr=[21, 22], # CHECK-DAG: x_df16arr = array<i16: 21, 22>
# CHECK-DAG: x_df32arr = array<f32: 2.300000e+01, 2.400000e+01>
x_df32arr=[23, 24],
# CHECK-DAG: x_df64arr = array<f64: 2.500000e+01, 2.600000e+01>
x_df64arr=[25, 26],
x_di32arr=[0, 1], # CHECK-DAG: x_di32arr = array<i32: 0, 1>
# CHECK-DAG: x_di64arr = array<i64: 1, 2>
x_di64arr=[1, 2],
x_di8arr=[2, 3], # CHECK-DAG: x_di8arr = array<i8: 2, 3>
# CHECK-DAG: x_dictarr = [{a = false}]
x_dictarr=[{"a": BoolAttr.get(False)}],
x_dict={"b": BoolAttr.get(True)}, # CHECK-DAG: x_dict = {b = true}
x_f32=-2.25, # CHECK-DAG: x_f32 = -2.250000e+00 : f32
# CHECK-DAG: x_f32arr = [2.000000e+00 : f32, 3.000000e+00 : f32]
x_f32arr=[2.0, 3.0],
x_f64=4.25, # CHECK-DAG: x_f64 = 4.250000e+00 : f64
x_f64arr=[4.0, 8.0], # CHECK-DAG: x_f64arr = [4.000000e+00, 8.000000e+00]
# CHECK-DAG: x_f64elems = dense<[8.000000e+00, 1.600000e+01]> : tensor<2xf64>
x_f64elems=[8.0, 16.0],
# CHECK-DAG: x_flatsymrefarr = [@symbol1, @symbol2]
x_flatsymrefarr=["symbol1", "symbol2"],
x_flatsymref="symbol3", # CHECK-DAG: x_flatsymref = @symbol3
x_i1=0, # CHECK-DAG: x_i1 = false
x_i16=42, # CHECK-DAG: x_i16 = 42 : i16
x_i32=6, # CHECK-DAG: x_i32 = 6 : i32
x_i32arr=[4, 5], # CHECK-DAG: x_i32arr = [4 : i32, 5 : i32]
x_i32elems=[5, 6], # CHECK-DAG: x_i32elems = dense<[5, 6]> : tensor<2xi32>
x_i64=9, # CHECK-DAG: x_i64 = 9 : i64
x_i64arr=[7, 8], # CHECK-DAG: x_i64arr = [7, 8]
x_i64elems=[8, 9], # CHECK-DAG: x_i64elems = dense<[8, 9]> : tensor<2xi64>
x_i64svecarr=[10, 11], # CHECK-DAG: x_i64svecarr = [10, 11]
x_i8=11, # CHECK-DAG: x_i8 = 11 : i8
x_idx=10, # CHECK-DAG: x_idx = 10 : index
# CHECK-DAG: x_idxelems = dense<[11, 12]> : tensor<2xindex>
x_idxelems=[11, 12],
# CHECK-DAG: x_idxlistarr = [{{\[}}13], [14, 15]]
x_idxlistarr=[[13], [14, 15]],
x_si1=-1, # CHECK-DAG: x_si1 = -1 : si1
x_si16=-2, # CHECK-DAG: x_si16 = -2 : si16
x_si32=-3, # CHECK-DAG: x_si32 = -3 : si32
x_si64=-123, # CHECK-DAG: x_si64 = -123 : si64
x_si8=-4, # CHECK-DAG: x_si8 = -4 : si8
x_strarr=["hello", "world"], # CHECK-DAG: x_strarr = ["hello", "world"]
x_str="hello world!", # CHECK-DAG: x_str = "hello world!"
# CHECK-DAG: x_symrefarr = [@flatsym, @deep::@sym]
x_symrefarr=["flatsym", ["deep", "sym"]],
x_symref=["deep", "sym2"], # CHECK-DAG: x_symref = @deep::@sym2
x_sym="symbol", # CHECK-DAG: x_sym = "symbol"
x_typearr=[F32Type.get()], # CHECK-DAG: x_typearr = [f32]
x_type=F64Type.get(), # CHECK-DAG: x_type = f64
x_ui1=1, # CHECK-DAG: x_ui1 = 1 : ui1
x_ui16=2, # CHECK-DAG: x_ui16 = 2 : ui16
x_ui32=3, # CHECK-DAG: x_ui32 = 3 : ui32
x_ui64=4, # CHECK-DAG: x_ui64 = 4 : ui64
x_ui8=5, # CHECK-DAG: x_ui8 = 5 : ui8
x_unit=True, # CHECK-DAG: x_unit
)
op.verify()
op.print(use_local_scope=True)
# fmt: off
assert typing.get_type_hints(test.AttributesOp.x_affinemaparr.fset)["value"] is ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_affinemaparr.fget)["return"] is ArrayAttr
assert type(op.x_affinemaparr) is typing.get_type_hints(test.AttributesOp.x_affinemaparr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_affinemap.fset)["value"] is AffineMapAttr
assert typing.get_type_hints(test.AttributesOp.x_affinemap.fget)["return"] is AffineMapAttr
assert type(op.x_affinemap) is typing.get_type_hints(test.AttributesOp.x_affinemap.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_arr.fset)["value"] is ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_arr.fget)["return"] is ArrayAttr
assert type(op.x_arr) is typing.get_type_hints(test.AttributesOp.x_arr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_boolarr.fset)["value"] is ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_boolarr.fget)["return"] is ArrayAttr
assert type(op.x_boolarr) is typing.get_type_hints(test.AttributesOp.x_boolarr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_bool.fset)["value"] is BoolAttr
assert typing.get_type_hints(test.AttributesOp.x_bool.fget)["return"] is BoolAttr
assert type(op.x_bool) is typing.get_type_hints(test.AttributesOp.x_bool.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_dboolarr.fset)["value"] is DenseBoolArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_dboolarr.fget)["return"] is DenseBoolArrayAttr
assert type(op.x_dboolarr) is typing.get_type_hints(test.AttributesOp.x_dboolarr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_df32arr.fset)["value"] is DenseF32ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_df32arr.fget)["return"] is DenseF32ArrayAttr
assert type(op.x_df32arr) is typing.get_type_hints(test.AttributesOp.x_df32arr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_df64arr.fset)["value"] is DenseF64ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_df64arr.fget)["return"] is DenseF64ArrayAttr
assert type(op.x_df64arr) is typing.get_type_hints(test.AttributesOp.x_df64arr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_df16arr.fset)["value"] is DenseI16ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_df16arr.fget)["return"] is DenseI16ArrayAttr
assert type(op.x_df16arr) is typing.get_type_hints(test.AttributesOp.x_df16arr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_di32arr.fset)["value"] is DenseI32ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_di32arr.fget)["return"] is DenseI32ArrayAttr
assert type(op.x_di32arr) is typing.get_type_hints(test.AttributesOp.x_di32arr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_di64arr.fset)["value"] is DenseI64ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_di64arr.fget)["return"] is DenseI64ArrayAttr
assert type(op.x_di64arr) is typing.get_type_hints(test.AttributesOp.x_di64arr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_di8arr.fset)["value"] is DenseI8ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_di8arr.fget)["return"] is DenseI8ArrayAttr
assert type(op.x_di8arr) is typing.get_type_hints(test.AttributesOp.x_di8arr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_dictarr.fset)["value"] is ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_dictarr.fget)["return"] is ArrayAttr
assert type(op.x_dictarr) is typing.get_type_hints(test.AttributesOp.x_dictarr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_dict.fset)["value"] is DictAttr
assert typing.get_type_hints(test.AttributesOp.x_dict.fget)["return"] is DictAttr
assert type(op.x_dict) is typing.get_type_hints(test.AttributesOp.x_dict.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_f32arr.fset)["value"] is ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_f32arr.fget)["return"] is ArrayAttr
assert type(op.x_f32arr) is typing.get_type_hints(test.AttributesOp.x_f32arr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_f32.fset)["value"] is FloatAttr
assert typing.get_type_hints(test.AttributesOp.x_f32.fget)["return"] is FloatAttr
assert type(op.x_f32) is typing.get_type_hints(test.AttributesOp.x_f32.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_f64arr.fset)["value"] is ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_f64arr.fget)["return"] is ArrayAttr
assert type(op.x_f64arr) is typing.get_type_hints(test.AttributesOp.x_f64arr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_f64.fset)["value"] is FloatAttr
assert typing.get_type_hints(test.AttributesOp.x_f64.fget)["return"] is FloatAttr
assert type(op.x_f64) is typing.get_type_hints(test.AttributesOp.x_f64.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_f64elems.fset)["value"] is DenseFPElementsAttr
assert typing.get_type_hints(test.AttributesOp.x_f64elems.fget)["return"] is DenseFPElementsAttr
assert type(op.x_f64elems) is typing.get_type_hints(test.AttributesOp.x_f64elems.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_flatsymrefarr.fset)["value"] is ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_flatsymrefarr.fget)["return"] is ArrayAttr
assert type(op.x_flatsymrefarr) is typing.get_type_hints(test.AttributesOp.x_flatsymrefarr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_flatsymref.fset)["value"] is FlatSymbolRefAttr
assert typing.get_type_hints(test.AttributesOp.x_flatsymref.fget)["return"] is FlatSymbolRefAttr
assert type(op.x_flatsymref) is typing.get_type_hints(test.AttributesOp.x_flatsymref.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_i16.fset)["value"] is IntegerAttr
assert typing.get_type_hints(test.AttributesOp.x_i16.fget)["return"] is IntegerAttr
assert type(op.x_i16) is typing.get_type_hints(test.AttributesOp.x_i16.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_i1.fset)["value"] is BoolAttr
assert typing.get_type_hints(test.AttributesOp.x_i1.fget)["return"] is BoolAttr
assert type(op.x_i1) is typing.get_type_hints(test.AttributesOp.x_i1.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_i32arr.fset)["value"] is ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_i32arr.fget)["return"] is ArrayAttr
assert type(op.x_i32arr) is typing.get_type_hints(test.AttributesOp.x_i32arr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_i32.fset)["value"] is IntegerAttr
assert typing.get_type_hints(test.AttributesOp.x_i32.fget)["return"] is IntegerAttr
assert type(op.x_i32) is typing.get_type_hints(test.AttributesOp.x_i32.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_i32elems.fset)["value"] is DenseIntElementsAttr
assert typing.get_type_hints(test.AttributesOp.x_i32elems.fget)["return"] is DenseIntElementsAttr
assert type(op.x_i32elems) is typing.get_type_hints(test.AttributesOp.x_i32elems.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_i64arr.fset)["value"] is ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_i64arr.fget)["return"] is ArrayAttr
assert type(op.x_i64arr) is typing.get_type_hints(test.AttributesOp.x_i64arr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_i64.fset)["value"] is IntegerAttr
assert typing.get_type_hints(test.AttributesOp.x_i64.fget)["return"] is IntegerAttr
assert type(op.x_i64) is typing.get_type_hints(test.AttributesOp.x_i64.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_i64elems.fset)["value"] is DenseIntElementsAttr
assert typing.get_type_hints(test.AttributesOp.x_i64elems.fget)["return"] is DenseIntElementsAttr
assert type(op.x_i64elems) is typing.get_type_hints(test.AttributesOp.x_i64elems.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_i64svecarr.fset)["value"] is ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_i64svecarr.fget)["return"] is ArrayAttr
assert type(op.x_i64svecarr) is typing.get_type_hints(test.AttributesOp.x_i64svecarr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_i8.fset)["value"] is IntegerAttr
assert typing.get_type_hints(test.AttributesOp.x_i8.fget)["return"] is IntegerAttr
assert type(op.x_i8) is typing.get_type_hints(test.AttributesOp.x_i8.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_idx.fset)["value"] is IntegerAttr
assert typing.get_type_hints(test.AttributesOp.x_idx.fget)["return"] is IntegerAttr
assert type(op.x_idx) is typing.get_type_hints(test.AttributesOp.x_idx.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_idxelems.fset)["value"] is DenseIntElementsAttr
assert typing.get_type_hints(test.AttributesOp.x_idxelems.fget)["return"] is DenseIntElementsAttr
assert type(op.x_idxelems) is typing.get_type_hints(test.AttributesOp.x_idxelems.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_idxlistarr.fset)["value"] is ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_idxlistarr.fget)["return"] is ArrayAttr
assert type(op.x_idxlistarr) is typing.get_type_hints(test.AttributesOp.x_idxlistarr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_si16.fset)["value"] is IntegerAttr
assert typing.get_type_hints(test.AttributesOp.x_si16.fget)["return"] is IntegerAttr
assert type(op.x_si16) is typing.get_type_hints(test.AttributesOp.x_si16.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_si1.fset)["value"] is IntegerAttr
assert typing.get_type_hints(test.AttributesOp.x_si1.fget)["return"] is IntegerAttr
assert type(op.x_si1) is typing.get_type_hints(test.AttributesOp.x_si1.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_si32.fset)["value"] is IntegerAttr
assert typing.get_type_hints(test.AttributesOp.x_si32.fget)["return"] is IntegerAttr
assert type(op.x_si32) is typing.get_type_hints(test.AttributesOp.x_si32.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_si64.fset)["value"] is IntegerAttr
assert typing.get_type_hints(test.AttributesOp.x_si64.fget)["return"] is IntegerAttr
assert type(op.x_si64) is typing.get_type_hints(test.AttributesOp.x_si64.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_si8.fset)["value"] is IntegerAttr
assert typing.get_type_hints(test.AttributesOp.x_si8.fget)["return"] is IntegerAttr
assert type(op.x_si8) is typing.get_type_hints(test.AttributesOp.x_si8.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_strarr.fset)["value"] is ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_strarr.fget)["return"] is ArrayAttr
assert type(op.x_strarr) is typing.get_type_hints(test.AttributesOp.x_strarr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_str.fset)["value"] is StringAttr
assert typing.get_type_hints(test.AttributesOp.x_str.fget)["return"] is StringAttr
assert type(op.x_str) is typing.get_type_hints(test.AttributesOp.x_str.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_sym.fset)["value"] is StringAttr
assert typing.get_type_hints(test.AttributesOp.x_sym.fget)["return"] is StringAttr
assert type(op.x_sym) is typing.get_type_hints(test.AttributesOp.x_sym.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_symrefarr.fset)["value"] is ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_symrefarr.fget)["return"] is ArrayAttr
assert type(op.x_symrefarr) is typing.get_type_hints(test.AttributesOp.x_symrefarr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_symref.fset)["value"] is SymbolRefAttr
assert typing.get_type_hints(test.AttributesOp.x_symref.fget)["return"] is SymbolRefAttr
assert type(op.x_symref) is typing.get_type_hints(test.AttributesOp.x_symref.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_typearr.fset)["value"] is ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_typearr.fget)["return"] is ArrayAttr
assert type(op.x_typearr) is typing.get_type_hints(test.AttributesOp.x_typearr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_type.fset)["value"] is TypeAttr
assert typing.get_type_hints(test.AttributesOp.x_type.fget)["return"] is TypeAttr
assert type(op.x_type) is typing.get_type_hints(test.AttributesOp.x_type.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_ui16.fset)["value"] is IntegerAttr
assert typing.get_type_hints(test.AttributesOp.x_ui16.fget)["return"] is IntegerAttr
assert type(op.x_ui16) is typing.get_type_hints(test.AttributesOp.x_ui16.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_ui1.fset)["value"] is IntegerAttr
assert typing.get_type_hints(test.AttributesOp.x_ui1.fget)["return"] is IntegerAttr
assert type(op.x_ui1) is typing.get_type_hints(test.AttributesOp.x_ui1.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_ui32.fset)["value"] is IntegerAttr
assert typing.get_type_hints(test.AttributesOp.x_ui32.fget)["return"] is IntegerAttr
assert type(op.x_ui32) is typing.get_type_hints(test.AttributesOp.x_ui32.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_ui64.fset)["value"] is IntegerAttr
assert typing.get_type_hints(test.AttributesOp.x_ui64.fget)["return"] is IntegerAttr
assert type(op.x_ui64) is typing.get_type_hints(test.AttributesOp.x_ui64.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_ui8.fset)["value"] is IntegerAttr
assert typing.get_type_hints(test.AttributesOp.x_ui8.fget)["return"] is IntegerAttr
assert type(op.x_ui8) is typing.get_type_hints(test.AttributesOp.x_ui8.fget)["return"]
# fmt: on
# CHECK-LABEL: TEST: inferReturnTypes
@run
def inferReturnTypes():
with Context() as ctx, Location.unknown(ctx):
module = Module.create()
with InsertionPoint(module.body):
op = test.InferResultsOp()
dummy = test.DummyOp()
# CHECK: [Type(i32), Type(i64)]
iface = InferTypeOpInterface(op)
print(iface.inferReturnTypes())
# CHECK: [Type(i32), Type(i64)]
iface_static = InferTypeOpInterface(test.InferResultsOp)
print(iface.inferReturnTypes())
assert isinstance(iface.opview, test.InferResultsOp)
assert iface.opview == iface.operation.opview
try:
iface_static.opview
except TypeError:
pass
else:
assert False, (
"not expected to be able to obtain an opview from a static" " interface"
)
try:
InferTypeOpInterface(dummy)
except ValueError:
pass
else:
assert False, "not expected dummy op to implement the interface"
try:
InferTypeOpInterface(test.DummyOp)
except ValueError:
pass
else:
assert False, "not expected dummy op class to implement the interface"
# CHECK-LABEL: TEST: resultTypesDefinedByTraits
@run
def resultTypesDefinedByTraits():
with Context() as ctx, Location.unknown(ctx):
module = Module.create()
with InsertionPoint(module.body):
inferred = test.InferResultsOp()
# CHECK: i32 i64
print(inferred.single.type, inferred.doubled.type)
same = test.SameOperandAndResultTypeOp([inferred.results[0]])
# CHECK-COUNT-2: i32
print(same.one.type)
print(same.two.type)
assert (
typing.get_type_hints(test.SameOperandAndResultTypeOp.one.fget)[
"return"
]
is OpResult
)
assert type(same.one) is OpResult
first_type_attr = test.FirstAttrDeriveTypeAttrOp(
inferred.results[1], TypeAttr.get(IndexType.get())
)
# CHECK-COUNT-2: index
print(first_type_attr.one.type)
print(first_type_attr.two.type)
first_attr = test.FirstAttrDeriveAttrOp(FloatAttr.get(F32Type.get(), 3.14))
# CHECK-COUNT-3: f32
print(first_attr.one.type)
print(first_attr.two.type)
print(first_attr.three.type)
implied = test.InferResultsImpliedOp()
# CHECK: i32
print(implied.integer.type)
# CHECK: f64
print(implied.flt.type)
# CHECK: index
print(implied.index.type)
# provide the result types to avoid inferring them
f64 = F64Type.get()
no_imply = test.InferResultsImpliedOp(results=[f64, f64, f64])
# CHECK-COUNT-3: f64
print(no_imply.integer.type, no_imply.flt.type, no_imply.index.type)
no_infer = test.InferResultsOp(results=[F32Type.get(), IndexType.get()])
# CHECK: f32 index
print(no_infer.single.type, no_infer.doubled.type)
# CHECK-LABEL: TEST: testOptionalOperandOp
@run
def testOptionalOperandOp():
with Context() as ctx, Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
op1 = test.OptionalOperandOp()
# CHECK: op1.input is None: True
print(f"op1.input is None: {op1.input is None}")
assert (
typing.get_type_hints(test.OptionalOperandOp.input.fget)["return"]
is Optional[Value]
)
assert (
typing.get_type_hints(test.OptionalOperandOp.result.fget)["return"]
== OpResult[IntegerType]
)
assert type(op1.result) is OpResult
op2 = test.OptionalOperandOp(input=op1)
# CHECK: op2.input is None: False
print(f"op2.input is None: {op2.input is None}")
# CHECK-LABEL: TEST: testCustomAttribute
@run
def testCustomAttribute():
with Context() as ctx, Location.unknown():
a = TestAttr.get()
# CHECK: #python_test.test_attr
print(a)
# CHECK: python_test.custom_attributed_op {
# CHECK: #python_test.test_attr
# CHECK: }
op2 = test.CustomAttributedOp(a)
print(f"{op2}")
# CHECK: #python_test.test_attr
print(f"{op2.test_attr}")
# CHECK: TestAttr(#python_test.test_attr)
print(repr(op2.test_attr))
# The following cast must not assert.
b = TestAttr(a)
unit = UnitAttr.get()
try:
TestAttr(unit)
except ValueError as e:
assert "Cannot cast attribute to TestAttr" in str(e)
else:
raise
# The following must trigger a TypeError from our adaptors and must not
# crash.
try:
TestAttr(42)
except TypeError as e:
assert (
"__init__(): incompatible function arguments. The following argument types are supported"
in str(e)
)
assert (
"__init__(self, cast_from_attr: mlir._mlir_libs._mlir.ir.Attribute) -> None"
in str(e)
)
assert (
"Invoked with types: mlir._mlir_libs._mlirPythonTestNanobind.TestAttr, int"
in str(e)
)
else:
raise
# The following must trigger a TypeError from pybind (therefore, not
# checking its message) and must not crash.
try:
TestAttr(42, 56)
except TypeError:
pass
else:
raise
@run
def testCustomType():
with Context() as ctx:
a = TestType.get()
# CHECK: !python_test.test_type
print(a)
# The following cast must not assert.
b = TestType(a)
# Instance custom types should have typeids
assert isinstance(b.typeid, TypeID)
i8 = IntegerType.get_signless(8)
try:
TestType(i8)
except ValueError as e:
assert "Cannot cast type to TestType" in str(e)
else:
raise
# The following must trigger a TypeError from our adaptors and must not
# crash.
try:
TestType(42)
except TypeError as e:
assert (
"__init__(): incompatible function arguments. The following argument types are supported"
in str(e)
)
assert (
"__init__(self, cast_from_type: mlir._mlir_libs._mlir.ir.Type) -> None"
in str(e)
)
assert (
"Invoked with types: mlir._mlir_libs._mlirPythonTestNanobind.TestType, int"
in str(e)
)
else:
raise
# The following must trigger a TypeError from pybind (therefore, not
# checking its message) and must not crash.
try:
TestType(42, 56)
except TypeError:
pass
else:
raise
@run
# CHECK-LABEL: TEST: testValue
def testValue():
# Check that Value is a generic class at runtime.
assert hasattr(Value, "__class_getitem__")
@run
# CHECK-LABEL: TEST: testTensorValue
def testTensorValue():
with Context() as ctx, Location.unknown():
i8 = IntegerType.get_signless(8)
class Tensor(TestTensorValue):
def __str__(self):
return super().__str__().replace("Value", "Tensor")
module = Module.create()
with InsertionPoint(module.body):
t = tensor.EmptyOp([10, 10], i8).result
# CHECK: Value(%{{.*}} = tensor.empty() : tensor<10x10xi8>)
print(Value(t))
tt = Tensor(t)
# CHECK: Tensor(%{{.*}} = tensor.empty() : tensor<10x10xi8>)
print(tt)
# CHECK: False
print(tt.is_null())
# Classes of custom types that inherit from concrete types should have
# static_typeid
assert isinstance(TestIntegerRankedTensorType.static_typeid, TypeID)
# And it should be equal to the in-tree concrete type
assert TestIntegerRankedTensorType.static_typeid == t.type.typeid
d = tensor.EmptyOp([1, 2, 3], IntegerType.get_signless(5)).result
# CHECK: Value(%{{.*}} = tensor.empty() : tensor<1x2x3xi5>)
print(d)
# CHECK: TestTensorValue
print(repr(d))
# CHECK-LABEL: TEST: inferReturnTypeComponents
@run
def inferReturnTypeComponents():
with Context() as ctx, Location.unknown(ctx):
module = Module.create()
i32 = IntegerType.get_signless(32)
with InsertionPoint(module.body):
resultType = UnrankedTensorType.get(i32)
operandTypes = [
RankedTensorType.get([1, 3, 10, 10], i32),
UnrankedTensorType.get(i32),
]
f = func.FuncOp(
"test_inferReturnTypeComponents", (operandTypes, [resultType])
)
entry_block = Block.create_at_start(f.operation.regions[0], operandTypes)
with InsertionPoint(entry_block):
ranked_op = test.InferShapedTypeComponentsOp(
resultType, entry_block.arguments[0]
)
unranked_op = test.InferShapedTypeComponentsOp(
resultType, entry_block.arguments[1]
)
# CHECK: has rank: True
# CHECK: rank: 4
# CHECK: element type: i32
# CHECK: shape: [1, 3, 10, 10]
iface = InferShapedTypeOpInterface(ranked_op)
shaped_type_components = iface.inferReturnTypeComponents(
operands=[ranked_op.operand]
)[0]
print("has rank:", shaped_type_components.has_rank)
print("rank:", shaped_type_components.rank)
print("element type:", shaped_type_components.element_type)
print("shape:", shaped_type_components.shape)
# CHECK: has rank: False
# CHECK: rank: None
# CHECK: element type: i32
# CHECK: shape: None
iface = InferShapedTypeOpInterface(unranked_op)
shaped_type_components = iface.inferReturnTypeComponents(
operands=[unranked_op.operand]
)[0]
print("has rank:", shaped_type_components.has_rank)
print("rank:", shaped_type_components.rank)
print("element type:", shaped_type_components.element_type)
print("shape:", shaped_type_components.shape)
# CHECK-LABEL: TEST: testCustomTypeTypeCaster
@run
def testCustomTypeTypeCaster():
with Context() as ctx, Location.unknown():
a = TestType.get()
assert a.typeid is not None
b = Type.parse("!python_test.test_type")
# CHECK: !python_test.test_type
print(b)
# CHECK: TestType(!python_test.test_type)
print(repr(b))
c = TestIntegerRankedTensorType.get([10, 10], 5)
# CHECK: tensor<10x10xi5>
print(c)
# CHECK: TestIntegerRankedTensorType(tensor<10x10xi5>)
print(repr(c))
# CHECK: Type caster is already registered
try:
@register_type_caster(c.typeid)
def type_caster(pytype):
return TestIntegerRankedTensorType(pytype)
except RuntimeError as e:
print(e)
# python_test dialect registers a caster for RankedTensorType in its extension (pybind) module.
# So this one replaces that one (successfully). And then just to be sure we restore the original caster below.
@register_type_caster(c.typeid, replace=True)
def type_caster(pytype):
return RankedTensorType(pytype)
d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result
# CHECK: tensor<10x10xi5>
print(d.type)
# CHECK: ranked tensor type RankedTensorType(tensor<10x10xi5>)
print("ranked tensor type", repr(d.type))
@register_type_caster(c.typeid, replace=True)
def type_caster(pytype):
return TestIntegerRankedTensorType(pytype)
d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result
# CHECK: tensor<10x10xi5>
print(d.type)
# CHECK: TestIntegerRankedTensorType(tensor<10x10xi5>)
print(repr(d.type))
# CHECK-LABEL: TEST: testInferTypeOpInterface
@run
def testInferTypeOpInterface():
with Context() as ctx, Location.unknown(ctx):
module = Module.create()
with InsertionPoint(module.body):
i64 = IntegerType.get_signless(64)
zero = arith.ConstantOp(i64, 0)
one_operand = test.InferResultsVariadicInputsOp(single=zero, doubled=None)
# CHECK: i32
print(one_operand.result.type)
two_operands = test.InferResultsVariadicInputsOp(single=zero, doubled=zero)
# CHECK: f32
print(two_operands.result.type)
assert (
typing.get_type_hints(test.infer_results_variadic_inputs_op)["return"]
is OpResult
)
assert (
type(test.infer_results_variadic_inputs_op(single=zero, doubled=zero))
is OpResult
)
# CHECK-LABEL: TEST: testVariadicOperandAccess
@run
def testVariadicOperandAccess():
def values(lst):
return [str(e) for e in lst]
with Context() as ctx, Location.unknown(ctx):
module = Module.create()
with InsertionPoint(module.body):
i32 = IntegerType.get_signless(32)
zero = arith.ConstantOp(i32, 0)
one = arith.ConstantOp(i32, 1)
two = arith.ConstantOp(i32, 2)
three = arith.ConstantOp(i32, 3)
four = arith.ConstantOp(i32, 4)
variadic_operands = test.SameVariadicOperandSizeOp(
[zero, one], two, [three, four]
)
# CHECK: Value(%{{.*}} = arith.constant 2 : i32)
print(variadic_operands.non_variadic)
assert (
typing.get_type_hints(test.SameVariadicOperandSizeOp.non_variadic.fget)[
"return"
]
is Value
)
assert type(variadic_operands.non_variadic) is Value
# CHECK: ['Value(%{{.*}} = arith.constant 0 : i32)', 'Value(%{{.*}} = arith.constant 1 : i32)']
print(values(variadic_operands.variadic1))
assert (
typing.get_type_hints(test.SameVariadicOperandSizeOp.variadic1.fget)[
"return"
]
is OpOperandList
)
assert type(variadic_operands.variadic1) is OpOperandList
# CHECK: ['Value(%{{.*}} = arith.constant 3 : i32)', 'Value(%{{.*}} = arith.constant 4 : i32)']
print(values(variadic_operands.variadic2))
assert type(variadic_operands.variadic2) is OpOperandList
assert (
typing.get_type_hints(test.same_variadic_operand)["return"]
is test.SameVariadicOperandSizeOp
)
assert (
type(test.same_variadic_operand([zero, one], two, [three, four]))
is test.SameVariadicOperandSizeOp
)
# CHECK-LABEL: TEST: testVariadicResultAccess
@run
def testVariadicResultAccess():
def types(lst):
return [e.type for e in lst]
with Context() as ctx, Location.unknown(ctx):
module = Module.create()
with InsertionPoint(module.body):
i = [IntegerType.get_signless(k) for k in range(7)]
# Test Variadic-Fixed-Variadic
op = test.SameVariadicResultSizeOpVFV([i[0], i[1]], i[2], [i[3], i[4]])
# CHECK: i2
print(op.non_variadic.type)
# CHECK: [IntegerType(i0), IntegerType(i1)]
print(types(op.variadic1))
# CHECK: [IntegerType(i3), IntegerType(i4)]
print(types(op.variadic2))
assert (
typing.get_type_hints(test.same_variadic_result_vfv)["return"]
== Union[OpResult, OpResultList, test.SameVariadicResultSizeOpVFV]
)
assert (
type(test.same_variadic_result_vfv([i[0], i[1]], i[2], [i[3], i[4]]))
is OpResultList
)
# Test Variadic-Variadic-Variadic
op = test.SameVariadicResultSizeOpVVV(
[i[0], i[1]], [i[2], i[3]], [i[4], i[5]]
)
# CHECK: [IntegerType(i0), IntegerType(i1)]
print(types(op.variadic1))
# CHECK: [IntegerType(i2), IntegerType(i3)]
print(types(op.variadic2))
# CHECK: [IntegerType(i4), IntegerType(i5)]
print(types(op.variadic3))
# Test Fixed-Fixed-Variadic
op = test.SameVariadicResultSizeOpFFV(i[0], i[1], [i[2], i[3], i[4]])
# CHECK: i0
print(op.non_variadic1.type)
# CHECK: i1
print(op.non_variadic2.type)
# CHECK: [IntegerType(i2), IntegerType(i3), IntegerType(i4)]
print(types(op.variadic))
assert (
typing.get_type_hints(test.SameVariadicResultSizeOpFFV.variadic.fget)[
"return"
]
is OpResultList
)
assert type(op.variadic) is OpResultList
# Test Variadic-Variadic-Fixed
op = test.SameVariadicResultSizeOpVVF(
[i[0], i[1], i[2]], [i[3], i[4], i[5]], i[6]
)
# CHECK: [IntegerType(i0), IntegerType(i1), IntegerType(i2)]
print(types(op.variadic1))
# CHECK: [IntegerType(i3), IntegerType(i4), IntegerType(i5)]
print(types(op.variadic2))
# CHECK: i6
print(op.non_variadic.type)
# Test Fixed-Variadic-Fixed-Variadic-Fixed
op = test.SameVariadicResultSizeOpFVFVF(
i[0], [i[1], i[2]], i[3], [i[4], i[5]], i[6]
)
# CHECK: i0
print(op.non_variadic1.type)
# CHECK: [IntegerType(i1), IntegerType(i2)]
print(types(op.variadic1))
# CHECK: i3
print(op.non_variadic2.type)
# CHECK: [IntegerType(i4), IntegerType(i5)]
print(types(op.variadic2))
# CHECK: i6
print(op.non_variadic3.type)
# Test Fixed-Variadic-Fixed-Variadic-Fixed - Variadic group size 0
op = test.SameVariadicResultSizeOpFVFVF(i[0], [], i[1], [], i[2])
# CHECK: i0
print(op.non_variadic1.type)
# CHECK: []
print(types(op.variadic1))
# CHECK: i1
print(op.non_variadic2.type)
# CHECK: []
print(types(op.variadic2))
# CHECK: i2
print(op.non_variadic3.type)
# Test Fixed-Variadic-Fixed-Variadic-Fixed - Variadic group size 1
op = test.SameVariadicResultSizeOpFVFVF(i[0], [i[1]], i[2], [i[3]], i[4])
# CHECK: i0
print(op.non_variadic1.type)
# CHECK: [IntegerType(i1)]
print(types(op.variadic1))
# CHECK: i2
print(op.non_variadic2.type)
# CHECK: [IntegerType(i3)]
print(types(op.variadic2))
# CHECK: i4
print(op.non_variadic3.type)
assert (
typing.get_type_hints(test.results_variadic)["return"]
== Union[OpResult, OpResultList, test.ResultsVariadicOp]
)
assert type(test.results_variadic([i[0]])) is OpResult
op_res_variadic = test.ResultsVariadicOp([i[0]])
assert (
typing.get_type_hints(test.ResultsVariadicOp.res.fget)["return"]
is OpResultList
)
assert type(op_res_variadic.res) is OpResultList
# CHECK-LABEL: TEST: testVariadicAndNormalRegionOp
@run
def testVariadicAndNormalRegionOp():
with Context() as ctx, Location.unknown(ctx):
module = Module.create()
with InsertionPoint(module.body):
region_op = test.VariadicAndNormalRegionOp(2)
assert (
typing.get_type_hints(test.VariadicAndNormalRegionOp.region.fget)[
"return"
]
is Region
)
assert type(region_op.region) is Region
assert (
typing.get_type_hints(test.VariadicAndNormalRegionOp.variadic.fget)[
"return"
]
is RegionSequence
)
assert type(region_op.variadic) is RegionSequence
assert isinstance(region_op.opview, OpView)
assert isinstance(region_op.operation.opview, OpView)