[MLIR][CAPI][python] expose the python binding for linalgOp.getIndexingMaps (#136054)
This PR is mainly about exposing the python bindings for `linalgOp.getIndexingMaps`. --------- Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
This commit is contained in:
parent
b3a53cc721
commit
7119b0cfd3
@ -50,6 +50,9 @@ typedef struct MlirLinalgConvolutionDimensions {
|
||||
MLIR_CAPI_EXPORTED MlirLinalgConvolutionDimensions
|
||||
mlirLinalgInferConvolutionDimensions(MlirOperation op);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirAttribute
|
||||
mlirLinalgGetIndexingMapsAttribute(MlirOperation op);
|
||||
|
||||
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg);
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
@ -120,6 +120,16 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
|
||||
|
||||
m.def("infer_convolution_dimensions", &InferConvolutionDimensions,
|
||||
"Infers convolution dimensions", nb::arg("op"));
|
||||
|
||||
m.def(
|
||||
"get_indexing_maps",
|
||||
[](MlirOperation op) -> std::optional<MlirAttribute> {
|
||||
MlirAttribute attr = mlirLinalgGetIndexingMapsAttribute(op);
|
||||
if (mlirAttributeIsNull(attr))
|
||||
return std::nullopt;
|
||||
return attr;
|
||||
},
|
||||
"Returns the indexing_maps attribute for a linalg op.");
|
||||
}
|
||||
|
||||
NB_MODULE(_mlirDialectsLinalg, m) {
|
||||
|
@ -120,4 +120,14 @@ mlirLinalgInferConvolutionDimensions(MlirOperation op) {
|
||||
return result;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirAttribute
|
||||
mlirLinalgGetIndexingMapsAttribute(MlirOperation op) {
|
||||
auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
|
||||
if (!linalgOp)
|
||||
return MlirAttribute{nullptr};
|
||||
|
||||
ArrayAttr attr = linalgOp.getIndexingMaps();
|
||||
return wrap(attr);
|
||||
}
|
||||
|
||||
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)
|
||||
|
@ -159,3 +159,52 @@ def test_infer_convolution_dimensions_from_ops():
|
||||
assert list(dims.depth) == []
|
||||
assert list(dims.strides) == [1, 1]
|
||||
assert list(dims.dilations) == [1, 1]
|
||||
|
||||
|
||||
@run
|
||||
def test_get_indexing_maps_attr():
|
||||
with Context(), Location.unknown():
|
||||
module = Module.create()
|
||||
f32 = F32Type.get()
|
||||
with InsertionPoint(module.body):
|
||||
a_type = RankedTensorType.get((4, 8), f32)
|
||||
b_type = RankedTensorType.get((8, 16), f32)
|
||||
c_type = RankedTensorType.get((4, 16), f32)
|
||||
|
||||
dim_m = AffineDimExpr.get(0)
|
||||
dim_n = AffineDimExpr.get(1)
|
||||
dim_k = AffineDimExpr.get(2)
|
||||
|
||||
a_map = AffineMap.get(3, 0, [dim_m, dim_k])
|
||||
b_map = AffineMap.get(3, 0, [dim_k, dim_n])
|
||||
c_map = AffineMap.get(3, 0, [dim_m, dim_n])
|
||||
|
||||
@func.FuncOp.from_py_func(a_type, b_type, c_type)
|
||||
def matmul_func(a, b, c):
|
||||
zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.0), result=f32)
|
||||
assert not linalg.get_indexing_maps(
|
||||
zero.operation
|
||||
), "Expected no indexing_maps on non-linalg op"
|
||||
|
||||
init = linalg.fill(zero, outs=[c])
|
||||
fill_op = init.owner
|
||||
fill_maps = linalg.get_indexing_maps(fill_op)
|
||||
assert fill_maps is not None
|
||||
assert len(fill_maps) == 2
|
||||
|
||||
# The fill op should have maps like (d0, d1) -> () and (d0, d1).
|
||||
fill_input_map = fill_maps[0].value
|
||||
fill_output_map = fill_maps[1].value
|
||||
assert fill_input_map == AffineMap.get(2, 0, [])
|
||||
assert fill_output_map == AffineMap.get(2, 0, [dim_m, dim_n])
|
||||
|
||||
result = linalg.matmul(a, b, outs=(init,))
|
||||
matmul_op = result.owner
|
||||
matmul_maps = linalg.get_indexing_maps(matmul_op)
|
||||
assert matmul_maps is not None
|
||||
assert len(matmul_maps) == 3
|
||||
|
||||
maps = [map_attr.value for map_attr in matmul_maps]
|
||||
assert maps[0] == a_map
|
||||
assert maps[1] == b_map
|
||||
assert maps[2] == c_map
|
||||
|
Loading…
x
Reference in New Issue
Block a user