[MLIR][CAPI][python] expose the python binding for linalgOp.getIndexingMaps (#136054)

This PR is mainly about exposing the python bindings for
`linalgOp.getIndexingMaps`.

---------

Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
This commit is contained in:
Bangtian Liu 2025-04-17 16:52:36 -04:00 committed by GitHub
parent b3a53cc721
commit 7119b0cfd3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 72 additions and 0 deletions

View File

@ -50,6 +50,9 @@ typedef struct MlirLinalgConvolutionDimensions {
MLIR_CAPI_EXPORTED MlirLinalgConvolutionDimensions MLIR_CAPI_EXPORTED MlirLinalgConvolutionDimensions
mlirLinalgInferConvolutionDimensions(MlirOperation op); mlirLinalgInferConvolutionDimensions(MlirOperation op);
MLIR_CAPI_EXPORTED MlirAttribute
mlirLinalgGetIndexingMapsAttribute(MlirOperation op);
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg); MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg);
#ifdef __cplusplus #ifdef __cplusplus

View File

@ -120,6 +120,16 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
m.def("infer_convolution_dimensions", &InferConvolutionDimensions, m.def("infer_convolution_dimensions", &InferConvolutionDimensions,
"Infers convolution dimensions", nb::arg("op")); "Infers convolution dimensions", nb::arg("op"));
m.def(
"get_indexing_maps",
[](MlirOperation op) -> std::optional<MlirAttribute> {
MlirAttribute attr = mlirLinalgGetIndexingMapsAttribute(op);
if (mlirAttributeIsNull(attr))
return std::nullopt;
return attr;
},
"Returns the indexing_maps attribute for a linalg op.");
} }
NB_MODULE(_mlirDialectsLinalg, m) { NB_MODULE(_mlirDialectsLinalg, m) {

View File

@ -120,4 +120,14 @@ mlirLinalgInferConvolutionDimensions(MlirOperation op) {
return result; return result;
} }
MLIR_CAPI_EXPORTED MlirAttribute
mlirLinalgGetIndexingMapsAttribute(MlirOperation op) {
auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
if (!linalgOp)
return MlirAttribute{nullptr};
ArrayAttr attr = linalgOp.getIndexingMaps();
return wrap(attr);
}
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect) MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)

View File

@ -159,3 +159,52 @@ def test_infer_convolution_dimensions_from_ops():
assert list(dims.depth) == [] assert list(dims.depth) == []
assert list(dims.strides) == [1, 1] assert list(dims.strides) == [1, 1]
assert list(dims.dilations) == [1, 1] assert list(dims.dilations) == [1, 1]
@run
def test_get_indexing_maps_attr():
with Context(), Location.unknown():
module = Module.create()
f32 = F32Type.get()
with InsertionPoint(module.body):
a_type = RankedTensorType.get((4, 8), f32)
b_type = RankedTensorType.get((8, 16), f32)
c_type = RankedTensorType.get((4, 16), f32)
dim_m = AffineDimExpr.get(0)
dim_n = AffineDimExpr.get(1)
dim_k = AffineDimExpr.get(2)
a_map = AffineMap.get(3, 0, [dim_m, dim_k])
b_map = AffineMap.get(3, 0, [dim_k, dim_n])
c_map = AffineMap.get(3, 0, [dim_m, dim_n])
@func.FuncOp.from_py_func(a_type, b_type, c_type)
def matmul_func(a, b, c):
zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.0), result=f32)
assert not linalg.get_indexing_maps(
zero.operation
), "Expected no indexing_maps on non-linalg op"
init = linalg.fill(zero, outs=[c])
fill_op = init.owner
fill_maps = linalg.get_indexing_maps(fill_op)
assert fill_maps is not None
assert len(fill_maps) == 2
# The fill op should have maps like (d0, d1) -> () and (d0, d1).
fill_input_map = fill_maps[0].value
fill_output_map = fill_maps[1].value
assert fill_input_map == AffineMap.get(2, 0, [])
assert fill_output_map == AffineMap.get(2, 0, [dim_m, dim_n])
result = linalg.matmul(a, b, outs=(init,))
matmul_op = result.owner
matmul_maps = linalg.get_indexing_maps(matmul_op)
assert matmul_maps is not None
assert len(matmul_maps) == 3
maps = [map_attr.value for map_attr in matmul_maps]
assert maps[0] == a_map
assert maps[1] == b_map
assert maps[2] == c_map