[mlir][vector] Reject alignment attribute on tensor-level gather/scatter (#188924)
This commit is contained in:
parent
ad91a2f036
commit
5ae2fe75c3
@ -6192,6 +6192,10 @@ LogicalResult GatherOp::verify() {
|
||||
return emitOpError("expected result dim to match mask dim");
|
||||
if (resVType != getPassThruVectorType())
|
||||
return emitOpError("expected pass_thru of same type as result type");
|
||||
if (getAlignmentAttr() && !isa<MemRefType>(baseType)) {
|
||||
return emitOpError(
|
||||
"alignment is only supported for memref bases, not tensor bases");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
@ -6300,6 +6304,10 @@ LogicalResult ScatterOp::verify() {
|
||||
return emitOpError("expected valueToStore dim to match indices dim");
|
||||
if (valueVType.getShape() != maskVType.getShape())
|
||||
return emitOpError("expected valueToStore dim to match mask dim");
|
||||
if (getAlignmentAttr() && !isa<MemRefType>(baseType)) {
|
||||
return emitOpError(
|
||||
"alignment is only supported for memref bases, not tensor bases");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
namespace {
|
||||
|
||||
@ -1545,6 +1545,15 @@ func.func @gather_non_power_of_two_alignment(%base: memref<16xf32>, %indices: ve
|
||||
|
||||
// -----
|
||||
|
||||
func.func @gather_tensor_alignment(%base: tensor<16xf32>, %indices: vector<16xi32>,
|
||||
%mask: vector<16xi1>, %pass_thru: vector<16xf32>, %c0 : index) {
|
||||
// expected-error@+1 {{'vector.gather' op alignment is only supported for memref bases, not tensor bases}}
|
||||
%0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
|
||||
{ alignment = 8 : i64 } : tensor<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
|
||||
%mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
@ -1624,6 +1633,15 @@ func.func @scatter_non_power_of_2_alignment(%base: memref<?xf32>, %indices: vect
|
||||
|
||||
// -----
|
||||
|
||||
func.func @scatter_tensor_alignment(%base: tensor<?xf32>, %indices: vector<16xi32>,
|
||||
%mask: vector<16xi1>, %value: vector<16xf32>, %c0: index) {
|
||||
// expected-error@+1 {{'vector.scatter' op alignment is only supported for memref bases, not tensor bases}}
|
||||
vector.scatter %base[%c0][%indices], %mask, %value { alignment = 8 : i64 }
|
||||
: tensor<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> -> tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @expand_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
// expected-error@+1 {{'vector.expandload' op base element type ('f64') does not match result element type ('f32')}}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user