6 Commits

Author SHA1 Message Date
Benjamin Maxwell
d1fc59c3b5
[mlir][ArmSME] Rewrite illegal shape_casts to vector.transpose ops (#82985)
This adds a rewrite that converts illegal 2D unit-dim `shape_casts` into
`vector.transpose` ops.

E.g.

```mlir
// Case 1:
%a = vector.shape_cast %0 : vector<[4]x1xf32> to vector<1x[4]xf32>
// Case 2:
%b = vector.shape_cast %1 : vector<[4]x1xf32> to vector<[4]xf32>
```

Becomes:

```mlir
// Case 1:
%a = vector.transpose %0 : [1, 0] vector<[4]x1xf32> to vector<1x[4]xf32>
// Case 2:
%t = vector.transpose %1 : [1, 0] vector<[4]x1xf32> to vector<1x[4]xf32>
%b = vector.shape_cast %t : vector<1x[4]xf32> to vector<[4]xf32>
```

Various lowerings and drop unit-dims patterns add such shape_casts,
however, if they do not cancel out (which they likely won't if we've
reached the vector-legalization pass) they will prevent lowering the IR.

Rewriting them as a transpose gives `LiftIllegalVectorTransposeToMemory`
a chance to eliminate the illegal types.
2024-03-07 17:04:12 +00:00
Benjamin Maxwell
8cfb71613c
[mlir][ArmSME] Replace use of isa with isa_and_present (#82798)
`op` can be null here, in which case this should just return a null
value back.
2024-02-26 09:44:26 +00:00
Benjamin Maxwell
1408667fdd [mlir][ArmSME] Follow MLIR constant style in VectorLegalization.cpp (NFC) 2024-02-23 16:55:32 +00:00
Benjamin Maxwell
0473e322f6
[mlir][ArmSME] Add rewrite to lift illegal vector.transposes to memory (#80170)
When unrolling the reduction dimension of something like a matmul for
SME, you can end up with transposed reads of illegal types, like so:

```mlir
%illegalRead = vector.transfer_read %memref[%a, %b]
                : memref<?x?xf32>, vector<[8]x4xf32>
%legalType = vector.transpose %illegalRead, [1, 0]
                : vector<[8]x4xf32> to vector<4x[8]xf32>
```

Here the `vector<[8]x4xf32>` is an illegal type, there's no way to lower
a scalable vector of fixed vectors. However, as the final type
`vector<4x[8]xf32>` is legal, we can instead lift the transpose to
memory (producing a strided memref), and eliminate all the illegal
types. This is shown below.

```mlir
%readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1]
                : memref<?x?xf32> to memref<?x?xf32>
%transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0)
                : memref<?x?xf32> to memref<?x?xf32>
%legalType = vector.transfer_read %transpose[%c0, %c0]
                : memref<?x?xf32>, vector<4x[8]xf32>
```
2024-02-06 09:30:55 +00:00
Benjamin Maxwell
c2dea7122c
[mlir][ArmSME] Fold extracts from 3D create_masks of SME-like masks (#80148)
When unrolling the reduction dimension of something like a matmul for
SME, it is possible to get 3D masks, which are vectors of SME-like
masks. The 2D masks for individual operations are then extracted from
the 3D masks.

i.e.:

```mlir
%mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1>
%subMask = vector.extract %mask[2]
        : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
```

ArmSME only supports lowering 2D create_masks, so we must fold the
extract into the create_mask. This can be done by checking if the
extraction index is within the true region, then using that select the
first dimension of the 2D mask. This is shown below.

```mlir
%extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index
%newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index
%subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1>
```
2024-02-02 10:06:11 +00:00
Benjamin Maxwell
042800a4dd
[mlir][ArmSME] Add initial SME vector legalization pass (#79152)
This adds a new pass (`-arm-sme-vector-legalization`) which legalizes
vector operations so that they can be lowered to ArmSME. This initial
patch adds decomposition for `vector.outerproduct`,
`vector.transfer_read`, and `vector.transfer_write` when they operate on
vector types larger than a single SME tile. For example, a [8]x[8]xf32
outer product would be decomposed into four [4]x[4]xf32 outer products,
which could then be lowered to ArmSME. These three ops have been picked
as supporting them alone allows lowering matmuls that use all ZA
accumulators to ArmSME.

For it to be possible to legalize a vector type it has to be a multiple
of an SME tile size, but other than that any shape can be used. E.g.
`vector<[8]x[8]xf32>`, `vector<[4]x[16]xf32>`, `vector<[16]x[4]xf32>`
can all be lowered to four `vector<[4]x[4]xf32>` operations.

In future, this pass will be extended with more SME-specific rewrites to
legalize unrolling the reduction dimension of matmuls (which is not
type-decomposition), which is why the pass has quite a general name.
2024-01-31 11:55:22 +00:00