This PR exposes `linalg::inferContractionDims(ArrayRef<AffineMap>)` to
Python, allowing users to infer contraction dimensions (batch/m/n/k)
directly from a list of affine maps without needing an operation.
---------
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
This is a follow-up PR of #162699.
In this PR we clean CAPI and Python bindings of MLIR rewrite part by:
- remove all manually-defined `wrap`/`unwrap` functions;
- remove useless nanobind-defined Python class `RewritePattern`.
This PR adds support for defining custom **`RewritePattern`**
implementations directly in the Python bindings.
Previously, users could define similar patterns using the PDL dialect’s
bindings. However, for more complex patterns, this often required
writing multiple Python callbacks as PDL native constraints or rewrite
functions, which made the overall logic less intuitive—though it could
be more performant than a pure Python implementation (especially for
simple patterns).
With this change, we introduce an additional, straightforward way to
define patterns purely in Python, complementing the existing PDL-based
approach.
### Example
```python
def to_muli(op, rewriter):
with rewriter.ip:
new_op = arith.muli(op.operands[0], op.operands[1], loc=op.location)
rewriter.replace_op(op, new_op.owner)
with Context():
patterns = RewritePatternSet()
patterns.add(arith.AddIOp, to_muli) # a pattern that rewrites arith.addi to arith.muli
frozen = patterns.freeze()
module = ...
apply_patterns_and_fold_greedily(module, frozen)
```
---------
Co-authored-by: Maksim Levental <maksim.levental@gmail.com>
`PassManager::enableStatistics` seems currently missing in both C API
and Python bindings. So here we added them in this PR, which includes
the `PassDisplayMode` enum type and the `EnableStatistics` method.
In [#160520](https://github.com/llvm/llvm-project/pull/160520), we
discussed the current limitations of PDL rewriting in Python (see [this
comment](https://github.com/llvm/llvm-project/pull/160520#issuecomment-3332326184)).
At the moment, we cannot create new operations in PDL native (python)
rewrite functions because the `PatternRewriter` APIs are not exposed.
This PR introduces bindings to retrieve the insertion point of the
`PatternRewriter`, enabling users to create new operations within Python
rewrite functions. With this capability, more complex rewrites e.g. with
branching and loops that involve op creations become possible.
---------
Co-authored-by: Maksim Levental <maksim.levental@gmail.com>
This is a follow-up to #159926.
That PR (#159926) exposed native rewrite function registration in PDL
through the C API and Python, enabling use with
`pdl.apply_native_rewrite`.
In this PR, we add support for native constraint functions in PDL via
`pdl.apply_native_constraint`, further completing the PDL API.
In the MLIR Python bindings, we can currently use PDL to define simple
patterns and then execute them with the greedy rewrite driver. However,
when dealing with more complex patterns—such as constant folding for
integer addition—we find that we need `apply_native_rewrite` to actually
perform arithmetic (i.e., compute the sum of two constants). For
example, consider the following PDL pseudocode:
```mlir
pdl.pattern : benefit(1) {
%a0 = pdl.attribute
%a1 = pdl.attribute
%c0 = pdl.operation "arith.constant" {value = %a0}
%c1 = pdl.operation "arith.constant" {value = %a1}
%op = pdl.operation "arith.addi"(%c0, %c1)
%sum = pdl.apply_native_rewrite "addIntegers"(%a0, %a1)
%new_cst = pdl.operation "arith.constant" {value = %sum}
pdl.replace %op with %new_cst
}
```
Here, `addIntegers` cannot be expressed in PDL alone—it requires a
*native rewrite function*. This PR introduces a mechanism to support
exactly that, allowing complex rewrite patterns to be expressed in
Python and enabling many passes to be implemented directly in Python as
well.
As a test case, we defined two new operations (`myint.constant` and
`myint.add`) in Python and implemented a constant-folding rewrite
pattern for them. The core code looks like this:
```python
m = Module.create()
with InsertionPoint(m.body):
@pdl.pattern(benefit=1, sym_name="myint_add_fold")
def pat():
...
op0 = pdl.OperationOp(name="myint.add", args=[v0, v1], types=[t])
@pdl.rewrite()
def rew():
sum = pdl.apply_native_rewrite(
[pdl.AttributeType.get()], "add_fold", [a0, a1]
)
newOp = pdl.OperationOp(
name="myint.constant", attributes={"value": sum}, types=[t]
)
pdl.ReplaceOp(op0, with_op=newOp)
def add_fold(rewriter, results, values):
a0, a1 = values
results.push_back(IntegerAttr.get(i32, a0.value + a1.value))
pdl_module = PDLModule(m)
pdl_module.register_rewrite_function("add_fold", add_fold)
```
The idea is previously discussed in Discord #mlir-python channel with
@makslevental.
---------
Co-authored-by: Maksim Levental <maksim.levental@gmail.com>
It closes#155996.
This PR added a method `add(callable, ..)` to
`mlir.passmanager.PassManager` to accept a callable object for defining
passes in the Python side.
This is a simple example of a Python-defined pass.
```python
from mlir.passmanager import PassManager
def demo_pass_1(op):
# do something with op
pass
class DemoPass:
def __init__(self, ...):
pass
def __call__(op):
# do something
pass
demo_pass_2 = DemoPass(..)
pm = PassManager('any', ctx)
pm.add(demo_pass_1)
pm.add(demo_pass_2)
pm.add("registered-passes")
pm.run(..)
```
---------
Co-authored-by: cnb.bsD2OPwAgEA <QejD2DJ2eEahUVy6Zg0aZI+cnb.bsD2OPwAgEA@noreply.cnb.cool>
Co-authored-by: Maksim Levental <maksim.levental@gmail.com>
In https://github.com/llvm/llvm-project/pull/94714, we add a python
function `apply_patterns_and_fold_greedily` which accepts an
`MlirModule` as the argument type. However, sometimes we want to apply
patterns with an `MlirOperation` argument, and there is currently no
python API to convert an `MlirOperation` to `MlirModule`.
So here we overload this function `apply_patterns_and_fold_greedily` to
do this (also a corresponding new C API
`mlirApplyPatternsAndFoldGreedilyWithOp`)
This commit moves the "element" param of `DICompositeType` to the end of
the parameter list. This is required as there seems to be a bug in the
attribute parser that breaks a print + parse roundtrip.
Related ticket: https://github.com/llvm/llvm-project/issues/156623
Historical context: `PyMlirContext::liveOperations` was an optimization
meant to cut down on the number of Python object allocations and
(partially) a mechanism for updating validity of ops after
transformation. E.g. during walking/transforming the AST. See original
patch [here](https://reviews.llvm.org/D87958).
Inspired by a
[renewed](https://github.com/llvm/llvm-project/pull/139721#issuecomment-3217131918)
interest in https://github.com/llvm/llvm-project/pull/139721 (which has
become a little stale...)
<p align="center">
<img width="504" height="375" alt="image"
src="https://github.com/user-attachments/assets/0daad562-d3d1-4876-8d01-5dba382ab186"
/>
</p>
In the previous go-around
(https://github.com/llvm/llvm-project/pull/92631) there were two issues
which have been resolved
1. ops that were "fetched" under a root op which has been transformed
are no longer reported as invalid. We simply "[formally
forbid](https://github.com/llvm/llvm-project/pull/92631#issuecomment-2119397018)"
this;
2. `Module._CAPICreate(module_capsule)` must now be followed by a
`module._clear_mlir_module()` to prevent double-freeing of the actual
`ModuleOp` object (i.e. calling the dtor on the
`OwningOpRef<ModuleOp>`):
```python
module = ...
module_dup = Module._CAPICreate(module._CAPIPtr)
module._clear_mlir_module()
```
- **the alternative choice** here is to remove the `Module._CAPICreate`
API altogether and replace it with something like `Module._move(module)`
which will do both `Module._CAPICreate` and `module._clear_mlir_module`.
Note, the other approach I explored last year was a [weakref
system](https://github.com/llvm/llvm-project/pull/97340) for
`mlir::Operation` which would effectively hoist this `liveOperations`
thing into MLIR core. Possibly doable but I now believe it's a bad idea.
The other potentially breaking change is `is`, which checks object
equality rather than value equality, will now report `False` because we
are always allocating `new` Python objects (ie that's the whole point of
this change). Users wanting to check equality for `Operation` and
`Module` should use `==`.
Retry landing https://github.com/llvm/llvm-project/pull/153373
## Major changes from previous attempt
- remove the test in CAPI because no existing tests in CAPI deal with
sanitizer exemptions
- update `mlir/docs/Dialects/GPU.md` to reflect the new behavior: load
GPU binary in global ctors, instead of loading them at call site.
- skip the test on Aarch64 since we have an issue with initialization there
---------
Co-authored-by: Mehdi Amini <joker.eph@gmail.com>
This PR introduces a mechanism to defer JIT engine initialization,
enabling registration of required symbols before global constructor
execution.
## Problem
Modules containing `gpu.module` generate global constructors (e.g.,
kernel load/unload) that execute *during* engine creation. This can
force premature symbol resolution, causing failures when:
- Symbols are registered via `mlirExecutionEngineRegisterSymbol` *after*
creation
- Global constructors exist (even if not directly using unresolved
symbols, e.g., an external function declaration)
- GPU modules introduce mandatory binary loading logic
## Usage
```c
// Create engine without initialization
MlirExecutionEngine jit = mlirExecutionEngineCreate(...);
// Register required symbols
mlirExecutionEngineRegisterSymbol(jit, ...);
// Explicitly initialize (runs global constructors)
mlirExecutionEngineInitialize(jit);
```
---------
Co-authored-by: Mehdi Amini <joker.eph@gmail.com>
Reland https://github.com/llvm/llvm-project/pull/150805
Shared libs build was broken.
Add `${dialect_libs}` and `${conversion_libs}` to
`MLIRRegisterAllExtensions` because it depends on
`registerConvert***ToLLVMInterface` functions.
`InitAll***` functions are used by `opt`-style tools to init all MLIR
dialects/passes/extensions. Currently they are implemeted as inline
functions and include essentially the entire MLIR header tree. Each file
which includes this header (~10 currently) takes 10+ sec and multiple GB
of ram to compile (tested with clang-19), which limits amount of
parallel compiler jobs which can be run. Also, flang just includes this
file into one of its headers.
Move the actual registration code to the static library, so it's
compiled only once.
Discourse thread
https://discourse.llvm.org/t/rfc-moving-initall-implementation-into-static-library/87559
The motivation is to avoid having to negate `isDynamic*` checks, avoid
double negations, and allow for `ShapedType::isStaticDim` to be used in
ADT functions without having to wrap it in a lambda performing the
negation.
Also add the new functions to C and Python bindings.
This thread through proper error handling / reporting capabilities to
avoid hitting llvm_unreachable while parsing linalg ops.
Fixes#132755Fixes#132740Fixes#129185
This PR is mainly about exposing the python bindings for
`linalg::isaConvolutionOpInterface` and `linalg::inferConvolutionDims`.
---------
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
This PR is mainly about exposing the python bindings for`
linalg::isaContractionOpInterface` and` linalg::inferContractionDims`.
---------
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
I observed that we have the boundary comments in the codebase like:
```
//===----------------------------------------------------------------------===//
// ...
//===----------------------------------------------------------------------===//
```
I also observed that there are incomplete boundary comments. The
revision is generated by a script that completes the boundary comments.
```
//===----------------------------------------------------------------------===//
// ...
...
```
Signed-off-by: hanhanW <hanhan0912@gmail.com>
This is an implementation for [RFC: Supporting Sub-Channel Quantization
in
MLIR](https://discourse.llvm.org/t/rfc-supporting-sub-channel-quantization-in-mlir/82694).
In order to make the review process easier, the PR has been divided into
the following commit labels:
1. **Add implementation for sub-channel type:** Includes the class
design for `UniformQuantizedSubChannelType`, printer/parser and bytecode
read/write support. The existing types (per-tensor and per-axis) are
unaltered.
2. **Add implementation for sub-channel type:** Lowering of
`quant.qcast` and `quant.dcast` operations to Linalg operations.
3. **Adding C/Python Apis:** We first define he C-APIs and build the
Python-APIs on top of those.
4. **Add pass to normalize generic ....:** This pass normalizes
sub-channel quantized types to per-tensor per-axis types, if possible.
A design note:
- **Explicitly storing the `quantized_dimensions`, even when they can be
derived for ranked tensor.**
While it's possible to infer quantized dimensions from the static shape
of the scales (or zero-points) tensor for ranked
data tensors
([ref](https://discourse.llvm.org/t/rfc-supporting-sub-channel-quantization-in-mlir/82694/3)
for background), there are cases where this can lead to ambiguity and
issues with round-tripping.
```
Consider the example: tensor<2x4x!quant.uniform<i8:f32:{0:2, 0:2}, {{s00:z00, s01:z01}}>>
```
The shape of the scales tensor is [1, 2], which might suggest that only
axis 1 is quantized. While this inference is technically correct, as the
block size for axis 0 is a degenerate case (equal to the dimension
size), it can cause problems with round-tripping. Therefore, even for
ranked tensors, we are explicitly storing the quantized dimensions.
Suggestions welcome!
PS: I understand that the upcoming holidays may impact your schedule, so
please take your time with the review. There's no rush.
In some projects like JAX ir.Context are used with disabled multi-threading to avoid
caching multiple threading pools:
623865fe95/jax/_src/interpreters/mlir.py (L606-L611)
However, when context has enabled multithreading it also uses locks on
the StorageUniquers and this can be helpful to avoid data races in the
multi-threaded execution (for example with free-threaded cpython,
https://github.com/jax-ml/jax/issues/26272).
With this PR user can enable the multi-threading: 1) enables additional
locking and 2) set a shared threading pool such that cached contexts can
have one global pool.
This PR extends the python bindings for CallSiteLoc, FileLineColRange,
FusedLoc, NameLoc with field accessors. It also adds the missing
`value.location` accessor.
I also did some "spring cleaning" here (`cast` -> `dyn_cast`) after
running into some of my own illegal casts.
Apparently trying to lookup a function pointer using the C api
`mlirExecutionEngineLookup` will throw an assert instead of just
returning a nullptr on builds with asserts.
The docs itself says it returns a nullptr when no function is found so
it should be sensible to not throw an assert in this case.
For extremely large models, it may be inefficient to load the model into
memory in Python prior to passing it to the MLIR C APIs for
deserialization. This change adds an API to parse a ModuleOp directly
from a file path.
Re-lands
[4e14b8a](4e14b8afb4).