[MLIR][CAPI][python] expose the python binding for linalgOp.getIndexingMaps (#136054)
This PR is mainly about exposing the python bindings for `linalgOp.getIndexingMaps`. --------- Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
This commit is contained in:
parent
b3a53cc721
commit
7119b0cfd3
@ -50,6 +50,9 @@ typedef struct MlirLinalgConvolutionDimensions {
|
|||||||
MLIR_CAPI_EXPORTED MlirLinalgConvolutionDimensions
|
MLIR_CAPI_EXPORTED MlirLinalgConvolutionDimensions
|
||||||
mlirLinalgInferConvolutionDimensions(MlirOperation op);
|
mlirLinalgInferConvolutionDimensions(MlirOperation op);
|
||||||
|
|
||||||
|
MLIR_CAPI_EXPORTED MlirAttribute
|
||||||
|
mlirLinalgGetIndexingMapsAttribute(MlirOperation op);
|
||||||
|
|
||||||
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg);
|
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
|
@ -120,6 +120,16 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
|
|||||||
|
|
||||||
m.def("infer_convolution_dimensions", &InferConvolutionDimensions,
|
m.def("infer_convolution_dimensions", &InferConvolutionDimensions,
|
||||||
"Infers convolution dimensions", nb::arg("op"));
|
"Infers convolution dimensions", nb::arg("op"));
|
||||||
|
|
||||||
|
m.def(
|
||||||
|
"get_indexing_maps",
|
||||||
|
[](MlirOperation op) -> std::optional<MlirAttribute> {
|
||||||
|
MlirAttribute attr = mlirLinalgGetIndexingMapsAttribute(op);
|
||||||
|
if (mlirAttributeIsNull(attr))
|
||||||
|
return std::nullopt;
|
||||||
|
return attr;
|
||||||
|
},
|
||||||
|
"Returns the indexing_maps attribute for a linalg op.");
|
||||||
}
|
}
|
||||||
|
|
||||||
NB_MODULE(_mlirDialectsLinalg, m) {
|
NB_MODULE(_mlirDialectsLinalg, m) {
|
||||||
|
@ -120,4 +120,14 @@ mlirLinalgInferConvolutionDimensions(MlirOperation op) {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MLIR_CAPI_EXPORTED MlirAttribute
|
||||||
|
mlirLinalgGetIndexingMapsAttribute(MlirOperation op) {
|
||||||
|
auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
|
||||||
|
if (!linalgOp)
|
||||||
|
return MlirAttribute{nullptr};
|
||||||
|
|
||||||
|
ArrayAttr attr = linalgOp.getIndexingMaps();
|
||||||
|
return wrap(attr);
|
||||||
|
}
|
||||||
|
|
||||||
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)
|
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)
|
||||||
|
@ -159,3 +159,52 @@ def test_infer_convolution_dimensions_from_ops():
|
|||||||
assert list(dims.depth) == []
|
assert list(dims.depth) == []
|
||||||
assert list(dims.strides) == [1, 1]
|
assert list(dims.strides) == [1, 1]
|
||||||
assert list(dims.dilations) == [1, 1]
|
assert list(dims.dilations) == [1, 1]
|
||||||
|
|
||||||
|
|
||||||
|
@run
|
||||||
|
def test_get_indexing_maps_attr():
|
||||||
|
with Context(), Location.unknown():
|
||||||
|
module = Module.create()
|
||||||
|
f32 = F32Type.get()
|
||||||
|
with InsertionPoint(module.body):
|
||||||
|
a_type = RankedTensorType.get((4, 8), f32)
|
||||||
|
b_type = RankedTensorType.get((8, 16), f32)
|
||||||
|
c_type = RankedTensorType.get((4, 16), f32)
|
||||||
|
|
||||||
|
dim_m = AffineDimExpr.get(0)
|
||||||
|
dim_n = AffineDimExpr.get(1)
|
||||||
|
dim_k = AffineDimExpr.get(2)
|
||||||
|
|
||||||
|
a_map = AffineMap.get(3, 0, [dim_m, dim_k])
|
||||||
|
b_map = AffineMap.get(3, 0, [dim_k, dim_n])
|
||||||
|
c_map = AffineMap.get(3, 0, [dim_m, dim_n])
|
||||||
|
|
||||||
|
@func.FuncOp.from_py_func(a_type, b_type, c_type)
|
||||||
|
def matmul_func(a, b, c):
|
||||||
|
zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.0), result=f32)
|
||||||
|
assert not linalg.get_indexing_maps(
|
||||||
|
zero.operation
|
||||||
|
), "Expected no indexing_maps on non-linalg op"
|
||||||
|
|
||||||
|
init = linalg.fill(zero, outs=[c])
|
||||||
|
fill_op = init.owner
|
||||||
|
fill_maps = linalg.get_indexing_maps(fill_op)
|
||||||
|
assert fill_maps is not None
|
||||||
|
assert len(fill_maps) == 2
|
||||||
|
|
||||||
|
# The fill op should have maps like (d0, d1) -> () and (d0, d1).
|
||||||
|
fill_input_map = fill_maps[0].value
|
||||||
|
fill_output_map = fill_maps[1].value
|
||||||
|
assert fill_input_map == AffineMap.get(2, 0, [])
|
||||||
|
assert fill_output_map == AffineMap.get(2, 0, [dim_m, dim_n])
|
||||||
|
|
||||||
|
result = linalg.matmul(a, b, outs=(init,))
|
||||||
|
matmul_op = result.owner
|
||||||
|
matmul_maps = linalg.get_indexing_maps(matmul_op)
|
||||||
|
assert matmul_maps is not None
|
||||||
|
assert len(matmul_maps) == 3
|
||||||
|
|
||||||
|
maps = [map_attr.value for map_attr in matmul_maps]
|
||||||
|
assert maps[0] == a_map
|
||||||
|
assert maps[1] == b_map
|
||||||
|
assert maps[2] == c_map
|
||||||
|
Loading…
x
Reference in New Issue
Block a user