[mlir][vector][memref] Add alignment attribute to memory access ops (#144344)

Alignment information is important to allow LLVM backends such as AMDGPU
to select wide memory accesses (e.g., dwordx4 or b128). Since this info
is not always inferable, it's better to inform LLVM backends explicitly
about it. Furthermore, alignment is not necessarily a property of the
element type, but of each individual memory access op (we can have
overaligned and underaligned accesses compared to the natural/preferred
alignment of the element type).

This patch introduces `alignment` attribute to memref/vector.load/store
ops.

Follow-up PRs will

1. Propagate the attribute to LLVM/SPIR-V.

2. Introduce `alignment` attribute to other vector memory access ops:
    vector.gather + vector.scatter
    vector.transfer_read + vector.transfer_write
    vector.compressstore + vector.expandload
    vector.maskedload + vector.maskedstore

3. Replace `--convert-vector-to-llvm='use-vector-alignment=1` with a
   simple pass to populate alignment attributes based on the vector
   types.
This commit is contained in:
tyb0807 2025-07-17 19:38:21 +02:00 committed by GitHub
parent 163da8796b
commit aa3978573e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 172 additions and 6 deletions

View File

@ -306,6 +306,8 @@ Right now, the following primitive constraints are supported:
* `IntPositive`: Specifying an integer attribute whose value is positive
* `IntNonNegative`: Specifying an integer attribute whose value is
non-negative
* `IntPowerOf2`: Specifying an integer attribute whose value is a power of
two > 0
* `ArrayMinCount<N>`: Specifying an array attribute to have at least `N`
elements
* `ArrayMaxCount<N>`: Specifying an array attribute to have at most `N`

View File

@ -1216,6 +1216,11 @@ def LoadOp : MemRef_Op<"load",
be reused in the cache. For details, refer to the
[https://llvm.org/docs/LangRef.html#load-instruction](LLVM load instruction).
An optional `alignment` attribute allows to specify the byte alignment of the
load operation. It must be a positive power of 2. The operation must access
memory at an address aligned to this boundary. Violations may lead to
architecture-specific faults or performance penalties.
A value of 0 indicates no specific alignment requirement.
Example:
```mlir
@ -1226,7 +1231,39 @@ def LoadOp : MemRef_Op<"load",
let arguments = (ins Arg<AnyMemRef, "the reference to load from",
[MemRead]>:$memref,
Variadic<Index>:$indices,
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
ConfinedAttr<OptionalAttr<I64Attr>,
[AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment);
let builders = [
OpBuilder<(ins "Value":$memref,
"ValueRange":$indices,
CArg<"bool", "false">:$nontemporal,
CArg<"uint64_t", "0">:$alignment), [{
return build($_builder, $_state, memref, indices, nontemporal,
alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
nullptr);
}]>,
OpBuilder<(ins "Type":$resultType,
"Value":$memref,
"ValueRange":$indices,
CArg<"bool", "false">:$nontemporal,
CArg<"uint64_t", "0">:$alignment), [{
return build($_builder, $_state, resultType, memref, indices, nontemporal,
alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
nullptr);
}]>,
OpBuilder<(ins "TypeRange":$resultTypes,
"Value":$memref,
"ValueRange":$indices,
CArg<"bool", "false">:$nontemporal,
CArg<"uint64_t", "0">:$alignment), [{
return build($_builder, $_state, resultTypes, memref, indices, nontemporal,
alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
nullptr);
}]>
];
let results = (outs AnyType:$result);
let extraClassDeclaration = [{
@ -1912,6 +1949,11 @@ def MemRef_StoreOp : MemRef_Op<"store",
be reused in the cache. For details, refer to the
[https://llvm.org/docs/LangRef.html#store-instruction](LLVM store instruction).
An optional `alignment` attribute allows to specify the byte alignment of the
store operation. It must be a positive power of 2. The operation must access
memory at an address aligned to this boundary. Violations may lead to
architecture-specific faults or performance penalties.
A value of 0 indicates no specific alignment requirement.
Example:
```mlir
@ -1923,13 +1965,25 @@ def MemRef_StoreOp : MemRef_Op<"store",
Arg<AnyMemRef, "the reference to store to",
[MemWrite]>:$memref,
Variadic<Index>:$indices,
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
ConfinedAttr<OptionalAttr<I64Attr>,
[AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment);
let builders = [
OpBuilder<(ins "Value":$valueToStore,
"Value":$memref,
"ValueRange":$indices,
CArg<"bool", "false">:$nontemporal,
CArg<"uint64_t", "0">:$alignment), [{
return build($_builder, $_state, valueToStore, memref, indices, nontemporal,
alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
nullptr);
}]>,
OpBuilder<(ins "Value":$valueToStore, "Value":$memref), [{
$_state.addOperands(valueToStore);
$_state.addOperands(memref);
}]>];
}]>
];
let extraClassDeclaration = [{
Value getValueToStore() { return getOperand(0); }

View File

@ -1809,12 +1809,42 @@ def Vector_LoadOp : Vector_Op<"load", [
```mlir
%result = vector.load %memref[%c0] : memref<7xf32>, vector<8xf32>
```
An optional `alignment` attribute allows to specify the byte alignment of the
load operation. It must be a positive power of 2. The operation must access
memory at an address aligned to this boundary. Violations may lead to
architecture-specific faults or performance penalties.
A value of 0 indicates no specific alignment requirement.
}];
let arguments = (ins Arg<AnyMemRef, "the reference to load from",
[MemRead]>:$base,
Variadic<Index>:$indices,
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
ConfinedAttr<OptionalAttr<I64Attr>,
[AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment);
let builders = [
OpBuilder<(ins "VectorType":$resultType,
"Value":$base,
"ValueRange":$indices,
CArg<"bool", "false">:$nontemporal,
CArg<"uint64_t", "0">:$alignment), [{
return build($_builder, $_state, resultType, base, indices, nontemporal,
alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
nullptr);
}]>,
OpBuilder<(ins "TypeRange":$resultTypes,
"Value":$base,
"ValueRange":$indices,
CArg<"bool", "false">:$nontemporal,
CArg<"uint64_t", "0">:$alignment), [{
return build($_builder, $_state, resultTypes, base, indices, nontemporal,
alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
nullptr);
}]>
];
let results = (outs AnyVectorOfAnyRank:$result);
let extraClassDeclaration = [{
@ -1895,6 +1925,12 @@ def Vector_StoreOp : Vector_Op<"store", [
```mlir
vector.store %valueToStore, %memref[%c0] : memref<7xf32>, vector<8xf32>
```
An optional `alignment` attribute allows to specify the byte alignment of the
store operation. It must be a positive power of 2. The operation must access
memory at an address aligned to this boundary. Violations may lead to
architecture-specific faults or performance penalties.
A value of 0 indicates no specific alignment requirement.
}];
let arguments = (ins
@ -1902,8 +1938,21 @@ def Vector_StoreOp : Vector_Op<"store", [
Arg<AnyMemRef, "the reference to store to",
[MemWrite]>:$base,
Variadic<Index>:$indices,
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal
);
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
ConfinedAttr<OptionalAttr<I64Attr>,
[AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment);
let builders = [
OpBuilder<(ins "Value":$valueToStore,
"Value":$base,
"ValueRange":$indices,
CArg<"bool", "false">:$nontemporal,
CArg<"uint64_t", "0">:$alignment), [{
return build($_builder, $_state, valueToStore, base, indices, nontemporal,
alignment != 0 ? $_builder.getI64IntegerAttr(alignment) :
nullptr);
}]>
];
let extraClassDeclaration = [{
MemRefType getMemRefType() {

View File

@ -796,6 +796,10 @@ def IntPositive : AttrConstraint<
CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getValue().isStrictlyPositive()">,
"whose value is positive">;
def IntPowerOf2 : AttrConstraint<
CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getValue().isPowerOf2()">,
"whose value is a power of two > 0">;
class ArrayMaxCount<int n> : AttrConstraint<
CPred<"::llvm::cast<::mlir::ArrayAttr>($_self).size() <= " # n>,
"with at most " # n # " elements">;

View File

@ -962,6 +962,24 @@ func.func @test_store_zero_results2(%x: i32, %p: memref<i32>) {
// -----
func.func @invalid_load_alignment(%memref: memref<4xi32>) {
%c0 = arith.constant 0 : index
// expected-error @below {{'memref.load' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
%val = memref.load %memref[%c0] { alignment = -1 } : memref<4xi32>
return
}
// -----
func.func @invalid_store_alignment(%memref: memref<4xi32>, %val: i32) {
%c0 = arith.constant 0 : index
// expected-error @below {{'memref.store' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
memref.store %val, %memref[%c0] { alignment = 3 } : memref<4xi32>
return
}
// -----
func.func @test_alloc_memref_map_rank_mismatch() {
^bb0:
// expected-error@+1 {{memref layout mismatch between rank and affine map: 2 != 1}}

View File

@ -265,6 +265,17 @@ func.func @zero_dim_no_idx(%arg0 : memref<i32>, %arg1 : memref<i32>, %arg2 : mem
// CHECK: memref.store %{{.*}}, %{{.*}}[] : memref<i32>
}
// CHECK-LABEL: func @load_store_alignment
func.func @load_store_alignment(%memref: memref<4xi32>) {
%c0 = arith.constant 0 : index
// CHECK: memref.load {{.*}} {alignment = 16 : i64}
%val = memref.load %memref[%c0] { alignment = 16 } : memref<4xi32>
// CHECK: memref.store {{.*}} {alignment = 16 : i64}
memref.store %val, %memref[%c0] { alignment = 16 } : memref<4xi32>
return
}
// CHECK-LABEL: func @memref_view(%arg0
func.func @memref_view(%arg0 : index, %arg1 : index, %arg2 : index) {
%0 = memref.alloc() : memref<2048xi8>

View File

@ -1995,6 +1995,15 @@ func.func @vector_load(%src : memref<?xi8>) {
// -----
func.func @invalid_load_alignment(%memref: memref<4xi32>) {
%c0 = arith.constant 0 : index
// expected-error @below {{'vector.load' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
%val = vector.load %memref[%c0] { alignment = -1 } : memref<4xi32>, vector<4xi32>
return
}
// -----
//===----------------------------------------------------------------------===//
// vector.store
//===----------------------------------------------------------------------===//
@ -2005,3 +2014,12 @@ func.func @vector_store(%dest : memref<?xi8>, %vec : vector<16x16xi8>) {
vector.store %vec, %dest[%c0] : memref<?xi8>, vector<16x16xi8>
return
}
// -----
func.func @invalid_store_alignment(%memref: memref<4xi32>, %val: vector<4xi32>) {
%c0 = arith.constant 0 : index
// expected-error @below {{'vector.store' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}}
vector.store %val, %memref[%c0] { alignment = 3 } : memref<4xi32>, vector<4xi32>
return
}

View File

@ -853,6 +853,16 @@ func.func @vector_load_and_store_2d_vector_memref(%memref : memref<200x100xvecto
return
}
// CHECK-LABEL: func @load_store_alignment
func.func @load_store_alignment(%memref: memref<4xi32>) {
%c0 = arith.constant 0 : index
// CHECK: vector.load {{.*}} {alignment = 16 : i64}
%val = vector.load %memref[%c0] { alignment = 16 } : memref<4xi32>, vector<4xi32>
// CHECK: vector.store {{.*}} {alignment = 16 : i64}
vector.store %val, %memref[%c0] { alignment = 16 } : memref<4xi32>, vector<4xi32>
return
}
// CHECK-LABEL: @masked_load_and_store
func.func @masked_load_and_store(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) {
%c0 = arith.constant 0 : index