13 Commits

Author SHA1 Message Date
Matthias Springer
8c4bc1e75d
[mlir][Transforms] Merge 1:1 and 1:N type converters (#113032)
The 1:N type converter derived from the 1:1 type converter and extends
it with 1:N target materializations. This commit merges the two type
converters and stores 1:N target materializations in the 1:1 type
converter. This is in preparation of merging the 1:1 and 1:N dialect
conversion infrastructures.

1:1 target materializations (producing a single `Value`) will remain
valid. An additional API is added to the type converter to register 1:N
target materializations (producing a `SmallVector<Value>`). Internally,
all target materializations are stored as 1:N materializations.

The 1:N type converter is removed.

Note for LLVM integration: If you are using the `OneToNTypeConverter`,
simply switch all occurrences to `TypeConverter`.

---------

Co-authored-by: Markus Böck <markus.boeck02@gmail.com>
2024-10-25 11:44:20 -07:00
Benjamin Maxwell
fc4485bf98
Revert "[mlir][ArmSME] Pattern to swap shape_cast(tranpose) with transpose(shape_cast) (#100731)" (#102457)
This reverts commit 88accd9aaa20c6a30661c48cc2ca6dbbdf991ec0.

This change can be dropped in favor of just #102017.
2024-08-09 13:29:57 +01:00
Andrzej Warzyński
fe07d9aa41
[mlir][vector] Switch to using getNumScalableDims (nfc) (#100806) 2024-07-27 08:11:00 +01:00
Benjamin Maxwell
88accd9aaa
[mlir][ArmSME] Pattern to swap shape_cast(tranpose) with transpose(shape_cast) (#100731)
This applies when the shape_cast is simply for dropping unit dims, and
the result rank is >= 2.

This simplifies the transpose making it possible for other ArmSME
legalization patterns to handle it.

Example:

```mlir
%0 = vector.transpose %vector, [3, 0, 1, 2]
       : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
%1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
```

```mlir
%0 = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
%1 = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
```
2024-07-26 14:45:51 +01:00
Benjamin Maxwell
c194bc77a2
[mlir][ArmSME] Add rewrite to handle unsupported SVE transposes via SME/ZA (#98620)
This adds a workaround rewrite that allows stores of unsupported SVE
transposes such as:

```mlir
%tr = vector.transpose %vec, [1, 0]
  : vector<2x[4]xf32> to vector<[4]x2xf32>
vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]}
  : vector<[4]x2xf32>,  memref<?x?xf32>
```

To use SME tiles, which are possible to lower (when SME is available):

```mlir
// Insert vector<2x[4]xf32> into an SME tile:
%0 = arm_sme.get_tile : vector<[4]x[4]xf32>
%1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32>
%2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32>
%3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32>
%4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
// Store the tile with a transpose + mask:
%c4_vscale = arith.muli %vscale, %c4 : index
%mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1>
vector.transfer_write %4, %arg1[%arg2, %arg3], %mask
   {permutation_map = affine_map<(d0, d1) -> (d1, d0)>}
   : vector<[4]x[4]xf32>, memref<?x?xf32>
```
2024-07-25 18:15:14 +01:00
Benjamin Maxwell
5ed5d723db
[mlir][ArmSME] Lower multi-tile stores to a single loop (#96187)
This adds a new pattern that can legalize a multi-tile transfer_write as
a single store loop. This is done as part of type decomposition as at
this level we know each tile write is disjoint, but that information is
lost after decomposition (without analysis to reconstruct it).

Example (pseudo-MLIR):

```
vector.transfer_write %vector, %dest[%y, %x], %mask
  : vector<[16]x[8]xi16>, memref<?x?xi16>
```
Is rewritten to:
```
scf.for %slice_idx = %c0 to %c8_vscale step %c1 {
  %upper_slice_mask = vector.extract %mask[%slice_idx] ─┐
    : vector<[8]xi1> from vector<[16]x[8]xi1>           |
  %upper_slice = vector.extract %upper_tile[%slice_idx] |- Store upper tile
    : vector<[8]xi16> from vector<[8]x[8]xi16>          |
  vector.transfer_write %upper_slice,                   |
    %dest[%slice_idx + %y, %x], %upper_slice_mask       |
    : vector<[8]xi16>, memref<?x?xi16>                  ┘
  %lower_slice_idx = %slice_idx + %c8_vscale                 ─┐
  %lower_slice_mask = vector.extract %mask[%lower_slice_idx]  |
    : vector<[8]xi1> from vector<[16]x[8]xi1>                 |
  %lower_slice = vector.extract %lower_tile[%slice_idx]       |- Store lower
    : vector<[8]xi16> from vector<[8]x[8]xi16>                |  tile
  vector.transfer_write %lower_slice,                         |
    %dest[%lower_slice_idx + %y, %x], %lower_slice_mask       |
    : vector<[8]xi16>, memref<?x?xi16>                        ┘
}
```
2024-06-25 12:46:56 +01:00
Benjamin Maxwell
dadcaf8227
[mlir][ArmSME] Support decomposing constant splats into ArmSME tiles (#88762)
This adds a simple rewrite/legalization to decompose constant splats
larger than a single ArmSME tile into multiple SME virtual tile sized
splats. E.g. a constant splat to `vector<[8]x[8]xi32>` would decompose
into four `vector<[4]x[4]xi32>` splats.
2024-04-16 12:54:01 +01:00
Benjamin Maxwell
d1fc59c3b5
[mlir][ArmSME] Rewrite illegal shape_casts to vector.transpose ops (#82985)
This adds a rewrite that converts illegal 2D unit-dim `shape_casts` into
`vector.transpose` ops.

E.g.

```mlir
// Case 1:
%a = vector.shape_cast %0 : vector<[4]x1xf32> to vector<1x[4]xf32>
// Case 2:
%b = vector.shape_cast %1 : vector<[4]x1xf32> to vector<[4]xf32>
```

Becomes:

```mlir
// Case 1:
%a = vector.transpose %0 : [1, 0] vector<[4]x1xf32> to vector<1x[4]xf32>
// Case 2:
%t = vector.transpose %1 : [1, 0] vector<[4]x1xf32> to vector<1x[4]xf32>
%b = vector.shape_cast %t : vector<1x[4]xf32> to vector<[4]xf32>
```

Various lowerings and drop unit-dims patterns add such shape_casts,
however, if they do not cancel out (which they likely won't if we've
reached the vector-legalization pass) they will prevent lowering the IR.

Rewriting them as a transpose gives `LiftIllegalVectorTransposeToMemory`
a chance to eliminate the illegal types.
2024-03-07 17:04:12 +00:00
Benjamin Maxwell
8cfb71613c
[mlir][ArmSME] Replace use of isa with isa_and_present (#82798)
`op` can be null here, in which case this should just return a null
value back.
2024-02-26 09:44:26 +00:00
Benjamin Maxwell
1408667fdd [mlir][ArmSME] Follow MLIR constant style in VectorLegalization.cpp (NFC) 2024-02-23 16:55:32 +00:00
Benjamin Maxwell
0473e322f6
[mlir][ArmSME] Add rewrite to lift illegal vector.transposes to memory (#80170)
When unrolling the reduction dimension of something like a matmul for
SME, you can end up with transposed reads of illegal types, like so:

```mlir
%illegalRead = vector.transfer_read %memref[%a, %b]
                : memref<?x?xf32>, vector<[8]x4xf32>
%legalType = vector.transpose %illegalRead, [1, 0]
                : vector<[8]x4xf32> to vector<4x[8]xf32>
```

Here the `vector<[8]x4xf32>` is an illegal type, there's no way to lower
a scalable vector of fixed vectors. However, as the final type
`vector<4x[8]xf32>` is legal, we can instead lift the transpose to
memory (producing a strided memref), and eliminate all the illegal
types. This is shown below.

```mlir
%readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1]
                : memref<?x?xf32> to memref<?x?xf32>
%transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0)
                : memref<?x?xf32> to memref<?x?xf32>
%legalType = vector.transfer_read %transpose[%c0, %c0]
                : memref<?x?xf32>, vector<4x[8]xf32>
```
2024-02-06 09:30:55 +00:00
Benjamin Maxwell
c2dea7122c
[mlir][ArmSME] Fold extracts from 3D create_masks of SME-like masks (#80148)
When unrolling the reduction dimension of something like a matmul for
SME, it is possible to get 3D masks, which are vectors of SME-like
masks. The 2D masks for individual operations are then extracted from
the 3D masks.

i.e.:

```mlir
%mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1>
%subMask = vector.extract %mask[2]
        : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
```

ArmSME only supports lowering 2D create_masks, so we must fold the
extract into the create_mask. This can be done by checking if the
extraction index is within the true region, then using that select the
first dimension of the 2D mask. This is shown below.

```mlir
%extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index
%newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index
%subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1>
```
2024-02-02 10:06:11 +00:00
Benjamin Maxwell
042800a4dd
[mlir][ArmSME] Add initial SME vector legalization pass (#79152)
This adds a new pass (`-arm-sme-vector-legalization`) which legalizes
vector operations so that they can be lowered to ArmSME. This initial
patch adds decomposition for `vector.outerproduct`,
`vector.transfer_read`, and `vector.transfer_write` when they operate on
vector types larger than a single SME tile. For example, a [8]x[8]xf32
outer product would be decomposed into four [4]x[4]xf32 outer products,
which could then be lowered to ArmSME. These three ops have been picked
as supporting them alone allows lowering matmuls that use all ZA
accumulators to ArmSME.

For it to be possible to legalize a vector type it has to be a multiple
of an SME tile size, but other than that any shape can be used. E.g.
`vector<[8]x[8]xf32>`, `vector<[4]x[16]xf32>`, `vector<[16]x[4]xf32>`
can all be lowered to four `vector<[4]x[4]xf32>` operations.

In future, this pass will be extended with more SME-specific rewrites to
legalize unrolling the reduction dimension of matmuls (which is not
type-decomposition), which is why the pass has quite a general name.
2024-01-31 11:55:22 +00:00