[mlir][LLVM] Delete getVectorElementType (#134981)

The LLVM dialect no longer has its own vector types. It uses
`mlir::VectorType` everywhere. Remove `LLVM::getVectorElementType` and
use `cast<VectorType>(ty).getElementType()` instead. This commit
addresses a
[comment](https://github.com/llvm/llvm-project/pull/133286#discussion_r2022192500)
on the PR that deleted the LLVM vector types.

Also improve vector type constraints by specifying the
`mlir::VectorType` C++ class, so that explicit casts to `VectorType` can
be avoided in some places.
This commit is contained in:
Matthias Springer 2025-04-09 21:35:32 +02:00 committed by GitHub
parent e0950ebb9c
commit a0d449016b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 44 additions and 47 deletions

View File

@ -334,8 +334,6 @@ compatible with the LLVM dialect:
- `bool LLVM::isCompatibleVectorType(Type)` - checks whether a type is a
vector type compatible with the LLVM dialect;
- `Type LLVM::getVectorElementType(Type)` - returns the element type of any
vector type compatible with the LLVM dialect;
- `llvm::ElementCount LLVM::getVectorNumElements(Type)` - returns the number
of elements in any vector type compatible with the LLVM dialect;
- `Type LLVM::getFixedVectorType(Type, unsigned)` - gets a fixed vector type

View File

@ -874,7 +874,7 @@ def LLVM_MatrixColumnMajorLoadOp : LLVM_OneResultIntrOp<"matrix.column.major.loa
const llvm::DataLayout &dl =
builder.GetInsertBlock()->getModule()->getDataLayout();
llvm::Type *ElemTy = moduleTranslation.convertType(
getVectorElementType(op.getType()));
op.getType().getElementType());
llvm::Align align = dl.getABITypeAlign(ElemTy);
$res = mb.CreateColumnMajorLoad(
ElemTy, $data, align, $stride, $isVolatile, $rows,
@ -907,7 +907,7 @@ def LLVM_MatrixColumnMajorStoreOp : LLVM_ZeroResultIntrOp<"matrix.column.major.s
llvm::MatrixBuilder mb(builder);
const llvm::DataLayout &dl =
builder.GetInsertBlock()->getModule()->getDataLayout();
Type elementType = getVectorElementType(op.getMatrix().getType());
Type elementType = op.getMatrix().getType().getElementType();
llvm::Align align = dl.getABITypeAlign(
moduleTranslation.convertType(elementType));
mb.CreateColumnMajorStore(
@ -1164,7 +1164,8 @@ def LLVM_vector_insert
let extraClassDeclaration = [{
uint64_t getVectorBitWidth(Type vector) {
return getVectorNumElements(vector).getKnownMinValue() *
getVectorElementType(vector).getIntOrFloatBitWidth();
::llvm::cast<VectorType>(vector).getElementType()
.getIntOrFloatBitWidth();
}
uint64_t getSrcVectorBitWidth() {
return getVectorBitWidth(getSrcvec().getType());
@ -1196,7 +1197,8 @@ def LLVM_vector_extract
let extraClassDeclaration = [{
uint64_t getVectorBitWidth(Type vector) {
return getVectorNumElements(vector).getKnownMinValue() *
getVectorElementType(vector).getIntOrFloatBitWidth();
::llvm::cast<VectorType>(vector).getElementType()
.getIntOrFloatBitWidth();
}
uint64_t getSrcVectorBitWidth() {
return getVectorBitWidth(getSrcvec().getType());
@ -1216,8 +1218,8 @@ def LLVM_vector_interleave2
"result has twice as many elements as 'vec1'",
And<[CPred<"getVectorNumElements($res.getType()) == "
"getVectorNumElements($vec1.getType()) * 2">,
CPred<"getVectorElementType($vec1.getType()) == "
"getVectorElementType($res.getType())">]>>,
CPred<"::llvm::cast<VectorType>($vec1.getType()).getElementType() == "
"::llvm::cast<VectorType>($res.getType()).getElementType()">]>>,
]>,
Arguments<(ins LLVM_AnyVector:$vec1, LLVM_AnyVector:$vec2)>;

View File

@ -113,17 +113,20 @@ def LLVM_AnyNonAggregate : Type<And<[LLVM_Type.predicate,
// Type constraint accepting any LLVM vector type.
def LLVM_AnyVector : Type<CPred<"::mlir::LLVM::isCompatibleVectorType($_self)">,
"LLVM dialect-compatible vector type">;
"LLVM dialect-compatible vector type",
"::mlir::VectorType">;
// Type constraint accepting any LLVM fixed-length vector type.
def LLVM_AnyFixedVector : Type<CPred<
"!::mlir::LLVM::isScalableVectorType($_self)">,
"LLVM dialect-compatible fixed-length vector type">;
"LLVM dialect-compatible fixed-length vector type",
"::mlir::VectorType">;
// Type constraint accepting any LLVM scalable vector type.
def LLVM_AnyScalableVector : Type<CPred<
"::mlir::LLVM::isScalableVectorType($_self)">,
"LLVM dialect-compatible scalable vector type">;
"LLVM dialect-compatible scalable vector type",
"::mlir::VectorType">;
// Type constraint accepting an LLVM vector type with an additional constraint
// on the vector element type.
@ -131,9 +134,10 @@ class LLVM_VectorOf<Type element> : Type<
And<[LLVM_AnyVector.predicate,
SubstLeaves<
"$_self",
"::mlir::LLVM::getVectorElementType($_self)",
"::llvm::cast<::mlir::VectorType>($_self).getElementType()",
element.predicate>]>,
"LLVM dialect-compatible vector of " # element.summary>;
"LLVM dialect-compatible vector of " # element.summary,
"::mlir::VectorType">;
// Type constraint accepting a constrained type, or a vector of such types.
class LLVM_ScalarOrVectorOf<Type element> :

View File

@ -820,8 +820,9 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
//===----------------------------------------------------------------------===//
def LLVM_ExtractElementOp : LLVM_Op<"extractelement", [Pure,
TypesMatchWith<"result type matches vector element type", "vector", "res",
"LLVM::getVectorElementType($_self)">]> {
TypesMatchWith<
"result type matches vector element type", "vector", "res",
"::llvm::cast<::mlir::VectorType>($_self).getElementType()">]> {
let summary = "Extract an element from an LLVM vector.";
let arguments = (ins LLVM_AnyVector:$vector, AnySignlessInteger:$position);
@ -881,7 +882,8 @@ def LLVM_ExtractValueOp : LLVM_Op<"extractvalue", [Pure]> {
def LLVM_InsertElementOp : LLVM_Op<"insertelement", [Pure,
TypesMatchWith<"argument type matches vector element type", "vector",
"value", "LLVM::getVectorElementType($_self)">,
"value",
"::llvm::cast<::mlir::VectorType>($_self).getElementType()">,
AllTypesMatch<["res", "vector"]>]> {
let summary = "Insert an element into an LLVM vector.";

View File

@ -111,10 +111,6 @@ bool isCompatibleFloatingPointType(Type type);
/// dialect pointers and LLVM dialect scalable vector types.
bool isCompatibleVectorType(Type type);
/// Returns the element type of any vector type compatible with the LLVM
/// dialect.
Type getVectorElementType(Type type);
/// Returns the element count of any LLVM-compatible vector type.
llvm::ElementCount getVectorNumElements(Type type);

View File

@ -78,10 +78,9 @@ static unsigned getBitWidth(Type type) {
/// Returns the bit width of LLVMType integer or vector.
static unsigned getLLVMTypeBitWidth(Type type) {
return cast<IntegerType>((LLVM::isCompatibleVectorType(type)
? LLVM::getVectorElementType(type)
: type))
.getWidth();
if (auto vecTy = dyn_cast<VectorType>(type))
type = vecTy.getElementType();
return cast<IntegerType>(type).getWidth();
}
/// Creates `IntegerAttribute` with all bits set for given type

View File

@ -2734,9 +2734,9 @@ void ShuffleVectorOp::build(OpBuilder &builder, OperationState &state, Value v1,
Value v2, DenseI32ArrayAttr mask,
ArrayRef<NamedAttribute> attrs) {
auto containerType = v1.getType();
auto vType = LLVM::getVectorType(LLVM::getVectorElementType(containerType),
mask.size(),
LLVM::isScalableVectorType(containerType));
auto vType = LLVM::getVectorType(
cast<VectorType>(containerType).getElementType(), mask.size(),
LLVM::isScalableVectorType(containerType));
build(builder, state, vType, v1, v2, mask);
state.addAttributes(attrs);
}
@ -2752,8 +2752,9 @@ static ParseResult parseShuffleType(AsmParser &parser, Type v1Type,
if (!LLVM::isCompatibleVectorType(v1Type))
return parser.emitError(parser.getCurrentLocation(),
"expected an LLVM compatible vector type");
resType = LLVM::getVectorType(LLVM::getVectorElementType(v1Type), mask.size(),
LLVM::isScalableVectorType(v1Type));
resType =
LLVM::getVectorType(cast<VectorType>(v1Type).getElementType(),
mask.size(), LLVM::isScalableVectorType(v1Type));
return success();
}
@ -3318,7 +3319,7 @@ LogicalResult AtomicRMWOp::verify() {
if (isCompatibleVectorType(valType)) {
if (isScalableVectorType(valType))
return emitOpError("expected LLVM IR fixed vector type");
Type elemType = getVectorElementType(valType);
Type elemType = llvm::cast<VectorType>(valType).getElementType();
if (!isCompatibleFloatingPointType(elemType))
return emitOpError(
"expected LLVM IR floating point type for vector element");
@ -3423,9 +3424,10 @@ static LogicalResult verifyExtOp(ExtOp op) {
return op.emitError("input and output vectors are of incompatible shape");
// Because this is a CastOp, the element of vectors is guaranteed to be an
// integer.
inputType = cast<IntegerType>(getVectorElementType(op.getArg().getType()));
outputType =
cast<IntegerType>(getVectorElementType(op.getResult().getType()));
inputType = cast<IntegerType>(
cast<VectorType>(op.getArg().getType()).getElementType());
outputType = cast<IntegerType>(
cast<VectorType>(op.getResult().getType()).getElementType());
} else {
// Because this is a CastOp and arg is not a vector, arg is guaranteed to be
// an integer.

View File

@ -821,12 +821,6 @@ bool mlir::LLVM::isCompatibleVectorType(Type type) {
return false;
}
Type mlir::LLVM::getVectorElementType(Type type) {
auto vecTy = dyn_cast<VectorType>(type);
assert(vecTy && "incompatible with LLVM vector type");
return vecTy.getElementType();
}
llvm::ElementCount mlir::LLVM::getVectorNumElements(Type type) {
auto vecTy = dyn_cast<VectorType>(type);
assert(vecTy && "incompatible with LLVM vector type");

View File

@ -826,7 +826,7 @@ static Type getVectorTypeForAttr(Type type, ArrayRef<int64_t> arrayShape = {}) {
}
// An LLVM dialect vector can only contain scalars.
Type elementType = LLVM::getVectorElementType(type);
Type elementType = cast<VectorType>(type).getElementType();
if (!elementType.isIntOrFloat())
return {};

View File

@ -515,21 +515,21 @@ func.func @extractvalue_wrong_nesting() {
// -----
func.func @invalid_vector_type_1(%arg0: vector<4xf32>, %arg1: i32, %arg2: f32) {
// expected-error@+1 {{'vector' must be LLVM dialect-compatible vector}}
// expected-error@+1 {{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
%0 = llvm.extractelement %arg2[%arg1 : i32] : f32
}
// -----
func.func @invalid_vector_type_2(%arg0: vector<4xf32>, %arg1: i32, %arg2: f32) {
// expected-error@+1 {{'vector' must be LLVM dialect-compatible vector}}
// expected-error@+1 {{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
%0 = llvm.insertelement %arg2, %arg2[%arg1 : i32] : f32
}
// -----
func.func @invalid_vector_type_3(%arg0: vector<4xf32>, %arg1: i32, %arg2: f32) {
// expected-error@+2 {{expected an LLVM compatible vector type}}
// expected-error@+1 {{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
%0 = llvm.shufflevector %arg2, %arg2 [0, 0, 0, 0, 7] : f32
}

View File

@ -211,7 +211,7 @@ llvm.func @vec_reduce_fmax_intr_wrong_type(%arg0 : vector<4xi32>) -> i32 {
// -----
llvm.func @matrix_load_intr_wrong_type(%ptr : !llvm.ptr, %stride : i32) -> f32 {
// expected-error @below{{op result #0 must be LLVM dialect-compatible vector type, but got 'f32'}}
// expected-error @+2{{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
%0 = llvm.intr.matrix.column.major.load %ptr, <stride=%stride>
{ isVolatile = 0: i1, rows = 3: i32, columns = 16: i32} : f32 from !llvm.ptr stride i32
llvm.return %0 : f32
@ -229,7 +229,7 @@ llvm.func @matrix_store_intr_wrong_type(%matrix : vector<48xf32>, %ptr : i32, %s
// -----
llvm.func @matrix_multiply_intr_wrong_type(%arg0 : vector<64xf32>, %arg1 : f32) -> vector<12xf32> {
// expected-error @below{{op operand #1 must be LLVM dialect-compatible vector type, but got 'f32'}}
// expected-error @+2{{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
%0 = llvm.intr.matrix.multiply %arg0, %arg1
{ lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32} : (vector<64xf32>, f32) -> vector<12xf32>
llvm.return %0 : vector<12xf32>
@ -238,7 +238,7 @@ llvm.func @matrix_multiply_intr_wrong_type(%arg0 : vector<64xf32>, %arg1 : f32)
// -----
llvm.func @matrix_transpose_intr_wrong_type(%matrix : f32) -> vector<48xf32> {
// expected-error @below{{op operand #0 must be LLVM dialect-compatible vector type, but got 'f32'}}
// expected-error @below{{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
%0 = llvm.intr.matrix.transpose %matrix {rows = 3: i32, columns = 16: i32} : f32 into vector<48xf32>
llvm.return %0 : vector<48xf32>
}
@ -286,7 +286,7 @@ llvm.func @masked_gather_intr_wrong_type_scalable(%ptrs : vector<7x!llvm.ptr>, %
// -----
llvm.func @masked_scatter_intr_wrong_type(%vec : f32, %ptrs : vector<7x!llvm.ptr>, %mask : vector<7xi1>) {
// expected-error @below{{op operand #0 must be LLVM dialect-compatible vector type, but got 'f32'}}
// expected-error @below{{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
llvm.intr.masked.scatter %vec, %ptrs, %mask { alignment = 1: i32} : f32, vector<7xi1> into vector<7x!llvm.ptr>
llvm.return
}