diff --git a/mlir/docs/DefiningDialects/Operations.md b/mlir/docs/DefiningDialects/Operations.md index b3bde055f04f..2225329ff830 100644 --- a/mlir/docs/DefiningDialects/Operations.md +++ b/mlir/docs/DefiningDialects/Operations.md @@ -306,6 +306,8 @@ Right now, the following primitive constraints are supported: * `IntPositive`: Specifying an integer attribute whose value is positive * `IntNonNegative`: Specifying an integer attribute whose value is non-negative +* `IntPowerOf2`: Specifying an integer attribute whose value is a power of + two > 0 * `ArrayMinCount`: Specifying an array attribute to have at least `N` elements * `ArrayMaxCount`: Specifying an array attribute to have at most `N` diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index 09bb3932ef29..9321089ab55f 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1216,6 +1216,11 @@ def LoadOp : MemRef_Op<"load", be reused in the cache. For details, refer to the [https://llvm.org/docs/LangRef.html#load-instruction](LLVM load instruction). + An optional `alignment` attribute allows to specify the byte alignment of the + load operation. It must be a positive power of 2. The operation must access + memory at an address aligned to this boundary. Violations may lead to + architecture-specific faults or performance penalties. + A value of 0 indicates no specific alignment requirement. Example: ```mlir @@ -1226,7 +1231,39 @@ def LoadOp : MemRef_Op<"load", let arguments = (ins Arg:$memref, Variadic:$indices, - DefaultValuedOptionalAttr:$nontemporal); + DefaultValuedOptionalAttr:$nontemporal, + ConfinedAttr, + [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment); + + let builders = [ + OpBuilder<(ins "Value":$memref, + "ValueRange":$indices, + CArg<"bool", "false">:$nontemporal, + CArg<"uint64_t", "0">:$alignment), [{ + return build($_builder, $_state, memref, indices, nontemporal, + alignment != 0 ? $_builder.getI64IntegerAttr(alignment) : + nullptr); + }]>, + OpBuilder<(ins "Type":$resultType, + "Value":$memref, + "ValueRange":$indices, + CArg<"bool", "false">:$nontemporal, + CArg<"uint64_t", "0">:$alignment), [{ + return build($_builder, $_state, resultType, memref, indices, nontemporal, + alignment != 0 ? $_builder.getI64IntegerAttr(alignment) : + nullptr); + }]>, + OpBuilder<(ins "TypeRange":$resultTypes, + "Value":$memref, + "ValueRange":$indices, + CArg<"bool", "false">:$nontemporal, + CArg<"uint64_t", "0">:$alignment), [{ + return build($_builder, $_state, resultTypes, memref, indices, nontemporal, + alignment != 0 ? $_builder.getI64IntegerAttr(alignment) : + nullptr); + }]> + ]; + let results = (outs AnyType:$result); let extraClassDeclaration = [{ @@ -1912,6 +1949,11 @@ def MemRef_StoreOp : MemRef_Op<"store", be reused in the cache. For details, refer to the [https://llvm.org/docs/LangRef.html#store-instruction](LLVM store instruction). + An optional `alignment` attribute allows to specify the byte alignment of the + store operation. It must be a positive power of 2. The operation must access + memory at an address aligned to this boundary. Violations may lead to + architecture-specific faults or performance penalties. + A value of 0 indicates no specific alignment requirement. Example: ```mlir @@ -1923,13 +1965,25 @@ def MemRef_StoreOp : MemRef_Op<"store", Arg:$memref, Variadic:$indices, - DefaultValuedOptionalAttr:$nontemporal); + DefaultValuedOptionalAttr:$nontemporal, + ConfinedAttr, + [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment); let builders = [ + OpBuilder<(ins "Value":$valueToStore, + "Value":$memref, + "ValueRange":$indices, + CArg<"bool", "false">:$nontemporal, + CArg<"uint64_t", "0">:$alignment), [{ + return build($_builder, $_state, valueToStore, memref, indices, nontemporal, + alignment != 0 ? $_builder.getI64IntegerAttr(alignment) : + nullptr); + }]>, OpBuilder<(ins "Value":$valueToStore, "Value":$memref), [{ $_state.addOperands(valueToStore); $_state.addOperands(memref); - }]>]; + }]> + ]; let extraClassDeclaration = [{ Value getValueToStore() { return getOperand(0); } diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index cbe490f6e4dd..e07188a1a04b 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -1809,12 +1809,42 @@ def Vector_LoadOp : Vector_Op<"load", [ ```mlir %result = vector.load %memref[%c0] : memref<7xf32>, vector<8xf32> ``` + + An optional `alignment` attribute allows to specify the byte alignment of the + load operation. It must be a positive power of 2. The operation must access + memory at an address aligned to this boundary. Violations may lead to + architecture-specific faults or performance penalties. + A value of 0 indicates no specific alignment requirement. }]; let arguments = (ins Arg:$base, Variadic:$indices, - DefaultValuedOptionalAttr:$nontemporal); + DefaultValuedOptionalAttr:$nontemporal, + ConfinedAttr, + [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment); + + let builders = [ + OpBuilder<(ins "VectorType":$resultType, + "Value":$base, + "ValueRange":$indices, + CArg<"bool", "false">:$nontemporal, + CArg<"uint64_t", "0">:$alignment), [{ + return build($_builder, $_state, resultType, base, indices, nontemporal, + alignment != 0 ? $_builder.getI64IntegerAttr(alignment) : + nullptr); + }]>, + OpBuilder<(ins "TypeRange":$resultTypes, + "Value":$base, + "ValueRange":$indices, + CArg<"bool", "false">:$nontemporal, + CArg<"uint64_t", "0">:$alignment), [{ + return build($_builder, $_state, resultTypes, base, indices, nontemporal, + alignment != 0 ? $_builder.getI64IntegerAttr(alignment) : + nullptr); + }]> + ]; + let results = (outs AnyVectorOfAnyRank:$result); let extraClassDeclaration = [{ @@ -1895,6 +1925,12 @@ def Vector_StoreOp : Vector_Op<"store", [ ```mlir vector.store %valueToStore, %memref[%c0] : memref<7xf32>, vector<8xf32> ``` + + An optional `alignment` attribute allows to specify the byte alignment of the + store operation. It must be a positive power of 2. The operation must access + memory at an address aligned to this boundary. Violations may lead to + architecture-specific faults or performance penalties. + A value of 0 indicates no specific alignment requirement. }]; let arguments = (ins @@ -1902,8 +1938,21 @@ def Vector_StoreOp : Vector_Op<"store", [ Arg:$base, Variadic:$indices, - DefaultValuedOptionalAttr:$nontemporal - ); + DefaultValuedOptionalAttr:$nontemporal, + ConfinedAttr, + [AllAttrOf<[IntPositive, IntPowerOf2]>]>:$alignment); + + let builders = [ + OpBuilder<(ins "Value":$valueToStore, + "Value":$base, + "ValueRange":$indices, + CArg<"bool", "false">:$nontemporal, + CArg<"uint64_t", "0">:$alignment), [{ + return build($_builder, $_state, valueToStore, base, indices, nontemporal, + alignment != 0 ? $_builder.getI64IntegerAttr(alignment) : + nullptr); + }]> + ]; let extraClassDeclaration = [{ MemRefType getMemRefType() { diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td index e91a13fea5c7..18da85a58071 100644 --- a/mlir/include/mlir/IR/CommonAttrConstraints.td +++ b/mlir/include/mlir/IR/CommonAttrConstraints.td @@ -796,6 +796,10 @@ def IntPositive : AttrConstraint< CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getValue().isStrictlyPositive()">, "whose value is positive">; +def IntPowerOf2 : AttrConstraint< + CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getValue().isPowerOf2()">, + "whose value is a power of two > 0">; + class ArrayMaxCount : AttrConstraint< CPred<"::llvm::cast<::mlir::ArrayAttr>($_self).size() <= " # n>, "with at most " # n # " elements">; diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir index 704cdaf838f4..fa803efa1d91 100644 --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -962,6 +962,24 @@ func.func @test_store_zero_results2(%x: i32, %p: memref) { // ----- +func.func @invalid_load_alignment(%memref: memref<4xi32>) { + %c0 = arith.constant 0 : index + // expected-error @below {{'memref.load' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}} + %val = memref.load %memref[%c0] { alignment = -1 } : memref<4xi32> + return +} + +// ----- + +func.func @invalid_store_alignment(%memref: memref<4xi32>, %val: i32) { + %c0 = arith.constant 0 : index + // expected-error @below {{'memref.store' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}} + memref.store %val, %memref[%c0] { alignment = 3 } : memref<4xi32> + return +} + +// ----- + func.func @test_alloc_memref_map_rank_mismatch() { ^bb0: // expected-error@+1 {{memref layout mismatch between rank and affine map: 2 != 1}} diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir index e11de7bec2d0..6c2298a3f8ac 100644 --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -265,6 +265,17 @@ func.func @zero_dim_no_idx(%arg0 : memref, %arg1 : memref, %arg2 : mem // CHECK: memref.store %{{.*}}, %{{.*}}[] : memref } + +// CHECK-LABEL: func @load_store_alignment +func.func @load_store_alignment(%memref: memref<4xi32>) { + %c0 = arith.constant 0 : index + // CHECK: memref.load {{.*}} {alignment = 16 : i64} + %val = memref.load %memref[%c0] { alignment = 16 } : memref<4xi32> + // CHECK: memref.store {{.*}} {alignment = 16 : i64} + memref.store %val, %memref[%c0] { alignment = 16 } : memref<4xi32> + return +} + // CHECK-LABEL: func @memref_view(%arg0 func.func @memref_view(%arg0 : index, %arg1 : index, %arg2 : index) { %0 = memref.alloc() : memref<2048xi8> diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 5038646e1f02..8017140a0bfa 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1995,6 +1995,15 @@ func.func @vector_load(%src : memref) { // ----- +func.func @invalid_load_alignment(%memref: memref<4xi32>) { + %c0 = arith.constant 0 : index + // expected-error @below {{'vector.load' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}} + %val = vector.load %memref[%c0] { alignment = -1 } : memref<4xi32>, vector<4xi32> + return +} + +// ----- + //===----------------------------------------------------------------------===// // vector.store //===----------------------------------------------------------------------===// @@ -2005,3 +2014,12 @@ func.func @vector_store(%dest : memref, %vec : vector<16x16xi8>) { vector.store %vec, %dest[%c0] : memref, vector<16x16xi8> return } + +// ----- + +func.func @invalid_store_alignment(%memref: memref<4xi32>, %val: vector<4xi32>) { + %c0 = arith.constant 0 : index + // expected-error @below {{'vector.store' op attribute 'alignment' failed to satisfy constraint: 64-bit signless integer attribute whose value is positive and whose value is a power of two > 0}} + vector.store %val, %memref[%c0] { alignment = 3 } : memref<4xi32>, vector<4xi32> + return +} diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index 10bf0f162056..39578ac56e36 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -853,6 +853,16 @@ func.func @vector_load_and_store_2d_vector_memref(%memref : memref<200x100xvecto return } +// CHECK-LABEL: func @load_store_alignment +func.func @load_store_alignment(%memref: memref<4xi32>) { + %c0 = arith.constant 0 : index + // CHECK: vector.load {{.*}} {alignment = 16 : i64} + %val = vector.load %memref[%c0] { alignment = 16 } : memref<4xi32>, vector<4xi32> + // CHECK: vector.store {{.*}} {alignment = 16 : i64} + vector.store %val, %memref[%c0] { alignment = 16 } : memref<4xi32>, vector<4xi32> + return +} + // CHECK-LABEL: @masked_load_and_store func.func @masked_load_and_store(%base: memref, %mask: vector<16xi1>, %passthru: vector<16xf32>) { %c0 = arith.constant 0 : index