[mlir][LLVM][NFC] Simplify computeSizes
function (#153588)
Rename `computeSizes` to `computeSize` and make it compute just a single size. This is in preparation of adding 1:N support to the Func->LLVM lowering patterns.
This commit is contained in:
parent
7d91213559
commit
0ff92fe2f0
@ -189,15 +189,13 @@ public:
|
||||
/// `unpack`.
|
||||
static unsigned getNumUnpackedValues() { return 2; }
|
||||
|
||||
/// Builds IR computing the sizes in bytes (suitable for opaque allocation)
|
||||
/// and appends the corresponding values into `sizes`. `addressSpaces`
|
||||
/// which must have the same length as `values`, is needed to handle layouts
|
||||
/// where sizeof(ptr addrspace(N)) != sizeof(ptr addrspace(0)).
|
||||
static void computeSizes(OpBuilder &builder, Location loc,
|
||||
/// Builds and returns IR computing the size in bytes (suitable for opaque
|
||||
/// allocation). `addressSpace` is needed to handle layouts where
|
||||
/// sizeof(ptr addrspace(N)) != sizeof(ptr addrspace(0)).
|
||||
static Value computeSize(OpBuilder &builder, Location loc,
|
||||
const LLVMTypeConverter &typeConverter,
|
||||
ArrayRef<UnrankedMemRefDescriptor> values,
|
||||
ArrayRef<unsigned> addressSpaces,
|
||||
SmallVectorImpl<Value> &sizes);
|
||||
UnrankedMemRefDescriptor desc,
|
||||
unsigned addressSpace);
|
||||
|
||||
/// TODO: The following accessors don't take alignment rules between elements
|
||||
/// of the descriptor struct into account. For some architectures, it might be
|
||||
|
@ -353,14 +353,9 @@ void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc,
|
||||
results.push_back(d.memRefDescPtr(builder, loc));
|
||||
}
|
||||
|
||||
void UnrankedMemRefDescriptor::computeSizes(
|
||||
Value UnrankedMemRefDescriptor::computeSize(
|
||||
OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter,
|
||||
ArrayRef<UnrankedMemRefDescriptor> values, ArrayRef<unsigned> addressSpaces,
|
||||
SmallVectorImpl<Value> &sizes) {
|
||||
if (values.empty())
|
||||
return;
|
||||
assert(values.size() == addressSpaces.size() &&
|
||||
"must provide address space for each descriptor");
|
||||
UnrankedMemRefDescriptor desc, unsigned addressSpace) {
|
||||
// Cache the index type.
|
||||
Type indexType = typeConverter.getIndexType();
|
||||
|
||||
@ -371,34 +366,31 @@ void UnrankedMemRefDescriptor::computeSizes(
|
||||
builder, loc, indexType,
|
||||
llvm::divideCeil(typeConverter.getIndexTypeBitwidth(), 8));
|
||||
|
||||
sizes.reserve(sizes.size() + values.size());
|
||||
for (auto [desc, addressSpace] : llvm::zip(values, addressSpaces)) {
|
||||
// Emit IR computing the memory necessary to store the descriptor. This
|
||||
// assumes the descriptor to be
|
||||
// { type*, type*, index, index[rank], index[rank] }
|
||||
// and densely packed, so the total size is
|
||||
// 2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index).
|
||||
// TODO: consider including the actual size (including eventual padding due
|
||||
// to data layout) into the unranked descriptor.
|
||||
Value pointerSize = createIndexAttrConstant(
|
||||
builder, loc, indexType,
|
||||
llvm::divideCeil(typeConverter.getPointerBitwidth(addressSpace), 8));
|
||||
Value doublePointerSize =
|
||||
LLVM::MulOp::create(builder, loc, indexType, two, pointerSize);
|
||||
// Emit IR computing the memory necessary to store the descriptor. This
|
||||
// assumes the descriptor to be
|
||||
// { type*, type*, index, index[rank], index[rank] }
|
||||
// and densely packed, so the total size is
|
||||
// 2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index).
|
||||
// TODO: consider including the actual size (including eventual padding due
|
||||
// to data layout) into the unranked descriptor.
|
||||
Value pointerSize = createIndexAttrConstant(
|
||||
builder, loc, indexType,
|
||||
llvm::divideCeil(typeConverter.getPointerBitwidth(addressSpace), 8));
|
||||
Value doublePointerSize =
|
||||
LLVM::MulOp::create(builder, loc, indexType, two, pointerSize);
|
||||
|
||||
// (1 + 2 * rank) * sizeof(index)
|
||||
Value rank = desc.rank(builder, loc);
|
||||
Value doubleRank = LLVM::MulOp::create(builder, loc, indexType, two, rank);
|
||||
Value doubleRankIncremented =
|
||||
LLVM::AddOp::create(builder, loc, indexType, doubleRank, one);
|
||||
Value rankIndexSize = LLVM::MulOp::create(builder, loc, indexType,
|
||||
doubleRankIncremented, indexSize);
|
||||
// (1 + 2 * rank) * sizeof(index)
|
||||
Value rank = desc.rank(builder, loc);
|
||||
Value doubleRank = LLVM::MulOp::create(builder, loc, indexType, two, rank);
|
||||
Value doubleRankIncremented =
|
||||
LLVM::AddOp::create(builder, loc, indexType, doubleRank, one);
|
||||
Value rankIndexSize = LLVM::MulOp::create(builder, loc, indexType,
|
||||
doubleRankIncremented, indexSize);
|
||||
|
||||
// Total allocation size.
|
||||
Value allocationSize = LLVM::AddOp::create(
|
||||
builder, loc, indexType, doublePointerSize, rankIndexSize);
|
||||
sizes.push_back(allocationSize);
|
||||
}
|
||||
// Total allocation size.
|
||||
Value allocationSize = LLVM::AddOp::create(builder, loc, indexType,
|
||||
doublePointerSize, rankIndexSize);
|
||||
return allocationSize;
|
||||
}
|
||||
|
||||
Value UnrankedMemRefDescriptor::allocatedPtr(
|
||||
|
@ -239,12 +239,6 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
|
||||
if (unrankedMemrefs.empty())
|
||||
return success();
|
||||
|
||||
// Compute allocation sizes.
|
||||
SmallVector<Value> sizes;
|
||||
UnrankedMemRefDescriptor::computeSizes(builder, loc, *getTypeConverter(),
|
||||
unrankedMemrefs, unrankedAddressSpaces,
|
||||
sizes);
|
||||
|
||||
// Get frequently used types.
|
||||
Type indexType = getTypeConverter()->getIndexType();
|
||||
|
||||
@ -267,8 +261,10 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
|
||||
Type type = origTypes[i];
|
||||
if (!isa<UnrankedMemRefType>(type))
|
||||
continue;
|
||||
Value allocationSize = sizes[unrankedMemrefPos++];
|
||||
UnrankedMemRefDescriptor desc(operands[i]);
|
||||
Value allocationSize = UnrankedMemRefDescriptor::computeSize(
|
||||
builder, loc, *getTypeConverter(), desc,
|
||||
unrankedAddressSpaces[unrankedMemrefPos++]);
|
||||
|
||||
// Allocate memory, copy, and free the source if necessary.
|
||||
Value memory =
|
||||
|
@ -1246,10 +1246,8 @@ struct MemorySpaceCastOpLowering
|
||||
auto result = UnrankedMemRefDescriptor::poison(
|
||||
rewriter, loc, typeConverter->convertType(resultTypeU));
|
||||
result.setRank(rewriter, loc, rank);
|
||||
SmallVector<Value, 1> sizes;
|
||||
UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
|
||||
result, resultAddrSpace, sizes);
|
||||
Value resultUnderlyingSize = sizes.front();
|
||||
Value resultUnderlyingSize = UnrankedMemRefDescriptor::computeSize(
|
||||
rewriter, loc, *getTypeConverter(), result, resultAddrSpace);
|
||||
Value resultUnderlyingDesc =
|
||||
LLVM::AllocaOp::create(rewriter, loc, getPtrType(),
|
||||
rewriter.getI8Type(), resultUnderlyingSize);
|
||||
@ -1530,12 +1528,11 @@ private:
|
||||
auto targetDesc = UnrankedMemRefDescriptor::poison(
|
||||
rewriter, loc, typeConverter->convertType(targetType));
|
||||
targetDesc.setRank(rewriter, loc, resultRank);
|
||||
SmallVector<Value, 4> sizes;
|
||||
UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
|
||||
targetDesc, addressSpace, sizes);
|
||||
Value allocationSize = UnrankedMemRefDescriptor::computeSize(
|
||||
rewriter, loc, *getTypeConverter(), targetDesc, addressSpace);
|
||||
Value underlyingDescPtr = LLVM::AllocaOp::create(
|
||||
rewriter, loc, getPtrType(), IntegerType::get(getContext(), 8),
|
||||
sizes.front());
|
||||
allocationSize);
|
||||
targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
|
||||
|
||||
// Extract pointers and offset from the source memref.
|
||||
|
Loading…
x
Reference in New Issue
Block a user