This patch introduces support for 4-way widening outer products. This
enables the fusion of 4 'arm_sme.outerproduct' operations that are
chained via the accumulator into single widened operations.
Changes:
- Adds the following operations:
- smopa_4way, smops_4way
- umopa_4way, umops_4way
- sumopa_4way, sumops_4way
- sumopa_4way, sumops_4way
- Implements conversions for the above ops to intrinsics in ArmSMEToLLVM.
- Extends 'arm-sme-outer-product' pass.
For a detailed description of these operations see the
'arm_sme.smopa_4way' description.
This patch introduces support for 2-way widening outer products. This
enables the fusion of 2 'arm_sme.outerproduct' operations that are
chained via the accumulator into a 2-way widening outer product
operation.
Changes:
- Add 'llvm.aarch64.sme.[us]mop[as].za32' intrinsics for 2-way variants.
These map to instruction variants added in SME2 and use different
intrinsics. Intrinsics are already implemented for widening variants
from SME1.
- Adds the following operations:
- fmopa_2way, fmops_2way
- smopa_2way, smops_2way
- umopa_2way, umops_2way
- Implements conversions for the above ops to intrinsics in
ArmSMEToLLVM.
- Adds a pass 'arm-sme-outer-product-fusion' that fuses
'arm_sme.outerproduct' operations.
For a detailed description of these operations see the
'arm_sme.fmopa_2way' description.
The reason for introducing many operations rather than one is the
signed/unsigned variants can't be distinguished with types (e.g., ui16,
si16) since 'arith.extui' and 'arith.extsi' only support signless
integers. A single operation would require this information and an
attribute (for example) for the sign doesn't feel right if
floating-point types are also supported where this wouldn't apply.
Furthermore, the SME FP8 extensions (FEAT_SME_F8F16, FEAT_SME_F8F32)
introduce FMOPA 2-way (FP8 to FP16) and 4-way (FP8 to FP32) variants but
no subtract variant. Whilst these are not supported in this patch, it
felt simpler to have separate ops for add/subtract given this.
This commit renames 4 pattern rewriter API functions:
* `updateRootInPlace` -> `modifyOpInPlace`
* `startRootUpdate` -> `startOpModification`
* `finalizeRootUpdate` -> `finalizeOpModification`
* `cancelRootUpdate` -> `cancelOpModification`
The term "root" is a misnomer. The root is the op that a rewrite pattern
matches against
(https://mlir.llvm.org/docs/PatternRewriter/#root-operation-name-optional).
A rewriter must be notified of all in-place op modifications, not just
in-place modifications of the root
(https://mlir.llvm.org/docs/PatternRewriter/#pattern-rewriter). The old
function names were confusing and have contributed to various broken
rewrite patterns.
Note: The new function names use the term "modify" instead of "update"
for consistency with the `RewriterBase::Listener` terminology
(`notifyOperationModified`).
This adds very basic (and inelegant) support for something like spilling
and reloading tiles, if you use more SME tiles than physically exist.
This is purely implemented to prevent the compiler from aborting if a
function uses too many tiles (i.e. due to bad unrolling), but is
expected to perform very poorly.
Currently, this works in two stages:
During tile allocation, if we run out of tiles instead of giving up, we
switch to allocating 'in-memory' tile IDs. These are tile IDs that start
at 16 (which is higher than any real tile ID). A warning will also be
emitted for each (root) tile op assigned an in-memory tile ID:
```
warning: failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance
```
Everything after this works like normal until `-convert-arm-sme-to-llvm`
Here the in-memory tile op:
```mlir
arm_sme.tile_op { tile_id = <IN MEMORY TILE> }
```
Is lowered to:
```mlir
// At function entry:
%alloca = memref.alloca ... : memref<?x?xty>
// Around the op:
// Swap the contents of %alloca and tile 0.
scf.for %slice_idx {
%current_slice = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
"arm_sme.intr.ld1h.horiz"(%alloca, %slice_idx) <{tile_id = 0 : i32}>
vector.store %current_slice, %alloca[%slice_idx, %c0]
}
// Execute op using tile 0.
arm_sme.tile_op { tile_id = 0 }
// Swap the contents of %alloca and tile 0.
// This restores tile 0 to its original state.
scf.for %slice_idx {
%current_slice = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
"arm_sme.intr.ld1h.horiz"(%alloca, %slice_idx) <{tile_id = 0 : i32}>
vector.store %current_slice, %alloca[%slice_idx, %c0]
}
```
This is inserted during the lowering to LLVM as spilling/reloading
registers is a very low-level concept, that can't really be modeled
correctly at a high level in MLIR.
Note: This is always doing the worst case full-tile swap. This could be
optimized to only spill/load data the tile op will use, which could be
just a slice. It's also not making any use of liveness, which could
allow reusing tiles. But these is not seen as important as correct code
should only use the available number of tiles.
This operation provides a convenient way to query the streaming vector
length regardless of the streaming mode. This most useful for functions
that call/pass data to streaming functions, but are not streaming
themselves.
Example:
```mlir
%svl_w = arm_sme.streaming_vl <word>
```
Created based on discussion here:
https://github.com/llvm/llvm-project/pull/76086#discussion_r1434226352
This patch removes the ArmSMETypeConverter, and instead updates
`populateArmSMEToLLVMConversionPatterns()` to add an ArmSME vector type
conversion to the existing LLVMTypeConverter. This makes it easier to
add these patterns to an existing `-to-llvm` lowering pass.
This reworks the ArmSME dialect to use attributes for tile allocation.
This has a number of advantages and corrects some issues with the
previous approach:
* Tile allocation can now be done ASAP (i.e. immediately after
`-convert-vector-to-arm-sme`)
* SSA form for control flow is now supported (e.g.`scf.for` loops that
yield tiles)
* ArmSME ops can be converted to intrinsics very late (i.e. after
lowering to control flow)
* Tests are simplified by removing constants and casts
* Avoids correctness issues with representing LLVM `immargs` as MLIR
values
- The tile ID on the SME intrinsics is an `immarg` (so is required to be
a compile-time constant), `immargs` should be mapped to MLIR attributes
(this is already the case for intrinsics in the LLVM dialect)
- Using MLIR values for `immargs` can lead to invalid LLVM IR being
generated (and passes such as -cse making incorrect optimizations)
As part of this patch we bid farewell to the following operations:
```mlir
arm_sme.get_tile_id : i32
arm_sme.cast_tile_to_vector : i32 to vector<[4]x[4]xi32>
arm_sme.cast_vector_to_tile : vector<[4]x[4]xi32> to i32
```
These are now replaced with:
```mlir
// Allocates a new tile with (indeterminate) state:
arm_sme.get_tile : vector<[4]x[4]xi32>
// A placeholder operation for lowering ArmSME ops to intrinsics:
arm_sme.materialize_ssa_tile : vector<[4]x[4]xi32>
```
The new tile allocation works by operations implementing the
`ArmSMETileOpInterface`. This interface says that an operation needs to
be assigned a tile ID, and may conditionally allocate a new SME tile.
Operations allocate a new tile by implementing...
```c++
std::optional<arm_sme::ArmSMETileType> getAllocatedTileType()
```
...and returning what type of tile the op allocates (ZAB, ZAH, etc).
Operations that don't allocate a tile return `std::nullopt` (which is
the default behaviour).
Currently the following ops are defined as allocating:
```mlir
arm_sme.get_tile
arm_sme.zero
arm_sme.tile_load
arm_sme.outerproduct // (if no accumulator is specified)
```
Allocating operations become the roots for the tile allocation pass,
which currently just (naively) assigns all transitive uses of a root
operation the same tile ID. However, this is enough to handle current
use cases.
Once tile IDs have been allocated subsequent rewrites can forward the
tile IDs to any newly created operations.
This gives more flexibility with when these lowerings are performed,
without also lowering unrelated vector ops.
This is a NFC (other than adding a new `-convert-arm-sme-to-llvm` pass)