[mlir][LLVM] Delete getVectorElementType
(#134981)
The LLVM dialect no longer has its own vector types. It uses `mlir::VectorType` everywhere. Remove `LLVM::getVectorElementType` and use `cast<VectorType>(ty).getElementType()` instead. This commit addresses a [comment](https://github.com/llvm/llvm-project/pull/133286#discussion_r2022192500) on the PR that deleted the LLVM vector types. Also improve vector type constraints by specifying the `mlir::VectorType` C++ class, so that explicit casts to `VectorType` can be avoided in some places.
This commit is contained in:
parent
e0950ebb9c
commit
a0d449016b
@ -334,8 +334,6 @@ compatible with the LLVM dialect:
|
||||
|
||||
- `bool LLVM::isCompatibleVectorType(Type)` - checks whether a type is a
|
||||
vector type compatible with the LLVM dialect;
|
||||
- `Type LLVM::getVectorElementType(Type)` - returns the element type of any
|
||||
vector type compatible with the LLVM dialect;
|
||||
- `llvm::ElementCount LLVM::getVectorNumElements(Type)` - returns the number
|
||||
of elements in any vector type compatible with the LLVM dialect;
|
||||
- `Type LLVM::getFixedVectorType(Type, unsigned)` - gets a fixed vector type
|
||||
|
@ -874,7 +874,7 @@ def LLVM_MatrixColumnMajorLoadOp : LLVM_OneResultIntrOp<"matrix.column.major.loa
|
||||
const llvm::DataLayout &dl =
|
||||
builder.GetInsertBlock()->getModule()->getDataLayout();
|
||||
llvm::Type *ElemTy = moduleTranslation.convertType(
|
||||
getVectorElementType(op.getType()));
|
||||
op.getType().getElementType());
|
||||
llvm::Align align = dl.getABITypeAlign(ElemTy);
|
||||
$res = mb.CreateColumnMajorLoad(
|
||||
ElemTy, $data, align, $stride, $isVolatile, $rows,
|
||||
@ -907,7 +907,7 @@ def LLVM_MatrixColumnMajorStoreOp : LLVM_ZeroResultIntrOp<"matrix.column.major.s
|
||||
llvm::MatrixBuilder mb(builder);
|
||||
const llvm::DataLayout &dl =
|
||||
builder.GetInsertBlock()->getModule()->getDataLayout();
|
||||
Type elementType = getVectorElementType(op.getMatrix().getType());
|
||||
Type elementType = op.getMatrix().getType().getElementType();
|
||||
llvm::Align align = dl.getABITypeAlign(
|
||||
moduleTranslation.convertType(elementType));
|
||||
mb.CreateColumnMajorStore(
|
||||
@ -1164,7 +1164,8 @@ def LLVM_vector_insert
|
||||
let extraClassDeclaration = [{
|
||||
uint64_t getVectorBitWidth(Type vector) {
|
||||
return getVectorNumElements(vector).getKnownMinValue() *
|
||||
getVectorElementType(vector).getIntOrFloatBitWidth();
|
||||
::llvm::cast<VectorType>(vector).getElementType()
|
||||
.getIntOrFloatBitWidth();
|
||||
}
|
||||
uint64_t getSrcVectorBitWidth() {
|
||||
return getVectorBitWidth(getSrcvec().getType());
|
||||
@ -1196,7 +1197,8 @@ def LLVM_vector_extract
|
||||
let extraClassDeclaration = [{
|
||||
uint64_t getVectorBitWidth(Type vector) {
|
||||
return getVectorNumElements(vector).getKnownMinValue() *
|
||||
getVectorElementType(vector).getIntOrFloatBitWidth();
|
||||
::llvm::cast<VectorType>(vector).getElementType()
|
||||
.getIntOrFloatBitWidth();
|
||||
}
|
||||
uint64_t getSrcVectorBitWidth() {
|
||||
return getVectorBitWidth(getSrcvec().getType());
|
||||
@ -1216,8 +1218,8 @@ def LLVM_vector_interleave2
|
||||
"result has twice as many elements as 'vec1'",
|
||||
And<[CPred<"getVectorNumElements($res.getType()) == "
|
||||
"getVectorNumElements($vec1.getType()) * 2">,
|
||||
CPred<"getVectorElementType($vec1.getType()) == "
|
||||
"getVectorElementType($res.getType())">]>>,
|
||||
CPred<"::llvm::cast<VectorType>($vec1.getType()).getElementType() == "
|
||||
"::llvm::cast<VectorType>($res.getType()).getElementType()">]>>,
|
||||
]>,
|
||||
Arguments<(ins LLVM_AnyVector:$vec1, LLVM_AnyVector:$vec2)>;
|
||||
|
||||
|
@ -113,17 +113,20 @@ def LLVM_AnyNonAggregate : Type<And<[LLVM_Type.predicate,
|
||||
|
||||
// Type constraint accepting any LLVM vector type.
|
||||
def LLVM_AnyVector : Type<CPred<"::mlir::LLVM::isCompatibleVectorType($_self)">,
|
||||
"LLVM dialect-compatible vector type">;
|
||||
"LLVM dialect-compatible vector type",
|
||||
"::mlir::VectorType">;
|
||||
|
||||
// Type constraint accepting any LLVM fixed-length vector type.
|
||||
def LLVM_AnyFixedVector : Type<CPred<
|
||||
"!::mlir::LLVM::isScalableVectorType($_self)">,
|
||||
"LLVM dialect-compatible fixed-length vector type">;
|
||||
"LLVM dialect-compatible fixed-length vector type",
|
||||
"::mlir::VectorType">;
|
||||
|
||||
// Type constraint accepting any LLVM scalable vector type.
|
||||
def LLVM_AnyScalableVector : Type<CPred<
|
||||
"::mlir::LLVM::isScalableVectorType($_self)">,
|
||||
"LLVM dialect-compatible scalable vector type">;
|
||||
"LLVM dialect-compatible scalable vector type",
|
||||
"::mlir::VectorType">;
|
||||
|
||||
// Type constraint accepting an LLVM vector type with an additional constraint
|
||||
// on the vector element type.
|
||||
@ -131,9 +134,10 @@ class LLVM_VectorOf<Type element> : Type<
|
||||
And<[LLVM_AnyVector.predicate,
|
||||
SubstLeaves<
|
||||
"$_self",
|
||||
"::mlir::LLVM::getVectorElementType($_self)",
|
||||
"::llvm::cast<::mlir::VectorType>($_self).getElementType()",
|
||||
element.predicate>]>,
|
||||
"LLVM dialect-compatible vector of " # element.summary>;
|
||||
"LLVM dialect-compatible vector of " # element.summary,
|
||||
"::mlir::VectorType">;
|
||||
|
||||
// Type constraint accepting a constrained type, or a vector of such types.
|
||||
class LLVM_ScalarOrVectorOf<Type element> :
|
||||
|
@ -820,8 +820,9 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def LLVM_ExtractElementOp : LLVM_Op<"extractelement", [Pure,
|
||||
TypesMatchWith<"result type matches vector element type", "vector", "res",
|
||||
"LLVM::getVectorElementType($_self)">]> {
|
||||
TypesMatchWith<
|
||||
"result type matches vector element type", "vector", "res",
|
||||
"::llvm::cast<::mlir::VectorType>($_self).getElementType()">]> {
|
||||
let summary = "Extract an element from an LLVM vector.";
|
||||
|
||||
let arguments = (ins LLVM_AnyVector:$vector, AnySignlessInteger:$position);
|
||||
@ -881,7 +882,8 @@ def LLVM_ExtractValueOp : LLVM_Op<"extractvalue", [Pure]> {
|
||||
|
||||
def LLVM_InsertElementOp : LLVM_Op<"insertelement", [Pure,
|
||||
TypesMatchWith<"argument type matches vector element type", "vector",
|
||||
"value", "LLVM::getVectorElementType($_self)">,
|
||||
"value",
|
||||
"::llvm::cast<::mlir::VectorType>($_self).getElementType()">,
|
||||
AllTypesMatch<["res", "vector"]>]> {
|
||||
let summary = "Insert an element into an LLVM vector.";
|
||||
|
||||
|
@ -111,10 +111,6 @@ bool isCompatibleFloatingPointType(Type type);
|
||||
/// dialect pointers and LLVM dialect scalable vector types.
|
||||
bool isCompatibleVectorType(Type type);
|
||||
|
||||
/// Returns the element type of any vector type compatible with the LLVM
|
||||
/// dialect.
|
||||
Type getVectorElementType(Type type);
|
||||
|
||||
/// Returns the element count of any LLVM-compatible vector type.
|
||||
llvm::ElementCount getVectorNumElements(Type type);
|
||||
|
||||
|
@ -78,10 +78,9 @@ static unsigned getBitWidth(Type type) {
|
||||
|
||||
/// Returns the bit width of LLVMType integer or vector.
|
||||
static unsigned getLLVMTypeBitWidth(Type type) {
|
||||
return cast<IntegerType>((LLVM::isCompatibleVectorType(type)
|
||||
? LLVM::getVectorElementType(type)
|
||||
: type))
|
||||
.getWidth();
|
||||
if (auto vecTy = dyn_cast<VectorType>(type))
|
||||
type = vecTy.getElementType();
|
||||
return cast<IntegerType>(type).getWidth();
|
||||
}
|
||||
|
||||
/// Creates `IntegerAttribute` with all bits set for given type
|
||||
|
@ -2734,9 +2734,9 @@ void ShuffleVectorOp::build(OpBuilder &builder, OperationState &state, Value v1,
|
||||
Value v2, DenseI32ArrayAttr mask,
|
||||
ArrayRef<NamedAttribute> attrs) {
|
||||
auto containerType = v1.getType();
|
||||
auto vType = LLVM::getVectorType(LLVM::getVectorElementType(containerType),
|
||||
mask.size(),
|
||||
LLVM::isScalableVectorType(containerType));
|
||||
auto vType = LLVM::getVectorType(
|
||||
cast<VectorType>(containerType).getElementType(), mask.size(),
|
||||
LLVM::isScalableVectorType(containerType));
|
||||
build(builder, state, vType, v1, v2, mask);
|
||||
state.addAttributes(attrs);
|
||||
}
|
||||
@ -2752,8 +2752,9 @@ static ParseResult parseShuffleType(AsmParser &parser, Type v1Type,
|
||||
if (!LLVM::isCompatibleVectorType(v1Type))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"expected an LLVM compatible vector type");
|
||||
resType = LLVM::getVectorType(LLVM::getVectorElementType(v1Type), mask.size(),
|
||||
LLVM::isScalableVectorType(v1Type));
|
||||
resType =
|
||||
LLVM::getVectorType(cast<VectorType>(v1Type).getElementType(),
|
||||
mask.size(), LLVM::isScalableVectorType(v1Type));
|
||||
return success();
|
||||
}
|
||||
|
||||
@ -3318,7 +3319,7 @@ LogicalResult AtomicRMWOp::verify() {
|
||||
if (isCompatibleVectorType(valType)) {
|
||||
if (isScalableVectorType(valType))
|
||||
return emitOpError("expected LLVM IR fixed vector type");
|
||||
Type elemType = getVectorElementType(valType);
|
||||
Type elemType = llvm::cast<VectorType>(valType).getElementType();
|
||||
if (!isCompatibleFloatingPointType(elemType))
|
||||
return emitOpError(
|
||||
"expected LLVM IR floating point type for vector element");
|
||||
@ -3423,9 +3424,10 @@ static LogicalResult verifyExtOp(ExtOp op) {
|
||||
return op.emitError("input and output vectors are of incompatible shape");
|
||||
// Because this is a CastOp, the element of vectors is guaranteed to be an
|
||||
// integer.
|
||||
inputType = cast<IntegerType>(getVectorElementType(op.getArg().getType()));
|
||||
outputType =
|
||||
cast<IntegerType>(getVectorElementType(op.getResult().getType()));
|
||||
inputType = cast<IntegerType>(
|
||||
cast<VectorType>(op.getArg().getType()).getElementType());
|
||||
outputType = cast<IntegerType>(
|
||||
cast<VectorType>(op.getResult().getType()).getElementType());
|
||||
} else {
|
||||
// Because this is a CastOp and arg is not a vector, arg is guaranteed to be
|
||||
// an integer.
|
||||
|
@ -821,12 +821,6 @@ bool mlir::LLVM::isCompatibleVectorType(Type type) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Type mlir::LLVM::getVectorElementType(Type type) {
|
||||
auto vecTy = dyn_cast<VectorType>(type);
|
||||
assert(vecTy && "incompatible with LLVM vector type");
|
||||
return vecTy.getElementType();
|
||||
}
|
||||
|
||||
llvm::ElementCount mlir::LLVM::getVectorNumElements(Type type) {
|
||||
auto vecTy = dyn_cast<VectorType>(type);
|
||||
assert(vecTy && "incompatible with LLVM vector type");
|
||||
|
@ -826,7 +826,7 @@ static Type getVectorTypeForAttr(Type type, ArrayRef<int64_t> arrayShape = {}) {
|
||||
}
|
||||
|
||||
// An LLVM dialect vector can only contain scalars.
|
||||
Type elementType = LLVM::getVectorElementType(type);
|
||||
Type elementType = cast<VectorType>(type).getElementType();
|
||||
if (!elementType.isIntOrFloat())
|
||||
return {};
|
||||
|
||||
|
@ -515,21 +515,21 @@ func.func @extractvalue_wrong_nesting() {
|
||||
// -----
|
||||
|
||||
func.func @invalid_vector_type_1(%arg0: vector<4xf32>, %arg1: i32, %arg2: f32) {
|
||||
// expected-error@+1 {{'vector' must be LLVM dialect-compatible vector}}
|
||||
// expected-error@+1 {{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
|
||||
%0 = llvm.extractelement %arg2[%arg1 : i32] : f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @invalid_vector_type_2(%arg0: vector<4xf32>, %arg1: i32, %arg2: f32) {
|
||||
// expected-error@+1 {{'vector' must be LLVM dialect-compatible vector}}
|
||||
// expected-error@+1 {{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
|
||||
%0 = llvm.insertelement %arg2, %arg2[%arg1 : i32] : f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @invalid_vector_type_3(%arg0: vector<4xf32>, %arg1: i32, %arg2: f32) {
|
||||
// expected-error@+2 {{expected an LLVM compatible vector type}}
|
||||
// expected-error@+1 {{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
|
||||
%0 = llvm.shufflevector %arg2, %arg2 [0, 0, 0, 0, 7] : f32
|
||||
}
|
||||
|
||||
|
@ -211,7 +211,7 @@ llvm.func @vec_reduce_fmax_intr_wrong_type(%arg0 : vector<4xi32>) -> i32 {
|
||||
// -----
|
||||
|
||||
llvm.func @matrix_load_intr_wrong_type(%ptr : !llvm.ptr, %stride : i32) -> f32 {
|
||||
// expected-error @below{{op result #0 must be LLVM dialect-compatible vector type, but got 'f32'}}
|
||||
// expected-error @+2{{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
|
||||
%0 = llvm.intr.matrix.column.major.load %ptr, <stride=%stride>
|
||||
{ isVolatile = 0: i1, rows = 3: i32, columns = 16: i32} : f32 from !llvm.ptr stride i32
|
||||
llvm.return %0 : f32
|
||||
@ -229,7 +229,7 @@ llvm.func @matrix_store_intr_wrong_type(%matrix : vector<48xf32>, %ptr : i32, %s
|
||||
// -----
|
||||
|
||||
llvm.func @matrix_multiply_intr_wrong_type(%arg0 : vector<64xf32>, %arg1 : f32) -> vector<12xf32> {
|
||||
// expected-error @below{{op operand #1 must be LLVM dialect-compatible vector type, but got 'f32'}}
|
||||
// expected-error @+2{{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
|
||||
%0 = llvm.intr.matrix.multiply %arg0, %arg1
|
||||
{ lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32} : (vector<64xf32>, f32) -> vector<12xf32>
|
||||
llvm.return %0 : vector<12xf32>
|
||||
@ -238,7 +238,7 @@ llvm.func @matrix_multiply_intr_wrong_type(%arg0 : vector<64xf32>, %arg1 : f32)
|
||||
// -----
|
||||
|
||||
llvm.func @matrix_transpose_intr_wrong_type(%matrix : f32) -> vector<48xf32> {
|
||||
// expected-error @below{{op operand #0 must be LLVM dialect-compatible vector type, but got 'f32'}}
|
||||
// expected-error @below{{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
|
||||
%0 = llvm.intr.matrix.transpose %matrix {rows = 3: i32, columns = 16: i32} : f32 into vector<48xf32>
|
||||
llvm.return %0 : vector<48xf32>
|
||||
}
|
||||
@ -286,7 +286,7 @@ llvm.func @masked_gather_intr_wrong_type_scalable(%ptrs : vector<7x!llvm.ptr>, %
|
||||
// -----
|
||||
|
||||
llvm.func @masked_scatter_intr_wrong_type(%vec : f32, %ptrs : vector<7x!llvm.ptr>, %mask : vector<7xi1>) {
|
||||
// expected-error @below{{op operand #0 must be LLVM dialect-compatible vector type, but got 'f32'}}
|
||||
// expected-error @below{{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
|
||||
llvm.intr.masked.scatter %vec, %ptrs, %mask { alignment = 1: i32} : f32, vector<7xi1> into vector<7x!llvm.ptr>
|
||||
llvm.return
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user