[mlir][LLVM][NFC] Simplify computeSizes function (#153588)

Rename `computeSizes` to `computeSize` and make it compute just a single
size. This is in preparation of adding 1:N support to the Func->LLVM
lowering patterns.
This commit is contained in:
Matthias Springer 2025-08-14 17:00:03 +02:00 committed by GitHub
parent 7d91213559
commit 0ff92fe2f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 39 additions and 56 deletions

View File

@ -189,15 +189,13 @@ public:
/// `unpack`.
static unsigned getNumUnpackedValues() { return 2; }
/// Builds IR computing the sizes in bytes (suitable for opaque allocation)
/// and appends the corresponding values into `sizes`. `addressSpaces`
/// which must have the same length as `values`, is needed to handle layouts
/// where sizeof(ptr addrspace(N)) != sizeof(ptr addrspace(0)).
static void computeSizes(OpBuilder &builder, Location loc,
/// Builds and returns IR computing the size in bytes (suitable for opaque
/// allocation). `addressSpace` is needed to handle layouts where
/// sizeof(ptr addrspace(N)) != sizeof(ptr addrspace(0)).
static Value computeSize(OpBuilder &builder, Location loc,
const LLVMTypeConverter &typeConverter,
ArrayRef<UnrankedMemRefDescriptor> values,
ArrayRef<unsigned> addressSpaces,
SmallVectorImpl<Value> &sizes);
UnrankedMemRefDescriptor desc,
unsigned addressSpace);
/// TODO: The following accessors don't take alignment rules between elements
/// of the descriptor struct into account. For some architectures, it might be

View File

@ -353,14 +353,9 @@ void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc,
results.push_back(d.memRefDescPtr(builder, loc));
}
void UnrankedMemRefDescriptor::computeSizes(
Value UnrankedMemRefDescriptor::computeSize(
OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter,
ArrayRef<UnrankedMemRefDescriptor> values, ArrayRef<unsigned> addressSpaces,
SmallVectorImpl<Value> &sizes) {
if (values.empty())
return;
assert(values.size() == addressSpaces.size() &&
"must provide address space for each descriptor");
UnrankedMemRefDescriptor desc, unsigned addressSpace) {
// Cache the index type.
Type indexType = typeConverter.getIndexType();
@ -371,34 +366,31 @@ void UnrankedMemRefDescriptor::computeSizes(
builder, loc, indexType,
llvm::divideCeil(typeConverter.getIndexTypeBitwidth(), 8));
sizes.reserve(sizes.size() + values.size());
for (auto [desc, addressSpace] : llvm::zip(values, addressSpaces)) {
// Emit IR computing the memory necessary to store the descriptor. This
// assumes the descriptor to be
// { type*, type*, index, index[rank], index[rank] }
// and densely packed, so the total size is
// 2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index).
// TODO: consider including the actual size (including eventual padding due
// to data layout) into the unranked descriptor.
Value pointerSize = createIndexAttrConstant(
builder, loc, indexType,
llvm::divideCeil(typeConverter.getPointerBitwidth(addressSpace), 8));
Value doublePointerSize =
LLVM::MulOp::create(builder, loc, indexType, two, pointerSize);
// Emit IR computing the memory necessary to store the descriptor. This
// assumes the descriptor to be
// { type*, type*, index, index[rank], index[rank] }
// and densely packed, so the total size is
// 2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index).
// TODO: consider including the actual size (including eventual padding due
// to data layout) into the unranked descriptor.
Value pointerSize = createIndexAttrConstant(
builder, loc, indexType,
llvm::divideCeil(typeConverter.getPointerBitwidth(addressSpace), 8));
Value doublePointerSize =
LLVM::MulOp::create(builder, loc, indexType, two, pointerSize);
// (1 + 2 * rank) * sizeof(index)
Value rank = desc.rank(builder, loc);
Value doubleRank = LLVM::MulOp::create(builder, loc, indexType, two, rank);
Value doubleRankIncremented =
LLVM::AddOp::create(builder, loc, indexType, doubleRank, one);
Value rankIndexSize = LLVM::MulOp::create(builder, loc, indexType,
doubleRankIncremented, indexSize);
// (1 + 2 * rank) * sizeof(index)
Value rank = desc.rank(builder, loc);
Value doubleRank = LLVM::MulOp::create(builder, loc, indexType, two, rank);
Value doubleRankIncremented =
LLVM::AddOp::create(builder, loc, indexType, doubleRank, one);
Value rankIndexSize = LLVM::MulOp::create(builder, loc, indexType,
doubleRankIncremented, indexSize);
// Total allocation size.
Value allocationSize = LLVM::AddOp::create(
builder, loc, indexType, doublePointerSize, rankIndexSize);
sizes.push_back(allocationSize);
}
// Total allocation size.
Value allocationSize = LLVM::AddOp::create(builder, loc, indexType,
doublePointerSize, rankIndexSize);
return allocationSize;
}
Value UnrankedMemRefDescriptor::allocatedPtr(

View File

@ -239,12 +239,6 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
if (unrankedMemrefs.empty())
return success();
// Compute allocation sizes.
SmallVector<Value> sizes;
UnrankedMemRefDescriptor::computeSizes(builder, loc, *getTypeConverter(),
unrankedMemrefs, unrankedAddressSpaces,
sizes);
// Get frequently used types.
Type indexType = getTypeConverter()->getIndexType();
@ -267,8 +261,10 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
Type type = origTypes[i];
if (!isa<UnrankedMemRefType>(type))
continue;
Value allocationSize = sizes[unrankedMemrefPos++];
UnrankedMemRefDescriptor desc(operands[i]);
Value allocationSize = UnrankedMemRefDescriptor::computeSize(
builder, loc, *getTypeConverter(), desc,
unrankedAddressSpaces[unrankedMemrefPos++]);
// Allocate memory, copy, and free the source if necessary.
Value memory =

View File

@ -1246,10 +1246,8 @@ struct MemorySpaceCastOpLowering
auto result = UnrankedMemRefDescriptor::poison(
rewriter, loc, typeConverter->convertType(resultTypeU));
result.setRank(rewriter, loc, rank);
SmallVector<Value, 1> sizes;
UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
result, resultAddrSpace, sizes);
Value resultUnderlyingSize = sizes.front();
Value resultUnderlyingSize = UnrankedMemRefDescriptor::computeSize(
rewriter, loc, *getTypeConverter(), result, resultAddrSpace);
Value resultUnderlyingDesc =
LLVM::AllocaOp::create(rewriter, loc, getPtrType(),
rewriter.getI8Type(), resultUnderlyingSize);
@ -1530,12 +1528,11 @@ private:
auto targetDesc = UnrankedMemRefDescriptor::poison(
rewriter, loc, typeConverter->convertType(targetType));
targetDesc.setRank(rewriter, loc, resultRank);
SmallVector<Value, 4> sizes;
UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
targetDesc, addressSpace, sizes);
Value allocationSize = UnrankedMemRefDescriptor::computeSize(
rewriter, loc, *getTypeConverter(), targetDesc, addressSpace);
Value underlyingDescPtr = LLVM::AllocaOp::create(
rewriter, loc, getPtrType(), IntegerType::get(getContext(), 8),
sizes.front());
allocationSize);
targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
// Extract pointers and offset from the source memref.