From 5ae2fe75c3898cbf78f170d3cd686e02182f36fc Mon Sep 17 00:00:00 2001 From: Jorn Tuyls Date: Sat, 28 Mar 2026 09:06:19 +0100 Subject: [PATCH] [mlir][vector] Reject alignment attribute on tensor-level gather/scatter (#188924) --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 8 ++++++++ mlir/test/Dialect/Vector/invalid.mlir | 18 ++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index c1536d6e062c..bd419f2ba93e 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -6192,6 +6192,10 @@ LogicalResult GatherOp::verify() { return emitOpError("expected result dim to match mask dim"); if (resVType != getPassThruVectorType()) return emitOpError("expected pass_thru of same type as result type"); + if (getAlignmentAttr() && !isa(baseType)) { + return emitOpError( + "alignment is only supported for memref bases, not tensor bases"); + } return success(); } @@ -6300,6 +6304,10 @@ LogicalResult ScatterOp::verify() { return emitOpError("expected valueToStore dim to match indices dim"); if (valueVType.getShape() != maskVType.getShape()) return emitOpError("expected valueToStore dim to match mask dim"); + if (getAlignmentAttr() && !isa(baseType)) { + return emitOpError( + "alignment is only supported for memref bases, not tensor bases"); + } return success(); } namespace { diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 8f8429e5844d..f90312c91533 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1545,6 +1545,15 @@ func.func @gather_non_power_of_two_alignment(%base: memref<16xf32>, %indices: ve // ----- +func.func @gather_tensor_alignment(%base: tensor<16xf32>, %indices: vector<16xi32>, + %mask: vector<16xi1>, %pass_thru: vector<16xf32>, %c0 : index) { + // expected-error@+1 {{'vector.gather' op alignment is only supported for memref bases, not tensor bases}} + %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru + { alignment = 8 : i64 } : tensor<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> +} + +// ----- + func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) { %c0 = arith.constant 0 : index @@ -1624,6 +1633,15 @@ func.func @scatter_non_power_of_2_alignment(%base: memref, %indices: vect // ----- +func.func @scatter_tensor_alignment(%base: tensor, %indices: vector<16xi32>, + %mask: vector<16xi1>, %value: vector<16xf32>, %c0: index) { + // expected-error@+1 {{'vector.scatter' op alignment is only supported for memref bases, not tensor bases}} + vector.scatter %base[%c0][%indices], %mask, %value { alignment = 8 : i64 } + : tensor, vector<16xi32>, vector<16xi1>, vector<16xf32> -> tensor +} + +// ----- + func.func @expand_base_type_mismatch(%base: memref, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) { %c0 = arith.constant 0 : index // expected-error@+1 {{'vector.expandload' op base element type ('f64') does not match result element type ('f32')}}