The 1:N type converter derived from the 1:1 type converter and extends
it with 1:N target materializations. This commit merges the two type
converters and stores 1:N target materializations in the 1:1 type
converter. This is in preparation of merging the 1:1 and 1:N dialect
conversion infrastructures.
1:1 target materializations (producing a single `Value`) will remain
valid. An additional API is added to the type converter to register 1:N
target materializations (producing a `SmallVector<Value>`). Internally,
all target materializations are stored as 1:N materializations.
The 1:N type converter is removed.
Note for LLVM integration: If you are using the `OneToNTypeConverter`,
simply switch all occurrences to `TypeConverter`.
---------
Co-authored-by: Markus Böck <markus.boeck02@gmail.com>
This applies when the shape_cast is simply for dropping unit dims, and
the result rank is >= 2.
This simplifies the transpose making it possible for other ArmSME
legalization patterns to handle it.
Example:
```mlir
%0 = vector.transpose %vector, [3, 0, 1, 2]
: vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
%1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
```
```mlir
%0 = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
%1 = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
```
This adds a workaround rewrite that allows stores of unsupported SVE
transposes such as:
```mlir
%tr = vector.transpose %vec, [1, 0]
: vector<2x[4]xf32> to vector<[4]x2xf32>
vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]}
: vector<[4]x2xf32>, memref<?x?xf32>
```
To use SME tiles, which are possible to lower (when SME is available):
```mlir
// Insert vector<2x[4]xf32> into an SME tile:
%0 = arm_sme.get_tile : vector<[4]x[4]xf32>
%1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32>
%2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32>
%3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32>
%4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
// Store the tile with a transpose + mask:
%c4_vscale = arith.muli %vscale, %c4 : index
%mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1>
vector.transfer_write %4, %arg1[%arg2, %arg3], %mask
{permutation_map = affine_map<(d0, d1) -> (d1, d0)>}
: vector<[4]x[4]xf32>, memref<?x?xf32>
```
This adds a new pattern that can legalize a multi-tile transfer_write as
a single store loop. This is done as part of type decomposition as at
this level we know each tile write is disjoint, but that information is
lost after decomposition (without analysis to reconstruct it).
Example (pseudo-MLIR):
```
vector.transfer_write %vector, %dest[%y, %x], %mask
: vector<[16]x[8]xi16>, memref<?x?xi16>
```
Is rewritten to:
```
scf.for %slice_idx = %c0 to %c8_vscale step %c1 {
%upper_slice_mask = vector.extract %mask[%slice_idx] ─┐
: vector<[8]xi1> from vector<[16]x[8]xi1> |
%upper_slice = vector.extract %upper_tile[%slice_idx] |- Store upper tile
: vector<[8]xi16> from vector<[8]x[8]xi16> |
vector.transfer_write %upper_slice, |
%dest[%slice_idx + %y, %x], %upper_slice_mask |
: vector<[8]xi16>, memref<?x?xi16> ┘
%lower_slice_idx = %slice_idx + %c8_vscale ─┐
%lower_slice_mask = vector.extract %mask[%lower_slice_idx] |
: vector<[8]xi1> from vector<[16]x[8]xi1> |
%lower_slice = vector.extract %lower_tile[%slice_idx] |- Store lower
: vector<[8]xi16> from vector<[8]x[8]xi16> | tile
vector.transfer_write %lower_slice, |
%dest[%lower_slice_idx + %y, %x], %lower_slice_mask |
: vector<[8]xi16>, memref<?x?xi16> ┘
}
```
This adds a simple rewrite/legalization to decompose constant splats
larger than a single ArmSME tile into multiple SME virtual tile sized
splats. E.g. a constant splat to `vector<[8]x[8]xi32>` would decompose
into four `vector<[4]x[4]xi32>` splats.
This adds a rewrite that converts illegal 2D unit-dim `shape_casts` into
`vector.transpose` ops.
E.g.
```mlir
// Case 1:
%a = vector.shape_cast %0 : vector<[4]x1xf32> to vector<1x[4]xf32>
// Case 2:
%b = vector.shape_cast %1 : vector<[4]x1xf32> to vector<[4]xf32>
```
Becomes:
```mlir
// Case 1:
%a = vector.transpose %0 : [1, 0] vector<[4]x1xf32> to vector<1x[4]xf32>
// Case 2:
%t = vector.transpose %1 : [1, 0] vector<[4]x1xf32> to vector<1x[4]xf32>
%b = vector.shape_cast %t : vector<1x[4]xf32> to vector<[4]xf32>
```
Various lowerings and drop unit-dims patterns add such shape_casts,
however, if they do not cancel out (which they likely won't if we've
reached the vector-legalization pass) they will prevent lowering the IR.
Rewriting them as a transpose gives `LiftIllegalVectorTransposeToMemory`
a chance to eliminate the illegal types.
When unrolling the reduction dimension of something like a matmul for
SME, you can end up with transposed reads of illegal types, like so:
```mlir
%illegalRead = vector.transfer_read %memref[%a, %b]
: memref<?x?xf32>, vector<[8]x4xf32>
%legalType = vector.transpose %illegalRead, [1, 0]
: vector<[8]x4xf32> to vector<4x[8]xf32>
```
Here the `vector<[8]x4xf32>` is an illegal type, there's no way to lower
a scalable vector of fixed vectors. However, as the final type
`vector<4x[8]xf32>` is legal, we can instead lift the transpose to
memory (producing a strided memref), and eliminate all the illegal
types. This is shown below.
```mlir
%readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1]
: memref<?x?xf32> to memref<?x?xf32>
%transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0)
: memref<?x?xf32> to memref<?x?xf32>
%legalType = vector.transfer_read %transpose[%c0, %c0]
: memref<?x?xf32>, vector<4x[8]xf32>
```
When unrolling the reduction dimension of something like a matmul for
SME, it is possible to get 3D masks, which are vectors of SME-like
masks. The 2D masks for individual operations are then extracted from
the 3D masks.
i.e.:
```mlir
%mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1>
%subMask = vector.extract %mask[2]
: vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
```
ArmSME only supports lowering 2D create_masks, so we must fold the
extract into the create_mask. This can be done by checking if the
extraction index is within the true region, then using that select the
first dimension of the 2D mask. This is shown below.
```mlir
%extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index
%newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index
%subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1>
```
This adds a new pass (`-arm-sme-vector-legalization`) which legalizes
vector operations so that they can be lowered to ArmSME. This initial
patch adds decomposition for `vector.outerproduct`,
`vector.transfer_read`, and `vector.transfer_write` when they operate on
vector types larger than a single SME tile. For example, a [8]x[8]xf32
outer product would be decomposed into four [4]x[4]xf32 outer products,
which could then be lowered to ArmSME. These three ops have been picked
as supporting them alone allows lowering matmuls that use all ZA
accumulators to ArmSME.
For it to be possible to legalize a vector type it has to be a multiple
of an SME tile size, but other than that any shape can be used. E.g.
`vector<[8]x[8]xf32>`, `vector<[4]x[16]xf32>`, `vector<[16]x[4]xf32>`
can all be lowered to four `vector<[4]x[4]xf32>` operations.
In future, this pass will be extended with more SME-specific rewrites to
legalize unrolling the reduction dimension of matmuls (which is not
type-decomposition), which is why the pass has quite a general name.