[mlir] TileUsingInterface bugfix for dominance error (#178190)
In this PR i move the insertion point in the
`yieldReplacementForFusedProducer` because i ran into some issue where a
`tensor.extract_slices` tried to use a result of `affine.apply` that was
inserted at the end of the block instead of the start of it.
This is the full error of the test i added before this change:
```mlir
third-party/llvm-project/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-scfforall.mlir:83:11: error: operand #1 does not dominate this use
%pack = linalg.pack %gen#1
^
third-party/llvm-project/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-scfforall.mlir:83:11: note: see current operation: %24 = "tensor.extract_slice"(%23, %36, %8) <{operandSegmentSizes = array<i32: 1, 1, 1, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: -9223372036854775808, 1024>, static_strides = array<i64: 1, 1>}> : (tensor<32x1024xf32>, index, index) -> tensor<?x1024xf32>
third-party/llvm-project/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-scfforall.mlir:71:12: note: operand defined here (op in the same block)
%gen:2 = linalg.generic {
^
// -----// IR Dump After InterpreterPass Failed (transform-interpreter) //----- //
#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0) -> (d0 * 16)>
#map2 = affine_map<(d0) -> (d0 * -16 + 32)>
#map3 = affine_map<(d0) -> (16, d0 * -16 + 32)>
#map4 = affine_map<(d0) -> (d0 - 1)>
"builtin.module"() ({
"func.func"() <{function_type = (tensor<32x1024xf32>) -> (tensor<32x1024xf32>, tensor<2x512x16x2xi8>), sym_name = "fuse_pack_consumer_into_multi_output_generic"}> ({
^bb0(%arg1: tensor<32x1024xf32>):
%2 = "arith.constant"() <{value = 0 : i8}> : () -> i8
%3 = "tensor.empty"() : () -> tensor<32x1024xf32>
%4 = "tensor.empty"() : () -> tensor<32x1024xi8>
%5 = "tensor.empty"() : () -> tensor<2x512x16x2xi8>
%6:2 = "linalg.generic"(%arg1, %3, %4) <{indexing_maps = [#map, #map, #map], iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>], operandSegmentSizes = array<i32: 1, 2>}> ({
^bb0(%arg9: f32, %arg10: f32, %arg11: i8):
%41 = "arith.fptoui"(%arg9) : (f32) -> i8
"linalg.yield"(%arg9, %41) : (f32, i8) -> ()
}) : (tensor<32x1024xf32>, tensor<32x1024xf32>, tensor<32x1024xi8>) -> (tensor<32x1024xf32>, tensor<32x1024xi8>)
%7:3 = "scf.forall"(%5, %3, %4) <{operandSegmentSizes = array<i32: 0, 0, 0, 3>, staticLowerBound = array<i64: 0>, staticStep = array<i64: 1>, staticUpperBound = array<i64: 2>}> ({
^bb0(%arg2: index, %arg3: tensor<2x512x16x2xi8>, %arg4: tensor<32x1024xf32>, %arg5: tensor<32x1024xi8>):
%8 = "affine.apply"(%arg2) <{map = #map1}> : (index) -> index
%9 = "affine.apply"(%arg2) <{map = #map2}> : (index) -> index
%10 = "affine.min"(%arg2) <{map = #map3}> : (index) -> index
%11 = "affine.apply"(%10) <{map = #map4}> : (index) -> index
%12 = "affine.apply"(%arg2) <{map = #map1}> : (index) -> index
%13 = "affine.apply"(%10) <{map = #map4}> : (index) -> index
%14 = "affine.apply"(%arg2) <{map = #map1}> : (index) -> index
%15 = "affine.apply"(%10) <{map = #map4}> : (index) -> index
%16 = "affine.apply"(%arg2) <{map = #map1}> : (index) -> index
%17 = "affine.apply"(%10) <{map = #map4}> : (index) -> index
%18 = "tensor.extract_slice"(%arg1, %12, %10) <{operandSegmentSizes = array<i32: 1, 1, 1, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: -9223372036854775808, 1024>, static_strides = array<i64: 1, 1>}> : (tensor<32x1024xf32>, index, index) -> tensor<?x1024xf32>
%19 = "tensor.empty"() : () -> tensor<32x1024xf32>
%20 = "tensor.extract_slice"(%19, %14, %10) <{operandSegmentSizes = array<i32: 1, 1, 1, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: -9223372036854775808, 1024>, static_strides = array<i64: 1, 1>}> : (tensor<32x1024xf32>, index, index) -> tensor<?x1024xf32>
%21 = "tensor.extract_slice"(%3, %14, %10) <{operandSegmentSizes = array<i32: 1, 1, 1, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: -9223372036854775808, 1024>, static_strides = array<i64: 1, 1>}> : (tensor<32x1024xf32>, index, index) -> tensor<?x1024xf32>
%22 = "tensor.empty"() : () -> tensor<32x1024xi8>
%23 = "tensor.extract_slice"(%22, %16, %10) <{operandSegmentSizes = array<i32: 1, 1, 1, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: -9223372036854775808, 1024>, static_strides = array<i64: 1, 1>}> : (tensor<32x1024xi8>, index, index) -> tensor<?x1024xi8>
%24 = "tensor.extract_slice"(%4, %16, %10) <{operandSegmentSizes = array<i32: 1, 1, 1, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: -9223372036854775808, 1024>, static_strides = array<i64: 1, 1>}> : (tensor<32x1024xi8>, index, index) -> tensor<?x1024xi8>
%25 = "tensor.empty"() : () -> tensor<32x1024xf32>
%26 = "tensor.extract_slice"(%25, %38, %10) <{operandSegmentSizes = array<i32: 1, 1, 1, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: -9223372036854775808, 1024>, static_strides = array<i64: 1, 1>}> : (tensor<32x1024xf32>, index, index) -> tensor<?x1024xf32>
%27 = "tensor.extract_slice"(%arg4, %38, %10) <{operandSegmentSizes = array<i32: 1, 1, 1, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: -9223372036854775808, 1024>, static_strides = array<i64: 1, 1>}> : (tensor<32x1024xf32>, index, index) -> tensor<?x1024xf32>
%28 = "tensor.empty"() : () -> tensor<32x1024xi8>
%29 = "tensor.extract_slice"(%28, %8, %10) <{operandSegmentSizes = array<i32: 1, 1, 1, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: -9223372036854775808, 1024>, static_strides = array<i64: 1, 1>}> : (tensor<32x1024xi8>, index, index) -> tensor<?x1024xi8>
%30 = "tensor.extract_slice"(%arg5, %8, %10) <{operandSegmentSizes = array<i32: 1, 1, 1, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: -9223372036854775808, 1024>, static_strides = array<i64: 1, 1>}> : (tensor<32x1024xi8>, index, index) -> tensor<?x1024xi8>
%31:2 = "linalg.generic"(%18, %27, %30) <{indexing_maps = [#map, #map, #map], iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>], operandSegmentSizes = array<i32: 1, 2>}> ({
^bb0(%arg6: f32, %arg7: f32, %arg8: i8):
%40 = "arith.fptoui"(%arg6) : (f32) -> i8
"linalg.yield"(%arg6, %40) : (f32, i8) -> ()
}) : (tensor<?x1024xf32>, tensor<?x1024xf32>, tensor<?x1024xi8>) -> (tensor<?x1024xf32>, tensor<?x1024xi8>)
%32 = "tensor.extract_slice"(%6#1, %8, %10) <{operandSegmentSizes = array<i32: 1, 1, 1, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: -9223372036854775808, 1024>, static_strides = array<i64: 1, 1>}> : (tensor<32x1024xi8>, index, index) -> tensor<?x1024xi8>
%33 = "tensor.empty"() : () -> tensor<2x512x16x2xi8>
%34 = "tensor.extract_slice"(%33, %arg2) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808, 0, 0, 0>, static_sizes = array<i64: 1, 512, 16, 2>, static_strides = array<i64: 1, 1, 1, 1>}> : (tensor<2x512x16x2xi8>, index) -> tensor<1x512x16x2xi8>
%35 = "tensor.extract_slice"(%arg3, %arg2) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808, 0, 0, 0>, static_sizes = array<i64: 1, 512, 16, 2>, static_strides = array<i64: 1, 1, 1, 1>}> : (tensor<2x512x16x2xi8>, index) -> tensor<1x512x16x2xi8>
%36 = "linalg.pack"(%31#1, %35, %2) <{inner_dims_pos = array<i64: 0, 1>, operandSegmentSizes = array<i32: 1, 1, 1, 0>, static_inner_tiles = array<i64: 16, 2>}> : (tensor<?x1024xi8>, tensor<1x512x16x2xi8>, i8) -> tensor<1x512x16x2xi8>
%37 = "affine.apply"(%10) <{map = #map4}> : (index) -> index
%38 = "affine.apply"(%arg2) <{map = #map1}> : (index) -> index
%39 = "affine.apply"(%10) <{map = #map4}> : (index) -> index
"scf.forall.in_parallel"() ({
"tensor.parallel_insert_slice"(%36, %arg3, %arg2) <{operandSegmentSizes = array<i32: 1, 1, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808, 0, 0, 0>, static_sizes = array<i64: 1, 512, 16, 2>, static_strides = array<i64: 1, 1, 1, 1>}> : (tensor<1x512x16x2xi8>, tensor<2x512x16x2xi8>, index) -> ()
"tensor.parallel_insert_slice"(%31#0, %arg4, %38, %10) <{operandSegmentSizes = array<i32: 1, 1, 1, 1, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: -9223372036854775808, 1024>, static_strides = array<i64: 1, 1>}> : (tensor<?x1024xf32>, tensor<32x1024xf32>, index, index) -> ()
"tensor.parallel_insert_slice"(%31#1, %arg5, %8, %10) <{operandSegmentSizes = array<i32: 1, 1, 1, 1, 0>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: -9223372036854775808, 1024>, static_strides = array<i64: 1, 1>}> : (tensor<?x1024xi8>, tensor<32x1024xi8>, index, index) -> ()
}) : () -> ()
}) : (tensor<2x512x16x2xi8>, tensor<32x1024xf32>, tensor<32x1024xi8>) -> (tensor<2x512x16x2xi8>, tensor<32x1024xf32>, tensor<32x1024xi8>)
"func.return"(%7#1, %7#0) : (tensor<32x1024xf32>, tensor<2x512x16x2xi8>) -> ()
}) : () -> ()
"builtin.module"() ({
"transform.named_sequence"() <{arg_attrs = [{transform.readonly}], function_type = (!transform.any_op) -> (), sym_name = "__transform_main"}> ({
^bb0(%arg0: !transform.any_op):
%0 = "transform.structured.match"(%arg0) <{ops = ["linalg.pack"]}> : (!transform.any_op) -> !transform.any_op
%1:2 = "transform.test.fuse_and_yield"(%0) <{tile_interchange = [], tile_sizes = [1], use_forall = true}> : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
"transform.yield"() : () -> ()
}) : () -> ()
}) {transform.with_named_sequence} : () -> ()
}) : () -> ()
```
I also noticed that Interface tests are missing from the bazel overlay
so i also added this.