[MLIR][Linalg][Python] Improve bindings for linalg.elementwise (#139462)

Adds wrappers for ElementWiseOp, in particular to ensure appropriate
default indexing maps are derived.
This commit is contained in:
Rolf Morel 2025-05-12 11:34:55 +02:00 committed by GitHub
parent 688bccb290
commit ba739c166d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 247 additions and 0 deletions

View File

@ -216,6 +216,67 @@ def contract(
)
# Extend and shadow the TableGen-derived version to make sure correct default
# indexing_maps are derived (as there is no mechanism for doing so given the
# Python API bypasses the C++-builders).
class ElementwiseOp_(ElementwiseOp):
def __init__(
self,
result_tensors,
inputs,
outputs,
kind,
*,
indexing_maps=None,
loc=None,
ip=None,
):
if indexing_maps is None:
inputs = [_get_op_result_or_value(in_) for in_ in inputs]
for in0, in1 in zip(inputs[:-1], inputs[1:]):
assert in0.type == in1.type
output = _get_op_result_or_value(outputs[0])
assert inputs[0].type == output.type
num_args = len(inputs) + 1
indexing_maps = [AffineMap.get_identity(output.type.rank)] * num_args
super().__init__(
result_tensors=result_tensors,
inputs=inputs,
outputs=outputs,
kind=kind,
indexing_maps=indexing_maps,
loc=loc,
ip=ip,
)
ElementwiseOp = ElementwiseOp_
def elementwise(
*ins: Union[Operation, OpView, Value],
outs: Sequence[Union[Operation, OpView, Value]],
kind: Union[ElementwiseKind, Attribute],
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
):
ins = [_get_op_result_or_value(input) for input in ins]
if len(outs) != 1:
raise ValueError(f"{outs=} must have length 1.")
init = _get_op_result_or_value(outs[0])
result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
op = ElementwiseOp(
result_tensors=result_types,
inputs=ins,
outputs=[init],
kind=kind,
indexing_maps=indexing_maps,
)
fill_builtin_region(op.operation)
return _get_op_result_or_op_results(op)
def pack(
source,
dest,

View File

@ -606,3 +606,189 @@ def testPackUnPackOp():
# CHECK: return %[[VAL_4]] : tensor<128x128xf32>
# CHECK: }
print(module)
# CHECK-LABEL: TEST: testElementwiseOp
@run
def testElementwiseOp():
with Context(), Location.unknown():
module = Module.create()
f32 = F32Type.get()
with InsertionPoint(module.body):
rect_shape = (8, 16)
vert_line_shape = (8,)
hor_line_shape = (16,)
transposed_rect_shape = (16, 8)
# CHECK-DAG: #[[$IdentMap2D:.*]] = affine_map<(d0, d1) -> (d0, d1)>
# CHECK-DAG: #[[$TransMap2D:.*]] = affine_map<(d0, d1) -> (d1, d0)>
# CHECK-DAG: #[[$VertLineBCastMap:.*]] = affine_map<(d0, d1) -> (d0)>
# CHECK-DAG: #[[$HorLineBCastMap:.*]] = affine_map<(d0, d1) -> (d1)>
ident_map_2d = AffineMap.get_identity(2)
transposed_map_2d = AffineMap.get_permutation((1, 0))
vert_line_bcast_map = AffineMap.get(2, 0, [AffineDimExpr.get(0)])
hor_line_bcast_map = AffineMap.get(2, 0, [AffineDimExpr.get(1)])
# CHECK: func.func @elementwise_op(
@func.FuncOp.from_py_func(
# CHECK-SAME: %[[Rect:.*]]: tensor<8x16xf32>,
RankedTensorType.get(rect_shape, f32),
# CHECK-SAME: %[[RectMem:.*]]: memref<8x16xf32>,
MemRefType.get(rect_shape, f32),
# CHECK-SAME: %[[VertLine:.*]]: tensor<8xf32>,
RankedTensorType.get(vert_line_shape, f32),
# CHECK-SAME: %[[VertLineMem:.*]]: memref<8xf32>,
MemRefType.get(vert_line_shape, f32),
# CHECK-SAME: %[[HorLine:.*]]: tensor<16xf32>,
RankedTensorType.get(hor_line_shape, f32),
# CHECK-SAME: %[[HorLineMem:.*]]: memref<16xf32>,
MemRefType.get(hor_line_shape, f32),
# CHECK-SAME: %[[TransRect:.*]]: tensor<16x8xf32>,
RankedTensorType.get(transposed_rect_shape, f32),
# CHECK-SAME: %[[TransRectMem:.*]]: memref<16x8xf32>)
MemRefType.get(transposed_rect_shape, f32),
)
def elementwise_op(
rect,
rect_mem,
vert_line,
vert_line_mem,
hor_line,
hor_line_mem,
trans_rect,
trans_rect_mem,
):
# CHECK: %[[OutRect:.*]] = tensor.empty() : tensor<8x16xf32>
out_rect = tensor.EmptyOp(rect_shape, f32)
# CHECK: %[[OutRectMem:.*]] = memref.alloca() : memref<8x16xf32>
out_rect_mem = memref.alloca(MemRefType.get(rect_shape, f32), [], [])
if _inferred_affine_maps := True:
# CHECK: linalg.elementwise
# CHECK-SAME: kind=#linalg.elementwise_kind<exp>
# CHECK-SAME: ins(%[[Rect]] : tensor<8x16xf32>)
# CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
op1 = linalg.ElementwiseOp(
result_tensors=(out_rect.result.type,),
inputs=(rect,),
outputs=(out_rect,),
kind=linalg.ElementwiseKind.exp,
)
linalg.fill_builtin_region(op1.operation)
# CHECK: linalg.elementwise
# CHECK-SAME: kind=#linalg.elementwise_kind<exp>
# CHECK-SAME: ins(%[[Rect]] : tensor<8x16xf32>)
# CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
linalg.elementwise(
rect,
outs=(out_rect,),
kind=linalg.ElementwiseKind.exp,
)
# CHECK: linalg.elementwise
# CHECK-SAME: kind=#linalg.elementwise_kind<exp>
# CHECK-SAME: ins(%[[RectMem]] : memref<8x16xf32>)
# CHECK-SAME: outs(%[[OutRectMem]] : memref<8x16xf32>)
linalg.elementwise(
rect_mem,
outs=(out_rect_mem,),
kind=linalg.ElementwiseKind.exp,
)
if _explicit_ident_affine_maps := True:
# Same as above but with default identity indexing_maps explicitly provided.
# CHECK: linalg.elementwise
# CHECK-SAME: kind=#linalg.elementwise_kind<exp>
# CHECK-SAME: ins(%[[Rect]] : tensor<8x16xf32>)
# CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
op3 = linalg.ElementwiseOp(
result_tensors=(out_rect.result.type,),
inputs=(rect,),
outputs=(out_rect,),
kind=linalg.ElementwiseKind.exp,
indexing_maps=[ident_map_2d, ident_map_2d],
)
linalg.fill_builtin_region(op3.operation)
# CHECK: linalg.elementwise
# CHECK-SAME: kind=#linalg.elementwise_kind<exp>
# CHECK-SAME: ins(%[[RectMem]] : memref<8x16xf32>)
# CHECK-SAME: outs(%[[OutRectMem]] : memref<8x16xf32>)
linalg.elementwise(
rect_mem,
outs=(out_rect_mem,),
kind=linalg.ElementwiseKind.exp,
indexing_maps=[ident_map_2d, ident_map_2d],
)
if _ops_with_non_ident_input_maps := True:
# CHECK: linalg.elementwise kind=#linalg.elementwise_kind<exp>
# CHECK-SAME: indexing_maps = [#[[$VertLineBCastMap]], #[[$IdentMap2D]]]
# CHECK-SAME: ins(%[[VertLine]] : tensor<8xf32>)
# CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
op4 = linalg.ElementwiseOp(
result_tensors=(out_rect.result.type,),
inputs=(vert_line,),
outputs=(out_rect,),
kind=linalg.ElementwiseKind.exp,
indexing_maps=[vert_line_bcast_map, ident_map_2d],
)
linalg.fill_builtin_region(op4.operation)
# CHECK: linalg.elementwise kind=#linalg.elementwise_kind<add>
# CHECK-SAME: indexing_maps = [#[[$IdentMap2D]], #[[$VertLineBCastMap]], #[[$IdentMap2D]]]
# CHECK-SAME: ins(%[[Rect]], %[[VertLine]] : tensor<8x16xf32>, tensor<8xf32>)
# CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
op4 = linalg.ElementwiseOp(
result_tensors=(out_rect.result.type,),
inputs=(rect, vert_line),
outputs=(out_rect,),
kind=linalg.ElementwiseKind.add,
indexing_maps=[ident_map_2d, vert_line_bcast_map, ident_map_2d],
)
linalg.fill_builtin_region(op4.operation)
# CHECK: linalg.elementwise kind=#linalg.elementwise_kind<div>
# CHECK-SAME: indexing_maps = [#[[$VertLineBCastMap]], #[[$HorLineBCastMap]], #[[$IdentMap2D]]]
# CHECK-SAME: ins(%[[VertLine]], %[[HorLine]] : tensor<8xf32>, tensor<16xf32>)
# CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
linalg.elementwise(
vert_line,
hor_line,
outs=(out_rect,),
kind=linalg.ElementwiseKind.div,
indexing_maps=[
vert_line_bcast_map,
hor_line_bcast_map,
ident_map_2d,
],
)
if _ops_with_non_ident_and_transposed_input_maps := True:
# CHECK: %[[VertLineBoolsMem:.*]] = memref.alloca() : memref<8xi1>
vert_line_bools_mem = memref.alloca(
MemRefType.get(vert_line_shape, IntegerType.get_signless(1)),
[],
[],
)
# CHECK: linalg.elementwise kind=#linalg.elementwise_kind<select>
# CHECK-SAME: indexing_maps = [#[[$VertLineBCastMap]], #[[$HorLineBCastMap]], #[[$TransMap2D]], #[[$IdentMap2D]]]
# CHECK-SAME: ins(%[[VertLineBoolsMem]], %[[HorLineMem]], %[[TransRectMem]] : memref<8xi1>, memref<16xf32>, memref<16x8xf32>)
# CHECK-SAME: outs(%[[OutRectMem]] : memref<8x16xf32>)
linalg.elementwise(
vert_line_bools_mem,
hor_line_mem,
trans_rect_mem,
outs=(out_rect_mem,),
kind=linalg.ElementwiseKind.select,
indexing_maps=[
vert_line_bcast_map,
hor_line_bcast_map,
transposed_map_2d,
ident_map_2d,
],
)
print(module)