[MLIR][Linalg][Python] Improve bindings for linalg.elementwise (#139462)
Adds wrappers for ElementWiseOp, in particular to ensure appropriate default indexing maps are derived.
This commit is contained in:
parent
688bccb290
commit
ba739c166d
@ -216,6 +216,67 @@ def contract(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Extend and shadow the TableGen-derived version to make sure correct default
|
||||||
|
# indexing_maps are derived (as there is no mechanism for doing so given the
|
||||||
|
# Python API bypasses the C++-builders).
|
||||||
|
class ElementwiseOp_(ElementwiseOp):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
result_tensors,
|
||||||
|
inputs,
|
||||||
|
outputs,
|
||||||
|
kind,
|
||||||
|
*,
|
||||||
|
indexing_maps=None,
|
||||||
|
loc=None,
|
||||||
|
ip=None,
|
||||||
|
):
|
||||||
|
if indexing_maps is None:
|
||||||
|
inputs = [_get_op_result_or_value(in_) for in_ in inputs]
|
||||||
|
for in0, in1 in zip(inputs[:-1], inputs[1:]):
|
||||||
|
assert in0.type == in1.type
|
||||||
|
output = _get_op_result_or_value(outputs[0])
|
||||||
|
assert inputs[0].type == output.type
|
||||||
|
num_args = len(inputs) + 1
|
||||||
|
indexing_maps = [AffineMap.get_identity(output.type.rank)] * num_args
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
result_tensors=result_tensors,
|
||||||
|
inputs=inputs,
|
||||||
|
outputs=outputs,
|
||||||
|
kind=kind,
|
||||||
|
indexing_maps=indexing_maps,
|
||||||
|
loc=loc,
|
||||||
|
ip=ip,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
ElementwiseOp = ElementwiseOp_
|
||||||
|
|
||||||
|
|
||||||
|
def elementwise(
|
||||||
|
*ins: Union[Operation, OpView, Value],
|
||||||
|
outs: Sequence[Union[Operation, OpView, Value]],
|
||||||
|
kind: Union[ElementwiseKind, Attribute],
|
||||||
|
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
|
||||||
|
):
|
||||||
|
ins = [_get_op_result_or_value(input) for input in ins]
|
||||||
|
if len(outs) != 1:
|
||||||
|
raise ValueError(f"{outs=} must have length 1.")
|
||||||
|
init = _get_op_result_or_value(outs[0])
|
||||||
|
result_types = [init.type] if isinstance(init.type, RankedTensorType) else []
|
||||||
|
|
||||||
|
op = ElementwiseOp(
|
||||||
|
result_tensors=result_types,
|
||||||
|
inputs=ins,
|
||||||
|
outputs=[init],
|
||||||
|
kind=kind,
|
||||||
|
indexing_maps=indexing_maps,
|
||||||
|
)
|
||||||
|
fill_builtin_region(op.operation)
|
||||||
|
return _get_op_result_or_op_results(op)
|
||||||
|
|
||||||
|
|
||||||
def pack(
|
def pack(
|
||||||
source,
|
source,
|
||||||
dest,
|
dest,
|
||||||
|
@ -606,3 +606,189 @@ def testPackUnPackOp():
|
|||||||
# CHECK: return %[[VAL_4]] : tensor<128x128xf32>
|
# CHECK: return %[[VAL_4]] : tensor<128x128xf32>
|
||||||
# CHECK: }
|
# CHECK: }
|
||||||
print(module)
|
print(module)
|
||||||
|
|
||||||
|
|
||||||
|
# CHECK-LABEL: TEST: testElementwiseOp
|
||||||
|
@run
|
||||||
|
def testElementwiseOp():
|
||||||
|
with Context(), Location.unknown():
|
||||||
|
module = Module.create()
|
||||||
|
f32 = F32Type.get()
|
||||||
|
with InsertionPoint(module.body):
|
||||||
|
rect_shape = (8, 16)
|
||||||
|
vert_line_shape = (8,)
|
||||||
|
hor_line_shape = (16,)
|
||||||
|
transposed_rect_shape = (16, 8)
|
||||||
|
|
||||||
|
# CHECK-DAG: #[[$IdentMap2D:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||||
|
# CHECK-DAG: #[[$TransMap2D:.*]] = affine_map<(d0, d1) -> (d1, d0)>
|
||||||
|
# CHECK-DAG: #[[$VertLineBCastMap:.*]] = affine_map<(d0, d1) -> (d0)>
|
||||||
|
# CHECK-DAG: #[[$HorLineBCastMap:.*]] = affine_map<(d0, d1) -> (d1)>
|
||||||
|
|
||||||
|
ident_map_2d = AffineMap.get_identity(2)
|
||||||
|
transposed_map_2d = AffineMap.get_permutation((1, 0))
|
||||||
|
vert_line_bcast_map = AffineMap.get(2, 0, [AffineDimExpr.get(0)])
|
||||||
|
hor_line_bcast_map = AffineMap.get(2, 0, [AffineDimExpr.get(1)])
|
||||||
|
|
||||||
|
# CHECK: func.func @elementwise_op(
|
||||||
|
@func.FuncOp.from_py_func(
|
||||||
|
# CHECK-SAME: %[[Rect:.*]]: tensor<8x16xf32>,
|
||||||
|
RankedTensorType.get(rect_shape, f32),
|
||||||
|
# CHECK-SAME: %[[RectMem:.*]]: memref<8x16xf32>,
|
||||||
|
MemRefType.get(rect_shape, f32),
|
||||||
|
# CHECK-SAME: %[[VertLine:.*]]: tensor<8xf32>,
|
||||||
|
RankedTensorType.get(vert_line_shape, f32),
|
||||||
|
# CHECK-SAME: %[[VertLineMem:.*]]: memref<8xf32>,
|
||||||
|
MemRefType.get(vert_line_shape, f32),
|
||||||
|
# CHECK-SAME: %[[HorLine:.*]]: tensor<16xf32>,
|
||||||
|
RankedTensorType.get(hor_line_shape, f32),
|
||||||
|
# CHECK-SAME: %[[HorLineMem:.*]]: memref<16xf32>,
|
||||||
|
MemRefType.get(hor_line_shape, f32),
|
||||||
|
# CHECK-SAME: %[[TransRect:.*]]: tensor<16x8xf32>,
|
||||||
|
RankedTensorType.get(transposed_rect_shape, f32),
|
||||||
|
# CHECK-SAME: %[[TransRectMem:.*]]: memref<16x8xf32>)
|
||||||
|
MemRefType.get(transposed_rect_shape, f32),
|
||||||
|
)
|
||||||
|
def elementwise_op(
|
||||||
|
rect,
|
||||||
|
rect_mem,
|
||||||
|
vert_line,
|
||||||
|
vert_line_mem,
|
||||||
|
hor_line,
|
||||||
|
hor_line_mem,
|
||||||
|
trans_rect,
|
||||||
|
trans_rect_mem,
|
||||||
|
):
|
||||||
|
# CHECK: %[[OutRect:.*]] = tensor.empty() : tensor<8x16xf32>
|
||||||
|
out_rect = tensor.EmptyOp(rect_shape, f32)
|
||||||
|
# CHECK: %[[OutRectMem:.*]] = memref.alloca() : memref<8x16xf32>
|
||||||
|
out_rect_mem = memref.alloca(MemRefType.get(rect_shape, f32), [], [])
|
||||||
|
|
||||||
|
if _inferred_affine_maps := True:
|
||||||
|
# CHECK: linalg.elementwise
|
||||||
|
# CHECK-SAME: kind=#linalg.elementwise_kind<exp>
|
||||||
|
# CHECK-SAME: ins(%[[Rect]] : tensor<8x16xf32>)
|
||||||
|
# CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||||
|
op1 = linalg.ElementwiseOp(
|
||||||
|
result_tensors=(out_rect.result.type,),
|
||||||
|
inputs=(rect,),
|
||||||
|
outputs=(out_rect,),
|
||||||
|
kind=linalg.ElementwiseKind.exp,
|
||||||
|
)
|
||||||
|
linalg.fill_builtin_region(op1.operation)
|
||||||
|
|
||||||
|
# CHECK: linalg.elementwise
|
||||||
|
# CHECK-SAME: kind=#linalg.elementwise_kind<exp>
|
||||||
|
# CHECK-SAME: ins(%[[Rect]] : tensor<8x16xf32>)
|
||||||
|
# CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||||
|
linalg.elementwise(
|
||||||
|
rect,
|
||||||
|
outs=(out_rect,),
|
||||||
|
kind=linalg.ElementwiseKind.exp,
|
||||||
|
)
|
||||||
|
|
||||||
|
# CHECK: linalg.elementwise
|
||||||
|
# CHECK-SAME: kind=#linalg.elementwise_kind<exp>
|
||||||
|
# CHECK-SAME: ins(%[[RectMem]] : memref<8x16xf32>)
|
||||||
|
# CHECK-SAME: outs(%[[OutRectMem]] : memref<8x16xf32>)
|
||||||
|
linalg.elementwise(
|
||||||
|
rect_mem,
|
||||||
|
outs=(out_rect_mem,),
|
||||||
|
kind=linalg.ElementwiseKind.exp,
|
||||||
|
)
|
||||||
|
|
||||||
|
if _explicit_ident_affine_maps := True:
|
||||||
|
# Same as above but with default identity indexing_maps explicitly provided.
|
||||||
|
# CHECK: linalg.elementwise
|
||||||
|
# CHECK-SAME: kind=#linalg.elementwise_kind<exp>
|
||||||
|
# CHECK-SAME: ins(%[[Rect]] : tensor<8x16xf32>)
|
||||||
|
# CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||||
|
op3 = linalg.ElementwiseOp(
|
||||||
|
result_tensors=(out_rect.result.type,),
|
||||||
|
inputs=(rect,),
|
||||||
|
outputs=(out_rect,),
|
||||||
|
kind=linalg.ElementwiseKind.exp,
|
||||||
|
indexing_maps=[ident_map_2d, ident_map_2d],
|
||||||
|
)
|
||||||
|
linalg.fill_builtin_region(op3.operation)
|
||||||
|
|
||||||
|
# CHECK: linalg.elementwise
|
||||||
|
# CHECK-SAME: kind=#linalg.elementwise_kind<exp>
|
||||||
|
# CHECK-SAME: ins(%[[RectMem]] : memref<8x16xf32>)
|
||||||
|
# CHECK-SAME: outs(%[[OutRectMem]] : memref<8x16xf32>)
|
||||||
|
linalg.elementwise(
|
||||||
|
rect_mem,
|
||||||
|
outs=(out_rect_mem,),
|
||||||
|
kind=linalg.ElementwiseKind.exp,
|
||||||
|
indexing_maps=[ident_map_2d, ident_map_2d],
|
||||||
|
)
|
||||||
|
|
||||||
|
if _ops_with_non_ident_input_maps := True:
|
||||||
|
# CHECK: linalg.elementwise kind=#linalg.elementwise_kind<exp>
|
||||||
|
# CHECK-SAME: indexing_maps = [#[[$VertLineBCastMap]], #[[$IdentMap2D]]]
|
||||||
|
# CHECK-SAME: ins(%[[VertLine]] : tensor<8xf32>)
|
||||||
|
# CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||||
|
op4 = linalg.ElementwiseOp(
|
||||||
|
result_tensors=(out_rect.result.type,),
|
||||||
|
inputs=(vert_line,),
|
||||||
|
outputs=(out_rect,),
|
||||||
|
kind=linalg.ElementwiseKind.exp,
|
||||||
|
indexing_maps=[vert_line_bcast_map, ident_map_2d],
|
||||||
|
)
|
||||||
|
linalg.fill_builtin_region(op4.operation)
|
||||||
|
|
||||||
|
# CHECK: linalg.elementwise kind=#linalg.elementwise_kind<add>
|
||||||
|
# CHECK-SAME: indexing_maps = [#[[$IdentMap2D]], #[[$VertLineBCastMap]], #[[$IdentMap2D]]]
|
||||||
|
# CHECK-SAME: ins(%[[Rect]], %[[VertLine]] : tensor<8x16xf32>, tensor<8xf32>)
|
||||||
|
# CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||||
|
op4 = linalg.ElementwiseOp(
|
||||||
|
result_tensors=(out_rect.result.type,),
|
||||||
|
inputs=(rect, vert_line),
|
||||||
|
outputs=(out_rect,),
|
||||||
|
kind=linalg.ElementwiseKind.add,
|
||||||
|
indexing_maps=[ident_map_2d, vert_line_bcast_map, ident_map_2d],
|
||||||
|
)
|
||||||
|
linalg.fill_builtin_region(op4.operation)
|
||||||
|
|
||||||
|
# CHECK: linalg.elementwise kind=#linalg.elementwise_kind<div>
|
||||||
|
# CHECK-SAME: indexing_maps = [#[[$VertLineBCastMap]], #[[$HorLineBCastMap]], #[[$IdentMap2D]]]
|
||||||
|
# CHECK-SAME: ins(%[[VertLine]], %[[HorLine]] : tensor<8xf32>, tensor<16xf32>)
|
||||||
|
# CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||||
|
linalg.elementwise(
|
||||||
|
vert_line,
|
||||||
|
hor_line,
|
||||||
|
outs=(out_rect,),
|
||||||
|
kind=linalg.ElementwiseKind.div,
|
||||||
|
indexing_maps=[
|
||||||
|
vert_line_bcast_map,
|
||||||
|
hor_line_bcast_map,
|
||||||
|
ident_map_2d,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
if _ops_with_non_ident_and_transposed_input_maps := True:
|
||||||
|
# CHECK: %[[VertLineBoolsMem:.*]] = memref.alloca() : memref<8xi1>
|
||||||
|
vert_line_bools_mem = memref.alloca(
|
||||||
|
MemRefType.get(vert_line_shape, IntegerType.get_signless(1)),
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
# CHECK: linalg.elementwise kind=#linalg.elementwise_kind<select>
|
||||||
|
# CHECK-SAME: indexing_maps = [#[[$VertLineBCastMap]], #[[$HorLineBCastMap]], #[[$TransMap2D]], #[[$IdentMap2D]]]
|
||||||
|
# CHECK-SAME: ins(%[[VertLineBoolsMem]], %[[HorLineMem]], %[[TransRectMem]] : memref<8xi1>, memref<16xf32>, memref<16x8xf32>)
|
||||||
|
# CHECK-SAME: outs(%[[OutRectMem]] : memref<8x16xf32>)
|
||||||
|
linalg.elementwise(
|
||||||
|
vert_line_bools_mem,
|
||||||
|
hor_line_mem,
|
||||||
|
trans_rect_mem,
|
||||||
|
outs=(out_rect_mem,),
|
||||||
|
kind=linalg.ElementwiseKind.select,
|
||||||
|
indexing_maps=[
|
||||||
|
vert_line_bcast_map,
|
||||||
|
hor_line_bcast_map,
|
||||||
|
transposed_map_2d,
|
||||||
|
ident_map_2d,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(module)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user