This a reland of https://github.com/llvm/llvm-project/pull/155741 which
was reverted at https://github.com/llvm/llvm-project/pull/157831. This
version is narrower in scope - it only turns on automatic stub
generation for `MLIRPythonExtension.Core._mlir` and **does not do
anything automatically**. Specifically, the only CMake code added to
`AddMLIRPython.cmake` is the `mlir_generate_type_stubs` function which
is then used only in a manual way. The API for
`mlir_generate_type_stubs` is:
```
Arguments:
MODULE_NAME: The fully-qualified name of the extension module (used for importing in python).
DEPENDS_TARGETS: List of targets these type stubs depend on being built; usually corresponding to the
specific extension module (e.g., something like StandalonePythonModules.extension._standaloneDialectsNanobind.dso)
and the core bindings extension module (e.g., something like StandalonePythonModules.extension._mlir.dso).
OUTPUT_DIR: The root output directory to emit the type stubs into.
OUTPUTS: List of expected outputs.
DEPENDS_TARGET_SRC_DEPS: List of cpp sources for extension library (for generating a DEPFILE).
IMPORT_PATHS: List of paths to add to PYTHONPATH for stubgen.
PATTERN_FILE: (Optional) Pattern file (see https://nanobind.readthedocs.io/en/latest/typing.html#pattern-files).
Outputs:
NB_STUBGEN_CUSTOM_TARGET: The target corresponding to generation which other targets can depend on.
```
Downstream users should use `mlir_generate_type_stubs` in coordination
with `declare_mlir_python_sources` to turn on stub generation for their
own downstream dialect extensions and upstream dialect extensions if
they so choose. Standalone example shows an example.
Note, downstream will also need to set
`-DMLIR_PYTHON_PACKAGE_PREFIX=...` correctly for their bindings.
There are cases where the same module can have multiple references (via
`PyModule::forModule` via `PyModule::createFromCapsule`) and thus when
`PyModule`s get gc'd `mlirModuleDestroy` can get called multiple times
for the same actual underlying `mlir::Module` (i.e., double free). So we
do actually need a "liveness map" for modules.
Note, if `type_caster<MlirModule>::from_cpp` weren't a thing we could guarantree
this never happened except explicitly when users called `PyModule::createFromCapsule`.
Historical context: `PyMlirContext::liveOperations` was an optimization
meant to cut down on the number of Python object allocations and
(partially) a mechanism for updating validity of ops after
transformation. E.g. during walking/transforming the AST. See original
patch [here](https://reviews.llvm.org/D87958).
Inspired by a
[renewed](https://github.com/llvm/llvm-project/pull/139721#issuecomment-3217131918)
interest in https://github.com/llvm/llvm-project/pull/139721 (which has
become a little stale...)
<p align="center">
<img width="504" height="375" alt="image"
src="https://github.com/user-attachments/assets/0daad562-d3d1-4876-8d01-5dba382ab186"
/>
</p>
In the previous go-around
(https://github.com/llvm/llvm-project/pull/92631) there were two issues
which have been resolved
1. ops that were "fetched" under a root op which has been transformed
are no longer reported as invalid. We simply "[formally
forbid](https://github.com/llvm/llvm-project/pull/92631#issuecomment-2119397018)"
this;
2. `Module._CAPICreate(module_capsule)` must now be followed by a
`module._clear_mlir_module()` to prevent double-freeing of the actual
`ModuleOp` object (i.e. calling the dtor on the
`OwningOpRef<ModuleOp>`):
```python
module = ...
module_dup = Module._CAPICreate(module._CAPIPtr)
module._clear_mlir_module()
```
- **the alternative choice** here is to remove the `Module._CAPICreate`
API altogether and replace it with something like `Module._move(module)`
which will do both `Module._CAPICreate` and `module._clear_mlir_module`.
Note, the other approach I explored last year was a [weakref
system](https://github.com/llvm/llvm-project/pull/97340) for
`mlir::Operation` which would effectively hoist this `liveOperations`
thing into MLIR core. Possibly doable but I now believe it's a bad idea.
The other potentially breaking change is `is`, which checks object
equality rather than value equality, will now report `False` because we
are always allocating `new` Python objects (ie that's the whole point of
this change). Users wanting to check equality for `Operation` and
`Module` should use `==`.
This PR implements "automatic" location inference in the bindings. The
way it works is it walks the frame stack collecting source locations
(Python captures these in the frame itself). It is inspired by JAX's
[implementation](523ddcfbca/jax/_src/interpreters/mlir.py (L462))
but moves the frame stack traversal into the bindings for better
performance.
The system supports registering "included" and "excluded" filenames;
frames originating from functions in included filenames **will not** be
filtered and frames originating from functions in excluded filenames
**will** be filtered (in that order). This allows excluding all the
generated `*_ops_gen.py` files.
The system is also "toggleable" and off by default to save people who
have their own systems (such as JAX) from the added cost.
Note, the system stores the entire stacktrace (subject to
`locTracebackFramesLimit`) in the `Location` using specifically a
`CallSiteLoc`. This can be useful for profiling tools (flamegraphs
etc.).
Shoutout to the folks at JAX for coming up with a good system.
---------
Co-authored-by: Jacques Pienaar <jpienaar@google.com>
- Introduces a `large_resource_limit` parameter across Python bindings,
enabling the eliding of resource strings exceeding a specified character
limit during IR printing.
- To maintain backward compatibilty, when using `operation.print()` API,
if `large_resource_limit` is None and the `large_elements_limit` is set,
the later will be used to elide the resource string as well. This change
was introduced by https://github.com/llvm/llvm-project/pull/125738.
- For printing using pass manager, the `large_resource_limit` and
`large_elements_limit` are completely independent of each other.
* `PyRegionList` is now sliceable. The dialect bindings generator seems
to assume it is sliceable already (!), yet accessing e.g. `cases` on
`scf.IndexedSwitchOp` raises a `TypeError` at runtime.
* `PyBlockList` and `PyOperationList` support negative indexing. It is
common for containers to do that in Python, and most container in the
MLIR Python bindings already allow the index to be negative.
In some projects like JAX ir.Context are used with disabled multi-threading to avoid
caching multiple threading pools:
623865fe95/jax/_src/interpreters/mlir.py (L606-L611)
However, when context has enabled multithreading it also uses locks on
the StorageUniquers and this can be helpful to avoid data races in the
multi-threaded execution (for example with free-threaded cpython,
https://github.com/jax-ml/jax/issues/26272).
With this PR user can enable the multi-threading: 1) enables additional
locking and 2) set a shared threading pool such that cached contexts can
have one global pool.
This PR extends the python bindings for CallSiteLoc, FileLineColRange,
FusedLoc, NameLoc with field accessors. It also adds the missing
`value.location` accessor.
I also did some "spring cleaning" here (`cast` -> `dyn_cast`) after
running into some of my own illegal casts.
The current `write_bytecode` implementation necessarily requires the
serialized module to be duplicated in memory when the python `bytes`
object is created and sent over the binding. For modules with large
resources, we may want to avoid this in-memory copy by serializing
directly to a file instead of sending bytes across the boundary.
For extremely large models, it may be inefficient to load the model into
memory in Python prior to passing it to the MLIR C APIs for
deserialization. This change adds an API to parse a ModuleOp directly
from a file path.
Re-lands
[4e14b8a](4e14b8afb4).
For extremely large models, it may be inefficient to load the model into
memory in Python prior to passing it to the MLIR C APIs for
deserialization. This change adds an API to parse a ModuleOp directly
from a file path.
If the large element limit is specified, large elements are hidden from
the asm but large resources are not. This change extends the large
elements limit to apply to printed resources as well.
This logic is in the critical path for constructing an operation from
Python. It is faster to compute this in C++ than it is in Python, and it
is a minor change to do this.
This change also alters the API contract of
_ods_common.get_op_results_or_values to avoid calling
get_op_result_or_value on each element of a sequence, since the C++ code
will now do this.
Most of the diff here is simply reordering the code in IRCore.cpp.
Currently we make two memory allocations for each PyOperation: a Python
object, and the PyOperation class itself. With some care we can allocate
the PyOperation inline inside the Python object, saving us a malloc()
call per object and perhaps improving cache locality.
Previously ODS-generated Python operations had code like this:
```
super().__init__(self.build_generic(attributes=attributes, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip))
```
we change it to:
```
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
```
This:
a) avoids an extra call dispatch (to `build_generic`), and
b) passes the class attributes directly to the constructor. Benchmarks
show that it is faster to pass these as arguments rather than having the
C++ code look up attributes on the class.
This PR improves the timing of the following benchmark on my workstation
from 5.3s to 4.5s:
```
def main(_):
with ir.Context(), ir.Location.unknown():
typ = ir.IntegerType.get_signless(32)
m = ir.Module.create()
with ir.InsertionPoint(m.body):
start = time.time()
for i in range(1000000):
arith.ConstantOp(typ, i)
end = time.time()
print(f"time: {end - start}")
```
Since this change adds an additional overload to the constructor and
does not alter any existing behaviors, it should be backwards
compatible.
In JAX, I observed a race between two PyOperation destructors from
different threads updating the same `liveOperations` map, despite not
intentionally sharing the context between different threads. Since I
don't think we can be completely sure when GC happens and on which
thread, it seems safest simply to add locking here.
We may also want to explicitly support sharing a context between threads
in the future, which would require this change or something similar.
This is a companion to #118583, although it can be landed independently
because since #117922 dialects do not have to use the same Python
binding framework as the Python core code.
This PR ports all of the in-tree dialect and pass extensions to
nanobind, with the exception of those that remain for testing pybind11
support.
This PR also:
* removes CollectDiagnosticsToStringScope from NanobindAdaptors.h. This
was overlooked in a previous PR and it is duplicated in Diagnostics.h.
---------
Co-authored-by: Jacques Pienaar <jpienaar@google.com>
Relands #118583, with a fix for Python 3.8 compatibility. It was not
possible to set the buffer protocol accessers via slots in Python 3.8.
Why? https://nanobind.readthedocs.io/en/latest/why.html says it better
than I can, but my primary motivation for this change is to improve MLIR
IR construction time from JAX.
For a complicated Google-internal LLM model in JAX, this change improves
the MLIR
lowering time by around 5s (out of around 30s), which is a significant
speedup for simply switching binding frameworks.
To a large extent, this is a mechanical change, for instance changing
`pybind11::` to `nanobind::`.
Notes:
* this PR needs Nanobind 2.4.0, because it needs a bug fix
(https://github.com/wjakob/nanobind/pull/806) that landed in that
release.
* this PR does not port the in-tree dialect extension modules. They can
be ported in a future PR.
* I removed the py::sibling() annotations from def_static and def_class
in `PybindAdapters.h`. These ask pybind11 to try to form an overload
with an existing method, but it's not possible to form mixed
pybind11/nanobind overloads this ways and the parent class is now
defined in nanobind. Better solutions may be possible here.
* nanobind does not contain an exact equivalent of pybind11's buffer
protocol support. It was not hard to add a nanobind implementation of a
similar API.
* nanobind is pickier about casting to std::vector<bool>, expecting that
the input is a sequence of bool types, not truthy values. In a couple of
places I added code to support truthy values during casting.
* nanobind distinguishes bytes (`nb::bytes`) from strings (e.g.,
`std::string`). This required nb::bytes overloads in a few places.
Why? https://nanobind.readthedocs.io/en/latest/why.html says it better
than I can, but my primary motivation for this change is to improve MLIR
IR construction time from JAX.
For a complicated Google-internal LLM model in JAX, this change improves
the MLIR
lowering time by around 5s (out of around 30s), which is a significant
speedup for simply switching binding frameworks.
To a large extent, this is a mechanical change, for instance changing
`pybind11::`
to `nanobind::`.
Notes:
* this PR needs Nanobind 2.4.0, because it needs a bug fix
(https://github.com/wjakob/nanobind/pull/806) that landed in that
release.
* this PR does not port the in-tree dialect extension modules. They can
be ported in a future PR.
* I removed the py::sibling() annotations from def_static and def_class
in `PybindAdapters.h`. These ask pybind11 to try to form an overload
with an existing method, but it's not possible to form mixed
pybind11/nanobind overloads this ways and the parent class is now
defined in nanobind. Better solutions may be possible here.
* nanobind does not contain an exact equivalent of pybind11's buffer
protocol support. It was not hard to add a nanobind implementation of a
similar API.
* nanobind is pickier about casting to std::vector<bool>, expecting that
the input is a sequence of bool types, not truthy values. In a couple of
places I added code to support truthy values during casting.
* nanobind distinguishes bytes (`nb::bytes`) from strings (e.g.,
`std::string`). This required nb::bytes overloads in a few places.
In the tablegen-generated Python bindings, we typically see a pattern
like:
```
class ConstantOp(_ods_ir.OpView):
...
def __init__(self, value, *, loc=None, ip=None):
...
super().__init__(self.build_generic(attributes=attributes, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip))
```
i.e., the generated code calls `OpView.__init__()` with the output of
`build_generic`. The purpose of `OpView` is to wrap another operation
object, and `OpView.__init__` can accept any `PyOperationBase` subclass,
and presumably the intention is that `build_generic` returns a
`PyOperation`, so the user ends up with a `PyOpView` wrapping a
`PyOperation`.
However, `PyOpView::buildGeneric` calls `PyOperation::create`, which
does not just build a PyOperation, but it also calls `createOpView` to
wrap that operation in a subclass of `PyOpView` and returns that view.
But that's rather pointless: we called this code from the constructor of
an `OpView` subclass, so we already have a view object ready to go; we
don't need to build another one!
If we change `PyOperation::create` to return the underlying
`PyOperation`, rather than a view wrapper, we can save allocating a
useless `PyOpView` object for each ODS-generated Python object.
This saves approximately 1.5s of Python time in a JAX LLM benchmark that
generates a mixture of upstream dialects and StableHLO.
The MLIR C and Python Bindings expose various methods from
`mlir::OpPrintingFlags` . This PR adds a binding for the `skipRegions`
method, which allows to skip the printing of Regions when printing Ops.
It also exposes this option as parameter in the python `get_asm` and
`print` methods
The PR implements MLIR Python Bindings for a few simple edit operations
on Block arguments, namely, `add_argument`, `erase_argument`, and
`erase_arguments`.
When an operation is erased in Python, its children may still be in the
"live" list inside Python bindings. After this, if some of the newly
allocated operations happen to reuse the same pointer address, this will
trigger an assertion in the bindings. This assertion would be incorrect
because the operations aren't actually live. Make sure we remove the
children operations from the "live" list when erasing the parent.
This also concentrates responsibility over the removal from the "live"
list and invalidation in a single place.
Note that this requires the IR to be sufficiently structurally valid so
a walk through it can succeed. If this invariant was broken by, e.g, C++
pass called from Python, there isn't much we can do.
If the python callback throws an error, the c++ code will throw a
py::error_already_set that needs to be caught and handled in the c++
code .
This change is inspired by the similar solution in
PySymbolTable::walkSymbolTables.