[MLIR][DataLayout] Add support for scalable vectors (#89349)
This commit extends the data layout to support scalable vectors. For scalable vectors, the `TypeSize`'s scalable field is set accordingly, and the alignment information remains the same as for normal vectors. This behavior is in sync with what LLVM's data layout queries are producing. Before this change, scalable vectors incorrectly returned the same size as "normal" vectors.
This commit is contained in:
parent
4d7f3d9e0f
commit
df411fbac6
@ -75,10 +75,12 @@ mlir::detail::getDefaultTypeSizeInBits(Type type, const DataLayout &dataLayout,
|
|||||||
// there is no bit-packing at the moment element sizes are taken in bytes and
|
// there is no bit-packing at the moment element sizes are taken in bytes and
|
||||||
// multiplied with 8 bits.
|
// multiplied with 8 bits.
|
||||||
// TODO: make this extensible.
|
// TODO: make this extensible.
|
||||||
if (auto vecType = dyn_cast<VectorType>(type))
|
if (auto vecType = dyn_cast<VectorType>(type)) {
|
||||||
return vecType.getNumElements() / vecType.getShape().back() *
|
uint64_t baseSize = vecType.getNumElements() / vecType.getShape().back() *
|
||||||
llvm::PowerOf2Ceil(vecType.getShape().back()) *
|
llvm::PowerOf2Ceil(vecType.getShape().back()) *
|
||||||
dataLayout.getTypeSize(vecType.getElementType()) * 8;
|
dataLayout.getTypeSize(vecType.getElementType()) * 8;
|
||||||
|
return llvm::TypeSize::get(baseSize, vecType.isScalable());
|
||||||
|
}
|
||||||
|
|
||||||
if (auto typeInterface = dyn_cast<DataLayoutTypeInterface>(type))
|
if (auto typeInterface = dyn_cast<DataLayoutTypeInterface>(type))
|
||||||
return typeInterface.getTypeSizeInBits(dataLayout, params);
|
return typeInterface.getTypeSizeInBits(dataLayout, params);
|
||||||
@ -138,9 +140,10 @@ getFloatTypeABIAlignment(FloatType fltType, const DataLayout &dataLayout,
|
|||||||
uint64_t mlir::detail::getDefaultABIAlignment(
|
uint64_t mlir::detail::getDefaultABIAlignment(
|
||||||
Type type, const DataLayout &dataLayout,
|
Type type, const DataLayout &dataLayout,
|
||||||
ArrayRef<DataLayoutEntryInterface> params) {
|
ArrayRef<DataLayoutEntryInterface> params) {
|
||||||
// Natural alignment is the closest power-of-two number above.
|
// Natural alignment is the closest power-of-two number above. For scalable
|
||||||
|
// vectors, aligning them to the same as the base vector is sufficient.
|
||||||
if (isa<VectorType>(type))
|
if (isa<VectorType>(type))
|
||||||
return llvm::PowerOf2Ceil(dataLayout.getTypeSize(type));
|
return llvm::PowerOf2Ceil(dataLayout.getTypeSize(type).getKnownMinValue());
|
||||||
|
|
||||||
if (auto fltType = dyn_cast<FloatType>(type))
|
if (auto fltType = dyn_cast<FloatType>(type))
|
||||||
return getFloatTypeABIAlignment(fltType, dataLayout, params);
|
return getFloatTypeABIAlignment(fltType, dataLayout, params);
|
||||||
|
@ -32,6 +32,18 @@ func.func @no_layout_builtin() {
|
|||||||
// CHECK: preferred = 8
|
// CHECK: preferred = 8
|
||||||
// CHECK: size = 8
|
// CHECK: size = 8
|
||||||
"test.data_layout_query"() : () -> index
|
"test.data_layout_query"() : () -> index
|
||||||
|
// CHECK: alignment = 16
|
||||||
|
// CHECK: bitsize = 128
|
||||||
|
// CHECK: index = 0
|
||||||
|
// CHECK: preferred = 16
|
||||||
|
// CHECK: size = 16
|
||||||
|
"test.data_layout_query"() : () -> vector<4xi32>
|
||||||
|
// CHECK: alignment = 16
|
||||||
|
// CHECK: bitsize = {minimal_size = 128 : index, scalable}
|
||||||
|
// CHECK: index = 0
|
||||||
|
// CHECK: preferred = 16
|
||||||
|
// CHECK: size = {minimal_size = 16 : index, scalable}
|
||||||
|
"test.data_layout_query"() : () -> vector<[4]xi32>
|
||||||
return
|
return
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -46,9 +46,22 @@ struct TestDataLayoutQuery
|
|||||||
Attribute programMemorySpace = layout.getProgramMemorySpace();
|
Attribute programMemorySpace = layout.getProgramMemorySpace();
|
||||||
Attribute globalMemorySpace = layout.getGlobalMemorySpace();
|
Attribute globalMemorySpace = layout.getGlobalMemorySpace();
|
||||||
uint64_t stackAlignment = layout.getStackAlignment();
|
uint64_t stackAlignment = layout.getStackAlignment();
|
||||||
|
|
||||||
|
auto convertTypeSizeToAttr = [&](llvm::TypeSize typeSize) -> Attribute {
|
||||||
|
if (!typeSize.isScalable())
|
||||||
|
return builder.getIndexAttr(typeSize);
|
||||||
|
|
||||||
|
return builder.getDictionaryAttr({
|
||||||
|
builder.getNamedAttr("scalable", builder.getUnitAttr()),
|
||||||
|
builder.getNamedAttr(
|
||||||
|
"minimal_size",
|
||||||
|
builder.getIndexAttr(typeSize.getKnownMinValue())),
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
op->setAttrs(
|
op->setAttrs(
|
||||||
{builder.getNamedAttr("size", builder.getIndexAttr(size)),
|
{builder.getNamedAttr("size", convertTypeSizeToAttr(size)),
|
||||||
builder.getNamedAttr("bitsize", builder.getIndexAttr(bitsize)),
|
builder.getNamedAttr("bitsize", convertTypeSizeToAttr(bitsize)),
|
||||||
builder.getNamedAttr("alignment", builder.getIndexAttr(alignment)),
|
builder.getNamedAttr("alignment", builder.getIndexAttr(alignment)),
|
||||||
builder.getNamedAttr("preferred", builder.getIndexAttr(preferred)),
|
builder.getNamedAttr("preferred", builder.getIndexAttr(preferred)),
|
||||||
builder.getNamedAttr("index", builder.getIndexAttr(index)),
|
builder.getNamedAttr("index", builder.getIndexAttr(index)),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user