
Rename `computeSizes` to `computeSize` and make it compute just a single size. This is in preparation of adding 1:N support to the Func->LLVM lowering patterns.
2088 lines
87 KiB
C++
2088 lines
87 KiB
C++
//===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
|
|
|
|
#include "mlir/Analysis/DataLayoutAnalysis.h"
|
|
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
|
|
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
|
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
|
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
|
|
#include "mlir/IR/AffineMap.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/IRMapping.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "llvm/Support/DebugLog.h"
|
|
#include "llvm/Support/MathExtras.h"
|
|
|
|
#include <optional>
|
|
|
|
#define DEBUG_TYPE "memref-to-llvm"
|
|
|
|
namespace mlir {
|
|
#define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
|
|
#include "mlir/Conversion/Passes.h.inc"
|
|
} // namespace mlir
|
|
|
|
using namespace mlir;
|
|
|
|
static constexpr LLVM::GEPNoWrapFlags kNoWrapFlags =
|
|
LLVM::GEPNoWrapFlags::inbounds | LLVM::GEPNoWrapFlags::nuw;
|
|
|
|
namespace {
|
|
|
|
static bool isStaticStrideOrOffset(int64_t strideOrOffset) {
|
|
return ShapedType::isStatic(strideOrOffset);
|
|
}
|
|
|
|
static FailureOr<LLVM::LLVMFuncOp>
|
|
getFreeFn(OpBuilder &b, const LLVMTypeConverter *typeConverter, ModuleOp module,
|
|
SymbolTableCollection *symbolTables) {
|
|
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
|
|
|
|
if (useGenericFn)
|
|
return LLVM::lookupOrCreateGenericFreeFn(b, module, symbolTables);
|
|
|
|
return LLVM::lookupOrCreateFreeFn(b, module, symbolTables);
|
|
}
|
|
|
|
static FailureOr<LLVM::LLVMFuncOp>
|
|
getNotalignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
|
|
Operation *module, Type indexType,
|
|
SymbolTableCollection *symbolTables) {
|
|
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
|
|
if (useGenericFn)
|
|
return LLVM::lookupOrCreateGenericAllocFn(b, module, indexType,
|
|
symbolTables);
|
|
|
|
return LLVM::lookupOrCreateMallocFn(b, module, indexType, symbolTables);
|
|
}
|
|
|
|
static FailureOr<LLVM::LLVMFuncOp>
|
|
getAlignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
|
|
Operation *module, Type indexType,
|
|
SymbolTableCollection *symbolTables) {
|
|
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
|
|
|
|
if (useGenericFn)
|
|
return LLVM::lookupOrCreateGenericAlignedAllocFn(b, module, indexType,
|
|
symbolTables);
|
|
|
|
return LLVM::lookupOrCreateAlignedAllocFn(b, module, indexType, symbolTables);
|
|
}
|
|
|
|
/// Computes the aligned value for 'input' as follows:
|
|
/// bumped = input + alignement - 1
|
|
/// aligned = bumped - bumped % alignment
|
|
static Value createAligned(ConversionPatternRewriter &rewriter, Location loc,
|
|
Value input, Value alignment) {
|
|
Value one = LLVM::ConstantOp::create(rewriter, loc, alignment.getType(),
|
|
rewriter.getIndexAttr(1));
|
|
Value bump = LLVM::SubOp::create(rewriter, loc, alignment, one);
|
|
Value bumped = LLVM::AddOp::create(rewriter, loc, input, bump);
|
|
Value mod = LLVM::URemOp::create(rewriter, loc, bumped, alignment);
|
|
return LLVM::SubOp::create(rewriter, loc, bumped, mod);
|
|
}
|
|
|
|
/// Computes the byte size for the MemRef element type.
|
|
static unsigned getMemRefEltSizeInBytes(const LLVMTypeConverter *typeConverter,
|
|
MemRefType memRefType, Operation *op,
|
|
const DataLayout *defaultLayout) {
|
|
const DataLayout *layout = defaultLayout;
|
|
if (const DataLayoutAnalysis *analysis =
|
|
typeConverter->getDataLayoutAnalysis()) {
|
|
layout = &analysis->getAbove(op);
|
|
}
|
|
Type elementType = memRefType.getElementType();
|
|
if (auto memRefElementType = dyn_cast<MemRefType>(elementType))
|
|
return typeConverter->getMemRefDescriptorSize(memRefElementType, *layout);
|
|
if (auto memRefElementType = dyn_cast<UnrankedMemRefType>(elementType))
|
|
return typeConverter->getUnrankedMemRefDescriptorSize(memRefElementType,
|
|
*layout);
|
|
return layout->getTypeSize(elementType);
|
|
}
|
|
|
|
static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
|
|
Location loc, Value allocatedPtr,
|
|
MemRefType memRefType, Type elementPtrType,
|
|
const LLVMTypeConverter &typeConverter) {
|
|
auto allocatedPtrTy = cast<LLVM::LLVMPointerType>(allocatedPtr.getType());
|
|
FailureOr<unsigned> maybeMemrefAddrSpace =
|
|
typeConverter.getMemRefAddressSpace(memRefType);
|
|
assert(succeeded(maybeMemrefAddrSpace) && "unsupported address space");
|
|
unsigned memrefAddrSpace = *maybeMemrefAddrSpace;
|
|
if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
|
|
allocatedPtr = LLVM::AddrSpaceCastOp::create(
|
|
rewriter, loc,
|
|
LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace),
|
|
allocatedPtr);
|
|
return allocatedPtr;
|
|
}
|
|
|
|
class AllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
|
|
SymbolTableCollection *symbolTables = nullptr;
|
|
|
|
public:
|
|
explicit AllocOpLowering(const LLVMTypeConverter &typeConverter,
|
|
SymbolTableCollection *symbolTables = nullptr,
|
|
PatternBenefit benefit = 1)
|
|
: ConvertOpToLLVMPattern<memref::AllocOp>(typeConverter, benefit),
|
|
symbolTables(symbolTables) {}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto loc = op.getLoc();
|
|
MemRefType memRefType = op.getType();
|
|
if (!isConvertibleAndHasIdentityMaps(memRefType))
|
|
return rewriter.notifyMatchFailure(op, "incompatible memref type");
|
|
|
|
// Get or insert alloc function into the module.
|
|
FailureOr<LLVM::LLVMFuncOp> allocFuncOp =
|
|
getNotalignedAllocFn(rewriter, getTypeConverter(),
|
|
op->getParentWithTrait<OpTrait::SymbolTable>(),
|
|
getIndexType(), symbolTables);
|
|
if (failed(allocFuncOp))
|
|
return failure();
|
|
|
|
// Get actual sizes of the memref as values: static sizes are constant
|
|
// values and dynamic sizes are passed to 'alloc' as operands. In case of
|
|
// zero-dimensional memref, assume a scalar (size 1).
|
|
SmallVector<Value, 4> sizes;
|
|
SmallVector<Value, 4> strides;
|
|
Value sizeBytes;
|
|
|
|
this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
|
|
rewriter, sizes, strides, sizeBytes, true);
|
|
|
|
Value alignment = getAlignment(rewriter, loc, op);
|
|
if (alignment) {
|
|
// Adjust the allocation size to consider alignment.
|
|
sizeBytes = LLVM::AddOp::create(rewriter, loc, sizeBytes, alignment);
|
|
}
|
|
|
|
// Allocate the underlying buffer.
|
|
Type elementPtrType = this->getElementPtrType(memRefType);
|
|
assert(elementPtrType && "could not compute element ptr type");
|
|
auto results =
|
|
LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(), sizeBytes);
|
|
|
|
Value allocatedPtr =
|
|
castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
|
|
elementPtrType, *getTypeConverter());
|
|
Value alignedPtr = allocatedPtr;
|
|
if (alignment) {
|
|
// Compute the aligned pointer.
|
|
Value allocatedInt =
|
|
LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), allocatedPtr);
|
|
Value alignmentInt =
|
|
createAligned(rewriter, loc, allocatedInt, alignment);
|
|
alignedPtr =
|
|
LLVM::IntToPtrOp::create(rewriter, loc, elementPtrType, alignmentInt);
|
|
}
|
|
|
|
// Create the MemRef descriptor.
|
|
auto memRefDescriptor = this->createMemRefDescriptor(
|
|
loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
|
|
|
|
// Return the final value of the descriptor.
|
|
rewriter.replaceOp(op, {memRefDescriptor});
|
|
return success();
|
|
}
|
|
|
|
/// Computes the alignment for the given memory allocation op.
|
|
template <typename OpType>
|
|
Value getAlignment(ConversionPatternRewriter &rewriter, Location loc,
|
|
OpType op) const {
|
|
MemRefType memRefType = op.getType();
|
|
Value alignment;
|
|
if (auto alignmentAttr = op.getAlignment()) {
|
|
Type indexType = getIndexType();
|
|
alignment =
|
|
createIndexAttrConstant(rewriter, loc, indexType, *alignmentAttr);
|
|
} else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
|
|
// In the case where no alignment is specified, we may want to override
|
|
// `malloc's` behavior. `malloc` typically aligns at the size of the
|
|
// biggest scalar on a target HW. For non-scalars, use the natural
|
|
// alignment of the LLVM type given by the LLVM DataLayout.
|
|
alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
|
|
}
|
|
return alignment;
|
|
}
|
|
};
|
|
|
|
class AlignedAllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
|
|
SymbolTableCollection *symbolTables = nullptr;
|
|
|
|
public:
|
|
explicit AlignedAllocOpLowering(const LLVMTypeConverter &typeConverter,
|
|
SymbolTableCollection *symbolTables = nullptr,
|
|
PatternBenefit benefit = 1)
|
|
: ConvertOpToLLVMPattern<memref::AllocOp>(typeConverter, benefit),
|
|
symbolTables(symbolTables) {}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto loc = op.getLoc();
|
|
MemRefType memRefType = op.getType();
|
|
if (!isConvertibleAndHasIdentityMaps(memRefType))
|
|
return rewriter.notifyMatchFailure(op, "incompatible memref type");
|
|
|
|
// Get or insert alloc function into module.
|
|
FailureOr<LLVM::LLVMFuncOp> allocFuncOp =
|
|
getAlignedAllocFn(rewriter, getTypeConverter(),
|
|
op->getParentWithTrait<OpTrait::SymbolTable>(),
|
|
getIndexType(), symbolTables);
|
|
if (failed(allocFuncOp))
|
|
return failure();
|
|
|
|
// Get actual sizes of the memref as values: static sizes are constant
|
|
// values and dynamic sizes are passed to 'alloc' as operands. In case of
|
|
// zero-dimensional memref, assume a scalar (size 1).
|
|
SmallVector<Value, 4> sizes;
|
|
SmallVector<Value, 4> strides;
|
|
Value sizeBytes;
|
|
|
|
this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
|
|
rewriter, sizes, strides, sizeBytes, !false);
|
|
|
|
int64_t alignment = alignedAllocationGetAlignment(op, &defaultLayout);
|
|
|
|
Value allocAlignment =
|
|
createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
|
|
|
|
// Function aligned_alloc requires size to be a multiple of alignment; we
|
|
// pad the size to the next multiple if necessary.
|
|
if (!isMemRefSizeMultipleOf(memRefType, alignment, op, &defaultLayout))
|
|
sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
|
|
|
|
Type elementPtrType = this->getElementPtrType(memRefType);
|
|
auto results =
|
|
LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(),
|
|
ValueRange({allocAlignment, sizeBytes}));
|
|
|
|
Value ptr =
|
|
castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
|
|
elementPtrType, *getTypeConverter());
|
|
|
|
// Create the MemRef descriptor.
|
|
auto memRefDescriptor = this->createMemRefDescriptor(
|
|
loc, memRefType, ptr, ptr, sizes, strides, rewriter);
|
|
|
|
// Return the final value of the descriptor.
|
|
rewriter.replaceOp(op, {memRefDescriptor});
|
|
return success();
|
|
}
|
|
|
|
/// The minimum alignment to use with aligned_alloc (has to be a power of 2).
|
|
static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
|
|
|
|
/// Computes the alignment for aligned_alloc used to allocate the buffer for
|
|
/// the memory allocation op.
|
|
///
|
|
/// Aligned_alloc requires the allocation size to be a power of two, and the
|
|
/// allocation size to be a multiple of the alignment.
|
|
int64_t alignedAllocationGetAlignment(memref::AllocOp op,
|
|
const DataLayout *defaultLayout) const {
|
|
if (std::optional<uint64_t> alignment = op.getAlignment())
|
|
return *alignment;
|
|
|
|
// Whenever we don't have alignment set, we will use an alignment
|
|
// consistent with the element type; since the allocation size has to be a
|
|
// power of two, we will bump to the next power of two if it isn't.
|
|
unsigned eltSizeBytes = getMemRefEltSizeInBytes(
|
|
getTypeConverter(), op.getType(), op, defaultLayout);
|
|
return std::max(kMinAlignedAllocAlignment,
|
|
llvm::PowerOf2Ceil(eltSizeBytes));
|
|
}
|
|
|
|
/// Returns true if the memref size in bytes is known to be a multiple of
|
|
/// factor.
|
|
bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, Operation *op,
|
|
const DataLayout *defaultLayout) const {
|
|
uint64_t sizeDivisor =
|
|
getMemRefEltSizeInBytes(getTypeConverter(), type, op, defaultLayout);
|
|
for (unsigned i = 0, e = type.getRank(); i < e; i++) {
|
|
if (type.isDynamicDim(i))
|
|
continue;
|
|
sizeDivisor = sizeDivisor * type.getDimSize(i);
|
|
}
|
|
return sizeDivisor % factor == 0;
|
|
}
|
|
|
|
private:
|
|
/// Default layout to use in absence of the corresponding analysis.
|
|
DataLayout defaultLayout;
|
|
};
|
|
|
|
struct AllocaOpLowering : public ConvertOpToLLVMPattern<memref::AllocaOp> {
|
|
using ConvertOpToLLVMPattern<memref::AllocaOp>::ConvertOpToLLVMPattern;
|
|
|
|
/// Allocates the underlying buffer using the right call. `allocatedBytePtr`
|
|
/// is set to null for stack allocations. `accessAlignment` is set if
|
|
/// alignment is needed post allocation (for eg. in conjunction with malloc).
|
|
LogicalResult
|
|
matchAndRewrite(memref::AllocaOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto loc = op.getLoc();
|
|
MemRefType memRefType = op.getType();
|
|
if (!isConvertibleAndHasIdentityMaps(memRefType))
|
|
return rewriter.notifyMatchFailure(op, "incompatible memref type");
|
|
|
|
// Get actual sizes of the memref as values: static sizes are constant
|
|
// values and dynamic sizes are passed to 'alloc' as operands. In case of
|
|
// zero-dimensional memref, assume a scalar (size 1).
|
|
SmallVector<Value, 4> sizes;
|
|
SmallVector<Value, 4> strides;
|
|
Value size;
|
|
|
|
this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
|
|
rewriter, sizes, strides, size, !true);
|
|
|
|
// With alloca, one gets a pointer to the element type right away.
|
|
// For stack allocations.
|
|
auto elementType =
|
|
typeConverter->convertType(op.getType().getElementType());
|
|
FailureOr<unsigned> maybeAddressSpace =
|
|
getTypeConverter()->getMemRefAddressSpace(op.getType());
|
|
assert(succeeded(maybeAddressSpace) && "unsupported address space");
|
|
unsigned addrSpace = *maybeAddressSpace;
|
|
auto elementPtrType =
|
|
LLVM::LLVMPointerType::get(rewriter.getContext(), addrSpace);
|
|
|
|
auto allocatedElementPtr =
|
|
LLVM::AllocaOp::create(rewriter, loc, elementPtrType, elementType, size,
|
|
op.getAlignment().value_or(0));
|
|
|
|
// Create the MemRef descriptor.
|
|
auto memRefDescriptor = this->createMemRefDescriptor(
|
|
loc, memRefType, allocatedElementPtr, allocatedElementPtr, sizes,
|
|
strides, rewriter);
|
|
|
|
// Return the final value of the descriptor.
|
|
rewriter.replaceOp(op, {memRefDescriptor});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct AllocaScopeOpLowering
|
|
: public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
|
|
using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
Location loc = allocaScopeOp.getLoc();
|
|
|
|
// Split the current block before the AllocaScopeOp to create the inlining
|
|
// point.
|
|
auto *currentBlock = rewriter.getInsertionBlock();
|
|
auto *remainingOpsBlock =
|
|
rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
|
|
Block *continueBlock;
|
|
if (allocaScopeOp.getNumResults() == 0) {
|
|
continueBlock = remainingOpsBlock;
|
|
} else {
|
|
continueBlock = rewriter.createBlock(
|
|
remainingOpsBlock, allocaScopeOp.getResultTypes(),
|
|
SmallVector<Location>(allocaScopeOp->getNumResults(),
|
|
allocaScopeOp.getLoc()));
|
|
LLVM::BrOp::create(rewriter, loc, ValueRange(), remainingOpsBlock);
|
|
}
|
|
|
|
// Inline body region.
|
|
Block *beforeBody = &allocaScopeOp.getBodyRegion().front();
|
|
Block *afterBody = &allocaScopeOp.getBodyRegion().back();
|
|
rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock);
|
|
|
|
// Save stack and then branch into the body of the region.
|
|
rewriter.setInsertionPointToEnd(currentBlock);
|
|
auto stackSaveOp = LLVM::StackSaveOp::create(rewriter, loc, getPtrType());
|
|
LLVM::BrOp::create(rewriter, loc, ValueRange(), beforeBody);
|
|
|
|
// Replace the alloca_scope return with a branch that jumps out of the body.
|
|
// Stack restore before leaving the body region.
|
|
rewriter.setInsertionPointToEnd(afterBody);
|
|
auto returnOp =
|
|
cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
|
|
auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
|
|
returnOp, returnOp.getResults(), continueBlock);
|
|
|
|
// Insert stack restore before jumping out the body of the region.
|
|
rewriter.setInsertionPoint(branchOp);
|
|
LLVM::StackRestoreOp::create(rewriter, loc, stackSaveOp);
|
|
|
|
// Replace the op with values return from the body region.
|
|
rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct AssumeAlignmentOpLowering
|
|
: public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
|
|
using ConvertOpToLLVMPattern<
|
|
memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
|
|
explicit AssumeAlignmentOpLowering(const LLVMTypeConverter &converter)
|
|
: ConvertOpToLLVMPattern<memref::AssumeAlignmentOp>(converter) {}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Value memref = adaptor.getMemref();
|
|
unsigned alignment = op.getAlignment();
|
|
auto loc = op.getLoc();
|
|
|
|
auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
|
|
Value ptr = getStridedElementPtr(rewriter, loc, srcMemRefType, memref,
|
|
/*indices=*/{});
|
|
|
|
// Emit llvm.assume(true) ["align"(memref, alignment)].
|
|
// This is more direct than ptrtoint-based checks, is explicitly supported,
|
|
// and works with non-integral address spaces.
|
|
Value trueCond =
|
|
LLVM::ConstantOp::create(rewriter, loc, rewriter.getBoolAttr(true));
|
|
Value alignmentConst =
|
|
createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
|
|
LLVM::AssumeOp::create(rewriter, loc, trueCond, LLVM::AssumeAlignTag(), ptr,
|
|
alignmentConst);
|
|
rewriter.replaceOp(op, memref);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// A `dealloc` is converted into a call to `free` on the underlying data buffer.
|
|
// The memref descriptor being an SSA value, there is no need to clean it up
|
|
// in any way.
|
|
class DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
|
|
SymbolTableCollection *symbolTables = nullptr;
|
|
|
|
public:
|
|
explicit DeallocOpLowering(const LLVMTypeConverter &typeConverter,
|
|
SymbolTableCollection *symbolTables = nullptr,
|
|
PatternBenefit benefit = 1)
|
|
: ConvertOpToLLVMPattern<memref::DeallocOp>(typeConverter, benefit),
|
|
symbolTables(symbolTables) {}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
// Insert the `free` declaration if it is not already present.
|
|
FailureOr<LLVM::LLVMFuncOp> freeFunc =
|
|
getFreeFn(rewriter, getTypeConverter(), op->getParentOfType<ModuleOp>(),
|
|
symbolTables);
|
|
if (failed(freeFunc))
|
|
return failure();
|
|
Value allocatedPtr;
|
|
if (auto unrankedTy =
|
|
llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
|
|
auto elementPtrTy = LLVM::LLVMPointerType::get(
|
|
rewriter.getContext(), unrankedTy.getMemorySpaceAsInt());
|
|
allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
|
|
rewriter, op.getLoc(),
|
|
UnrankedMemRefDescriptor(adaptor.getMemref())
|
|
.memRefDescPtr(rewriter, op.getLoc()),
|
|
elementPtrTy);
|
|
} else {
|
|
allocatedPtr = MemRefDescriptor(adaptor.getMemref())
|
|
.allocatedPtr(rewriter, op.getLoc());
|
|
}
|
|
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc.value(),
|
|
allocatedPtr);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// A `dim` is converted to a constant for static sizes and to an access to the
|
|
// size stored in the memref descriptor for dynamic sizes.
|
|
struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
|
|
using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Type operandType = dimOp.getSource().getType();
|
|
if (isa<UnrankedMemRefType>(operandType)) {
|
|
FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef(
|
|
operandType, dimOp, adaptor.getOperands(), rewriter);
|
|
if (failed(extractedSize))
|
|
return failure();
|
|
rewriter.replaceOp(dimOp, {*extractedSize});
|
|
return success();
|
|
}
|
|
if (isa<MemRefType>(operandType)) {
|
|
rewriter.replaceOp(
|
|
dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
|
|
adaptor.getOperands(), rewriter)});
|
|
return success();
|
|
}
|
|
llvm_unreachable("expected MemRefType or UnrankedMemRefType");
|
|
}
|
|
|
|
private:
|
|
FailureOr<Value>
|
|
extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
|
|
OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
Location loc = dimOp.getLoc();
|
|
|
|
auto unrankedMemRefType = cast<UnrankedMemRefType>(operandType);
|
|
auto scalarMemRefType =
|
|
MemRefType::get({}, unrankedMemRefType.getElementType());
|
|
FailureOr<unsigned> maybeAddressSpace =
|
|
getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType);
|
|
if (failed(maybeAddressSpace)) {
|
|
dimOp.emitOpError("memref memory space must be convertible to an integer "
|
|
"address space");
|
|
return failure();
|
|
}
|
|
unsigned addressSpace = *maybeAddressSpace;
|
|
|
|
// Extract pointer to the underlying ranked descriptor and bitcast it to a
|
|
// memref<element_type> descriptor pointer to minimize the number of GEP
|
|
// operations.
|
|
UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource());
|
|
Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
|
|
|
|
Type elementType = typeConverter->convertType(scalarMemRefType);
|
|
|
|
// Get pointer to offset field of memref<element_type> descriptor.
|
|
auto indexPtrTy =
|
|
LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
|
|
Value offsetPtr =
|
|
LLVM::GEPOp::create(rewriter, loc, indexPtrTy, elementType,
|
|
underlyingRankedDesc, ArrayRef<LLVM::GEPArg>{0, 2});
|
|
|
|
// The size value that we have to extract can be obtained using GEPop with
|
|
// `dimOp.index() + 1` index argument.
|
|
Value idxPlusOne = LLVM::AddOp::create(
|
|
rewriter, loc,
|
|
createIndexAttrConstant(rewriter, loc, getIndexType(), 1),
|
|
adaptor.getIndex());
|
|
Value sizePtr = LLVM::GEPOp::create(rewriter, loc, indexPtrTy,
|
|
getTypeConverter()->getIndexType(),
|
|
offsetPtr, idxPlusOne);
|
|
return LLVM::LoadOp::create(rewriter, loc,
|
|
getTypeConverter()->getIndexType(), sizePtr)
|
|
.getResult();
|
|
}
|
|
|
|
std::optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
|
|
if (auto idx = dimOp.getConstantIndex())
|
|
return idx;
|
|
|
|
if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
|
|
return cast<IntegerAttr>(constantOp.getValue()).getValue().getSExtValue();
|
|
|
|
return std::nullopt;
|
|
}
|
|
|
|
Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
|
|
OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
Location loc = dimOp.getLoc();
|
|
|
|
// Take advantage if index is constant.
|
|
MemRefType memRefType = cast<MemRefType>(operandType);
|
|
Type indexType = getIndexType();
|
|
if (std::optional<int64_t> index = getConstantDimIndex(dimOp)) {
|
|
int64_t i = *index;
|
|
if (i >= 0 && i < memRefType.getRank()) {
|
|
if (memRefType.isDynamicDim(i)) {
|
|
// extract dynamic size from the memref descriptor.
|
|
MemRefDescriptor descriptor(adaptor.getSource());
|
|
return descriptor.size(rewriter, loc, i);
|
|
}
|
|
// Use constant for static size.
|
|
int64_t dimSize = memRefType.getDimSize(i);
|
|
return createIndexAttrConstant(rewriter, loc, indexType, dimSize);
|
|
}
|
|
}
|
|
Value index = adaptor.getIndex();
|
|
int64_t rank = memRefType.getRank();
|
|
MemRefDescriptor memrefDescriptor(adaptor.getSource());
|
|
return memrefDescriptor.size(rewriter, loc, index, rank);
|
|
}
|
|
};
|
|
|
|
/// Common base for load and store operations on MemRefs. Restricts the match
|
|
/// to supported MemRef types. Provides functionality to emit code accessing a
|
|
/// specific element of the underlying data buffer.
|
|
template <typename Derived>
|
|
struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
|
|
using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
|
|
using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
|
|
using Base = LoadStoreOpLowering<Derived>;
|
|
};
|
|
|
|
/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
|
|
/// retried until it succeeds in atomically storing a new value into memory.
|
|
///
|
|
/// +---------------------------------+
|
|
/// | <code before the AtomicRMWOp> |
|
|
/// | <compute initial %loaded> |
|
|
/// | cf.br loop(%loaded) |
|
|
/// +---------------------------------+
|
|
/// |
|
|
/// -------| |
|
|
/// | v v
|
|
/// | +--------------------------------+
|
|
/// | | loop(%loaded): |
|
|
/// | | <body contents> |
|
|
/// | | %pair = cmpxchg |
|
|
/// | | %ok = %pair[0] |
|
|
/// | | %new = %pair[1] |
|
|
/// | | cf.cond_br %ok, end, loop(%new) |
|
|
/// | +--------------------------------+
|
|
/// | | |
|
|
/// |----------- |
|
|
/// v
|
|
/// +--------------------------------+
|
|
/// | end: |
|
|
/// | <code after the AtomicRMWOp> |
|
|
/// +--------------------------------+
|
|
///
|
|
struct GenericAtomicRMWOpLowering
|
|
: public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
|
|
using Base::Base;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto loc = atomicOp.getLoc();
|
|
Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
|
|
|
|
// Split the block into initial, loop, and ending parts.
|
|
auto *initBlock = rewriter.getInsertionBlock();
|
|
auto *loopBlock = rewriter.splitBlock(initBlock, Block::iterator(atomicOp));
|
|
loopBlock->addArgument(valueType, loc);
|
|
|
|
auto *endBlock =
|
|
rewriter.splitBlock(loopBlock, Block::iterator(atomicOp)++);
|
|
|
|
// Compute the loaded value and branch to the loop block.
|
|
rewriter.setInsertionPointToEnd(initBlock);
|
|
auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
|
|
auto dataPtr = getStridedElementPtr(
|
|
rewriter, loc, memRefType, adaptor.getMemref(), adaptor.getIndices());
|
|
Value init = LLVM::LoadOp::create(
|
|
rewriter, loc, typeConverter->convertType(memRefType.getElementType()),
|
|
dataPtr);
|
|
LLVM::BrOp::create(rewriter, loc, init, loopBlock);
|
|
|
|
// Prepare the body of the loop block.
|
|
rewriter.setInsertionPointToStart(loopBlock);
|
|
|
|
// Clone the GenericAtomicRMWOp region and extract the result.
|
|
auto loopArgument = loopBlock->getArgument(0);
|
|
IRMapping mapping;
|
|
mapping.map(atomicOp.getCurrentValue(), loopArgument);
|
|
Block &entryBlock = atomicOp.body().front();
|
|
for (auto &nestedOp : entryBlock.without_terminator()) {
|
|
Operation *clone = rewriter.clone(nestedOp, mapping);
|
|
mapping.map(nestedOp.getResults(), clone->getResults());
|
|
}
|
|
Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));
|
|
|
|
// Prepare the epilog of the loop block.
|
|
// Append the cmpxchg op to the end of the loop block.
|
|
auto successOrdering = LLVM::AtomicOrdering::acq_rel;
|
|
auto failureOrdering = LLVM::AtomicOrdering::monotonic;
|
|
auto cmpxchg =
|
|
LLVM::AtomicCmpXchgOp::create(rewriter, loc, dataPtr, loopArgument,
|
|
result, successOrdering, failureOrdering);
|
|
// Extract the %new_loaded and %ok values from the pair.
|
|
Value newLoaded = LLVM::ExtractValueOp::create(rewriter, loc, cmpxchg, 0);
|
|
Value ok = LLVM::ExtractValueOp::create(rewriter, loc, cmpxchg, 1);
|
|
|
|
// Conditionally branch to the end or back to the loop depending on %ok.
|
|
LLVM::CondBrOp::create(rewriter, loc, ok, endBlock, ArrayRef<Value>(),
|
|
loopBlock, newLoaded);
|
|
|
|
rewriter.setInsertionPointToEnd(endBlock);
|
|
|
|
// The 'result' of the atomic_rmw op is the newly loaded value.
|
|
rewriter.replaceOp(atomicOp, {newLoaded});
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Returns the LLVM type of the global variable given the memref type `type`.
|
|
static Type
|
|
convertGlobalMemrefTypeToLLVM(MemRefType type,
|
|
const LLVMTypeConverter &typeConverter) {
|
|
// LLVM type for a global memref will be a multi-dimension array. For
|
|
// declarations or uninitialized global memrefs, we can potentially flatten
|
|
// this to a 1D array. However, for memref.global's with an initial value,
|
|
// we do not intend to flatten the ElementsAttribute when going from std ->
|
|
// LLVM dialect, so the LLVM type needs to me a multi-dimension array.
|
|
Type elementType = typeConverter.convertType(type.getElementType());
|
|
Type arrayTy = elementType;
|
|
// Shape has the outermost dim at index 0, so need to walk it backwards
|
|
for (int64_t dim : llvm::reverse(type.getShape()))
|
|
arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
|
|
return arrayTy;
|
|
}
|
|
|
|
/// GlobalMemrefOp is lowered to a LLVM Global Variable.
|
|
class GlobalMemrefOpLowering : public ConvertOpToLLVMPattern<memref::GlobalOp> {
|
|
SymbolTableCollection *symbolTables = nullptr;
|
|
|
|
public:
|
|
explicit GlobalMemrefOpLowering(const LLVMTypeConverter &typeConverter,
|
|
SymbolTableCollection *symbolTables = nullptr,
|
|
PatternBenefit benefit = 1)
|
|
: ConvertOpToLLVMPattern<memref::GlobalOp>(typeConverter, benefit),
|
|
symbolTables(symbolTables) {}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
MemRefType type = global.getType();
|
|
if (!isConvertibleAndHasIdentityMaps(type))
|
|
return failure();
|
|
|
|
Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
|
|
|
|
LLVM::Linkage linkage =
|
|
global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
|
|
bool isExternal = global.isExternal();
|
|
bool isUninitialized = global.isUninitialized();
|
|
|
|
Attribute initialValue = nullptr;
|
|
if (!isExternal && !isUninitialized) {
|
|
auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue());
|
|
initialValue = elementsAttr;
|
|
|
|
// For scalar memrefs, the global variable created is of the element type,
|
|
// so unpack the elements attribute to extract the value.
|
|
if (type.getRank() == 0)
|
|
initialValue = elementsAttr.getSplatValue<Attribute>();
|
|
}
|
|
|
|
uint64_t alignment = global.getAlignment().value_or(0);
|
|
FailureOr<unsigned> addressSpace =
|
|
getTypeConverter()->getMemRefAddressSpace(type);
|
|
if (failed(addressSpace))
|
|
return global.emitOpError(
|
|
"memory space cannot be converted to an integer address space");
|
|
|
|
// Remove old operation from symbol table.
|
|
SymbolTable *symbolTable = nullptr;
|
|
if (symbolTables) {
|
|
Operation *symbolTableOp =
|
|
global->getParentWithTrait<OpTrait::SymbolTable>();
|
|
symbolTable = &symbolTables->getSymbolTable(symbolTableOp);
|
|
symbolTable->remove(global);
|
|
}
|
|
|
|
// Create new operation.
|
|
auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
|
|
global, arrayTy, global.getConstant(), linkage, global.getSymName(),
|
|
initialValue, alignment, *addressSpace);
|
|
|
|
// Insert new operation into symbol table.
|
|
if (symbolTable)
|
|
symbolTable->insert(newGlobal, rewriter.getInsertionPoint());
|
|
|
|
if (!isExternal && isUninitialized) {
|
|
rewriter.createBlock(&newGlobal.getInitializerRegion());
|
|
Value undef[] = {
|
|
LLVM::UndefOp::create(rewriter, newGlobal.getLoc(), arrayTy)};
|
|
LLVM::ReturnOp::create(rewriter, newGlobal.getLoc(), undef);
|
|
}
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
|
|
/// the first element stashed into the descriptor. This reuses
|
|
/// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
|
|
struct GetGlobalMemrefOpLowering
|
|
: public ConvertOpToLLVMPattern<memref::GetGlobalOp> {
|
|
using ConvertOpToLLVMPattern<memref::GetGlobalOp>::ConvertOpToLLVMPattern;
|
|
|
|
/// Buffer "allocation" for memref.get_global op is getting the address of
|
|
/// the global variable referenced.
|
|
LogicalResult
|
|
matchAndRewrite(memref::GetGlobalOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto loc = op.getLoc();
|
|
MemRefType memRefType = op.getType();
|
|
if (!isConvertibleAndHasIdentityMaps(memRefType))
|
|
return rewriter.notifyMatchFailure(op, "incompatible memref type");
|
|
|
|
// Get actual sizes of the memref as values: static sizes are constant
|
|
// values and dynamic sizes are passed to 'alloc' as operands. In case of
|
|
// zero-dimensional memref, assume a scalar (size 1).
|
|
SmallVector<Value, 4> sizes;
|
|
SmallVector<Value, 4> strides;
|
|
Value sizeBytes;
|
|
|
|
this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
|
|
rewriter, sizes, strides, sizeBytes, !false);
|
|
|
|
MemRefType type = cast<MemRefType>(op.getResult().getType());
|
|
|
|
// This is called after a type conversion, which would have failed if this
|
|
// call fails.
|
|
FailureOr<unsigned> maybeAddressSpace =
|
|
getTypeConverter()->getMemRefAddressSpace(type);
|
|
assert(succeeded(maybeAddressSpace) && "unsupported address space");
|
|
unsigned memSpace = *maybeAddressSpace;
|
|
|
|
Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
|
|
auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), memSpace);
|
|
auto addressOf =
|
|
LLVM::AddressOfOp::create(rewriter, loc, ptrTy, op.getName());
|
|
|
|
// Get the address of the first element in the array by creating a GEP with
|
|
// the address of the GV as the base, and (rank + 1) number of 0 indices.
|
|
auto gep =
|
|
LLVM::GEPOp::create(rewriter, loc, ptrTy, arrayTy, addressOf,
|
|
SmallVector<LLVM::GEPArg>(type.getRank() + 1, 0));
|
|
|
|
// We do not expect the memref obtained using `memref.get_global` to be
|
|
// ever deallocated. Set the allocated pointer to be known bad value to
|
|
// help debug if that ever happens.
|
|
auto intPtrType = getIntPtrType(memSpace);
|
|
Value deadBeefConst =
|
|
createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
|
|
auto deadBeefPtr =
|
|
LLVM::IntToPtrOp::create(rewriter, loc, ptrTy, deadBeefConst);
|
|
|
|
// Both allocated and aligned pointers are same. We could potentially stash
|
|
// a nullptr for the allocated pointer since we do not expect any dealloc.
|
|
// Create the MemRef descriptor.
|
|
auto memRefDescriptor = this->createMemRefDescriptor(
|
|
loc, memRefType, deadBeefPtr, gep, sizes, strides, rewriter);
|
|
|
|
// Return the final value of the descriptor.
|
|
rewriter.replaceOp(op, {memRefDescriptor});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// Load operation is lowered to obtaining a pointer to the indexed element
|
|
// and loading it.
|
|
struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
|
|
using Base::Base;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto type = loadOp.getMemRefType();
|
|
|
|
// Per memref.load spec, the indices must be in-bounds:
|
|
// 0 <= idx < dim_size, and additionally all offsets are non-negative,
|
|
// hence inbounds and nuw are used when lowering to llvm.getelementptr.
|
|
Value dataPtr = getStridedElementPtr(rewriter, loadOp.getLoc(), type,
|
|
adaptor.getMemref(),
|
|
adaptor.getIndices(), kNoWrapFlags);
|
|
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
|
|
loadOp, typeConverter->convertType(type.getElementType()), dataPtr,
|
|
loadOp.getAlignment().value_or(0), false, loadOp.getNontemporal());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// Store operation is lowered to obtaining a pointer to the indexed element,
|
|
// and storing the given value to it.
|
|
struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
|
|
using Base::Base;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto type = op.getMemRefType();
|
|
|
|
// Per memref.store spec, the indices must be in-bounds:
|
|
// 0 <= idx < dim_size, and additionally all offsets are non-negative,
|
|
// hence inbounds and nuw are used when lowering to llvm.getelementptr.
|
|
Value dataPtr =
|
|
getStridedElementPtr(rewriter, op.getLoc(), type, adaptor.getMemref(),
|
|
adaptor.getIndices(), kNoWrapFlags);
|
|
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
|
|
op.getAlignment().value_or(0),
|
|
false, op.getNontemporal());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// The prefetch operation is lowered in a way similar to the load operation
|
|
// except that the llvm.prefetch operation is used for replacement.
|
|
struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
|
|
using Base::Base;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto type = prefetchOp.getMemRefType();
|
|
auto loc = prefetchOp.getLoc();
|
|
|
|
Value dataPtr = getStridedElementPtr(
|
|
rewriter, loc, type, adaptor.getMemref(), adaptor.getIndices());
|
|
|
|
// Replace with llvm.prefetch.
|
|
IntegerAttr isWrite = rewriter.getI32IntegerAttr(prefetchOp.getIsWrite());
|
|
IntegerAttr localityHint = prefetchOp.getLocalityHintAttr();
|
|
IntegerAttr isData =
|
|
rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache());
|
|
rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
|
|
localityHint, isData);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
|
|
using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
Type operandType = op.getMemref().getType();
|
|
if (isa<UnrankedMemRefType>(operandType)) {
|
|
UnrankedMemRefDescriptor desc(adaptor.getMemref());
|
|
rewriter.replaceOp(op, {desc.rank(rewriter, loc)});
|
|
return success();
|
|
}
|
|
if (auto rankedMemRefType = dyn_cast<MemRefType>(operandType)) {
|
|
Type indexType = getIndexType();
|
|
rewriter.replaceOp(op,
|
|
{createIndexAttrConstant(rewriter, loc, indexType,
|
|
rankedMemRefType.getRank())});
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
};
|
|
|
|
struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
|
|
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Type srcType = memRefCastOp.getOperand().getType();
|
|
Type dstType = memRefCastOp.getType();
|
|
|
|
// memref::CastOp reduce to bitcast in the ranked MemRef case and can be
|
|
// used for type erasure. For now they must preserve underlying element type
|
|
// and require source and result type to have the same rank. Therefore,
|
|
// perform a sanity check that the underlying structs are the same. Once op
|
|
// semantics are relaxed we can revisit.
|
|
if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
|
|
if (typeConverter->convertType(srcType) !=
|
|
typeConverter->convertType(dstType))
|
|
return failure();
|
|
|
|
// Unranked to unranked cast is disallowed
|
|
if (isa<UnrankedMemRefType>(srcType) && isa<UnrankedMemRefType>(dstType))
|
|
return failure();
|
|
|
|
auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
|
|
auto loc = memRefCastOp.getLoc();
|
|
|
|
// For ranked/ranked case, just keep the original descriptor.
|
|
if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType)) {
|
|
rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
|
|
return success();
|
|
}
|
|
|
|
if (isa<MemRefType>(srcType) && isa<UnrankedMemRefType>(dstType)) {
|
|
// Casting ranked to unranked memref type
|
|
// Set the rank in the destination from the memref type
|
|
// Allocate space on the stack and copy the src memref descriptor
|
|
// Set the ptr in the destination to the stack space
|
|
auto srcMemRefType = cast<MemRefType>(srcType);
|
|
int64_t rank = srcMemRefType.getRank();
|
|
// ptr = AllocaOp sizeof(MemRefDescriptor)
|
|
auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
|
|
loc, adaptor.getSource(), rewriter);
|
|
|
|
// rank = ConstantOp srcRank
|
|
auto rankVal = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
|
|
rewriter.getIndexAttr(rank));
|
|
// poison = PoisonOp
|
|
UnrankedMemRefDescriptor memRefDesc =
|
|
UnrankedMemRefDescriptor::poison(rewriter, loc, targetStructType);
|
|
// d1 = InsertValueOp poison, rank, 0
|
|
memRefDesc.setRank(rewriter, loc, rankVal);
|
|
// d2 = InsertValueOp d1, ptr, 1
|
|
memRefDesc.setMemRefDescPtr(rewriter, loc, ptr);
|
|
rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
|
|
|
|
} else if (isa<UnrankedMemRefType>(srcType) && isa<MemRefType>(dstType)) {
|
|
// Casting from unranked type to ranked.
|
|
// The operation is assumed to be doing a correct cast. If the destination
|
|
// type mismatches the unranked the type, it is undefined behavior.
|
|
UnrankedMemRefDescriptor memRefDesc(adaptor.getSource());
|
|
// ptr = ExtractValueOp src, 1
|
|
auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
|
|
|
|
// struct = LoadOp ptr
|
|
auto loadOp = LLVM::LoadOp::create(rewriter, loc, targetStructType, ptr);
|
|
rewriter.replaceOp(memRefCastOp, loadOp.getResult());
|
|
} else {
|
|
llvm_unreachable("Unsupported unranked memref to unranked memref cast");
|
|
}
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Pattern to lower a `memref.copy` to llvm.
|
|
///
|
|
/// For memrefs with identity layouts, the copy is lowered to the llvm
|
|
/// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
|
|
/// to the generic `MemrefCopyFn`.
|
|
class MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
|
|
SymbolTableCollection *symbolTables = nullptr;
|
|
|
|
public:
|
|
explicit MemRefCopyOpLowering(const LLVMTypeConverter &typeConverter,
|
|
SymbolTableCollection *symbolTables = nullptr,
|
|
PatternBenefit benefit = 1)
|
|
: ConvertOpToLLVMPattern<memref::CopyOp>(typeConverter, benefit),
|
|
symbolTables(symbolTables) {}
|
|
|
|
LogicalResult
|
|
lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
auto loc = op.getLoc();
|
|
auto srcType = dyn_cast<MemRefType>(op.getSource().getType());
|
|
|
|
MemRefDescriptor srcDesc(adaptor.getSource());
|
|
|
|
// Compute number of elements.
|
|
Value numElements = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
|
|
rewriter.getIndexAttr(1));
|
|
for (int pos = 0; pos < srcType.getRank(); ++pos) {
|
|
auto size = srcDesc.size(rewriter, loc, pos);
|
|
numElements = LLVM::MulOp::create(rewriter, loc, numElements, size);
|
|
}
|
|
|
|
// Get element size.
|
|
auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
|
|
// Compute total.
|
|
Value totalSize =
|
|
LLVM::MulOp::create(rewriter, loc, numElements, sizeInBytes);
|
|
|
|
Type elementType = typeConverter->convertType(srcType.getElementType());
|
|
|
|
Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
|
|
Value srcOffset = srcDesc.offset(rewriter, loc);
|
|
Value srcPtr = LLVM::GEPOp::create(rewriter, loc, srcBasePtr.getType(),
|
|
elementType, srcBasePtr, srcOffset);
|
|
MemRefDescriptor targetDesc(adaptor.getTarget());
|
|
Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
|
|
Value targetOffset = targetDesc.offset(rewriter, loc);
|
|
Value targetPtr =
|
|
LLVM::GEPOp::create(rewriter, loc, targetBasePtr.getType(), elementType,
|
|
targetBasePtr, targetOffset);
|
|
LLVM::MemcpyOp::create(rewriter, loc, targetPtr, srcPtr, totalSize,
|
|
/*isVolatile=*/false);
|
|
rewriter.eraseOp(op);
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult
|
|
lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
auto loc = op.getLoc();
|
|
auto srcType = cast<BaseMemRefType>(op.getSource().getType());
|
|
auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
|
|
|
|
// First make sure we have an unranked memref descriptor representation.
|
|
auto makeUnranked = [&, this](Value ranked, MemRefType type) {
|
|
auto rank = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
|
|
type.getRank());
|
|
auto *typeConverter = getTypeConverter();
|
|
auto ptr =
|
|
typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
|
|
|
|
auto unrankedType =
|
|
UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
|
|
return UnrankedMemRefDescriptor::pack(
|
|
rewriter, loc, *typeConverter, unrankedType, ValueRange{rank, ptr});
|
|
};
|
|
|
|
// Save stack position before promoting descriptors
|
|
auto stackSaveOp = LLVM::StackSaveOp::create(rewriter, loc, getPtrType());
|
|
|
|
auto srcMemRefType = dyn_cast<MemRefType>(srcType);
|
|
Value unrankedSource =
|
|
srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType)
|
|
: adaptor.getSource();
|
|
auto targetMemRefType = dyn_cast<MemRefType>(targetType);
|
|
Value unrankedTarget =
|
|
targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType)
|
|
: adaptor.getTarget();
|
|
|
|
// Now promote the unranked descriptors to the stack.
|
|
auto one = LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
|
|
rewriter.getIndexAttr(1));
|
|
auto promote = [&](Value desc) {
|
|
auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
|
|
auto allocated =
|
|
LLVM::AllocaOp::create(rewriter, loc, ptrType, desc.getType(), one);
|
|
LLVM::StoreOp::create(rewriter, loc, desc, allocated);
|
|
return allocated;
|
|
};
|
|
|
|
auto sourcePtr = promote(unrankedSource);
|
|
auto targetPtr = promote(unrankedTarget);
|
|
|
|
// Derive size from llvm.getelementptr which will account for any
|
|
// potential alignment
|
|
auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
|
|
auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
|
|
rewriter, op->getParentOfType<ModuleOp>(), getIndexType(),
|
|
sourcePtr.getType(), symbolTables);
|
|
if (failed(copyFn))
|
|
return failure();
|
|
LLVM::CallOp::create(rewriter, loc, copyFn.value(),
|
|
ValueRange{elemSize, sourcePtr, targetPtr});
|
|
|
|
// Restore stack used for descriptors
|
|
LLVM::StackRestoreOp::create(rewriter, loc, stackSaveOp);
|
|
|
|
rewriter.eraseOp(op);
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto srcType = cast<BaseMemRefType>(op.getSource().getType());
|
|
auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
|
|
|
|
auto isContiguousMemrefType = [&](BaseMemRefType type) {
|
|
auto memrefType = dyn_cast<mlir::MemRefType>(type);
|
|
// We can use memcpy for memrefs if they have an identity layout or are
|
|
// contiguous with an arbitrary offset. Ignore empty memrefs, which is a
|
|
// special case handled by memrefCopy.
|
|
return memrefType &&
|
|
(memrefType.getLayout().isIdentity() ||
|
|
(memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
|
|
memref::isStaticShapeAndContiguousRowMajor(memrefType)));
|
|
};
|
|
|
|
if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
|
|
return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
|
|
|
|
return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
|
|
}
|
|
};
|
|
|
|
struct MemorySpaceCastOpLowering
|
|
: public ConvertOpToLLVMPattern<memref::MemorySpaceCastOp> {
|
|
using ConvertOpToLLVMPattern<
|
|
memref::MemorySpaceCastOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Location loc = op.getLoc();
|
|
|
|
Type resultType = op.getDest().getType();
|
|
if (auto resultTypeR = dyn_cast<MemRefType>(resultType)) {
|
|
auto resultDescType =
|
|
cast<LLVM::LLVMStructType>(typeConverter->convertType(resultTypeR));
|
|
Type newPtrType = resultDescType.getBody()[0];
|
|
|
|
SmallVector<Value> descVals;
|
|
MemRefDescriptor::unpack(rewriter, loc, adaptor.getSource(), resultTypeR,
|
|
descVals);
|
|
descVals[0] =
|
|
LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[0]);
|
|
descVals[1] =
|
|
LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[1]);
|
|
Value result = MemRefDescriptor::pack(rewriter, loc, *getTypeConverter(),
|
|
resultTypeR, descVals);
|
|
rewriter.replaceOp(op, result);
|
|
return success();
|
|
}
|
|
if (auto resultTypeU = dyn_cast<UnrankedMemRefType>(resultType)) {
|
|
// Since the type converter won't be doing this for us, get the address
|
|
// space.
|
|
auto sourceType = cast<UnrankedMemRefType>(op.getSource().getType());
|
|
FailureOr<unsigned> maybeSourceAddrSpace =
|
|
getTypeConverter()->getMemRefAddressSpace(sourceType);
|
|
if (failed(maybeSourceAddrSpace))
|
|
return rewriter.notifyMatchFailure(loc,
|
|
"non-integer source address space");
|
|
unsigned sourceAddrSpace = *maybeSourceAddrSpace;
|
|
FailureOr<unsigned> maybeResultAddrSpace =
|
|
getTypeConverter()->getMemRefAddressSpace(resultTypeU);
|
|
if (failed(maybeResultAddrSpace))
|
|
return rewriter.notifyMatchFailure(loc,
|
|
"non-integer result address space");
|
|
unsigned resultAddrSpace = *maybeResultAddrSpace;
|
|
|
|
UnrankedMemRefDescriptor sourceDesc(adaptor.getSource());
|
|
Value rank = sourceDesc.rank(rewriter, loc);
|
|
Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(rewriter, loc);
|
|
|
|
// Create and allocate storage for new memref descriptor.
|
|
auto result = UnrankedMemRefDescriptor::poison(
|
|
rewriter, loc, typeConverter->convertType(resultTypeU));
|
|
result.setRank(rewriter, loc, rank);
|
|
Value resultUnderlyingSize = UnrankedMemRefDescriptor::computeSize(
|
|
rewriter, loc, *getTypeConverter(), result, resultAddrSpace);
|
|
Value resultUnderlyingDesc =
|
|
LLVM::AllocaOp::create(rewriter, loc, getPtrType(),
|
|
rewriter.getI8Type(), resultUnderlyingSize);
|
|
result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc);
|
|
|
|
// Copy pointers, performing address space casts.
|
|
auto sourceElemPtrType =
|
|
LLVM::LLVMPointerType::get(rewriter.getContext(), sourceAddrSpace);
|
|
auto resultElemPtrType =
|
|
LLVM::LLVMPointerType::get(rewriter.getContext(), resultAddrSpace);
|
|
|
|
Value allocatedPtr = sourceDesc.allocatedPtr(
|
|
rewriter, loc, sourceUnderlyingDesc, sourceElemPtrType);
|
|
Value alignedPtr =
|
|
sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(),
|
|
sourceUnderlyingDesc, sourceElemPtrType);
|
|
allocatedPtr = LLVM::AddrSpaceCastOp::create(
|
|
rewriter, loc, resultElemPtrType, allocatedPtr);
|
|
alignedPtr = LLVM::AddrSpaceCastOp::create(rewriter, loc,
|
|
resultElemPtrType, alignedPtr);
|
|
|
|
result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
|
|
resultElemPtrType, allocatedPtr);
|
|
result.setAlignedPtr(rewriter, loc, *getTypeConverter(),
|
|
resultUnderlyingDesc, resultElemPtrType, alignedPtr);
|
|
|
|
// Copy all the index-valued operands.
|
|
Value sourceIndexVals =
|
|
sourceDesc.offsetBasePtr(rewriter, loc, *getTypeConverter(),
|
|
sourceUnderlyingDesc, sourceElemPtrType);
|
|
Value resultIndexVals =
|
|
result.offsetBasePtr(rewriter, loc, *getTypeConverter(),
|
|
resultUnderlyingDesc, resultElemPtrType);
|
|
|
|
int64_t bytesToSkip =
|
|
2 * llvm::divideCeil(
|
|
getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
|
|
Value bytesToSkipConst = LLVM::ConstantOp::create(
|
|
rewriter, loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip));
|
|
Value copySize =
|
|
LLVM::SubOp::create(rewriter, loc, getIndexType(),
|
|
resultUnderlyingSize, bytesToSkipConst);
|
|
LLVM::MemcpyOp::create(rewriter, loc, resultIndexVals, sourceIndexVals,
|
|
copySize, /*isVolatile=*/false);
|
|
|
|
rewriter.replaceOp(op, ValueRange{result});
|
|
return success();
|
|
}
|
|
return rewriter.notifyMatchFailure(loc, "unexpected memref type");
|
|
}
|
|
};
|
|
|
|
/// Extracts allocated, aligned pointers and offset from a ranked or unranked
|
|
/// memref type. In unranked case, the fields are extracted from the underlying
|
|
/// ranked descriptor.
|
|
static void extractPointersAndOffset(Location loc,
|
|
ConversionPatternRewriter &rewriter,
|
|
const LLVMTypeConverter &typeConverter,
|
|
Value originalOperand,
|
|
Value convertedOperand,
|
|
Value *allocatedPtr, Value *alignedPtr,
|
|
Value *offset = nullptr) {
|
|
Type operandType = originalOperand.getType();
|
|
if (isa<MemRefType>(operandType)) {
|
|
MemRefDescriptor desc(convertedOperand);
|
|
*allocatedPtr = desc.allocatedPtr(rewriter, loc);
|
|
*alignedPtr = desc.alignedPtr(rewriter, loc);
|
|
if (offset != nullptr)
|
|
*offset = desc.offset(rewriter, loc);
|
|
return;
|
|
}
|
|
|
|
// These will all cause assert()s on unconvertible types.
|
|
unsigned memorySpace = *typeConverter.getMemRefAddressSpace(
|
|
cast<UnrankedMemRefType>(operandType));
|
|
auto elementPtrType =
|
|
LLVM::LLVMPointerType::get(rewriter.getContext(), memorySpace);
|
|
|
|
// Extract pointer to the underlying ranked memref descriptor and cast it to
|
|
// ElemType**.
|
|
UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
|
|
Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc);
|
|
|
|
*allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
|
|
rewriter, loc, underlyingDescPtr, elementPtrType);
|
|
*alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
|
|
rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
|
|
if (offset != nullptr) {
|
|
*offset = UnrankedMemRefDescriptor::offset(
|
|
rewriter, loc, typeConverter, underlyingDescPtr, elementPtrType);
|
|
}
|
|
}
|
|
|
|
struct MemRefReinterpretCastOpLowering
|
|
: public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
|
|
using ConvertOpToLLVMPattern<
|
|
memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Type srcType = castOp.getSource().getType();
|
|
|
|
Value descriptor;
|
|
if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
|
|
adaptor, &descriptor)))
|
|
return failure();
|
|
rewriter.replaceOp(castOp, {descriptor});
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
LogicalResult convertSourceMemRefToDescriptor(
|
|
ConversionPatternRewriter &rewriter, Type srcType,
|
|
memref::ReinterpretCastOp castOp,
|
|
memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
|
|
MemRefType targetMemRefType =
|
|
cast<MemRefType>(castOp.getResult().getType());
|
|
auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
|
|
typeConverter->convertType(targetMemRefType));
|
|
if (!llvmTargetDescriptorTy)
|
|
return failure();
|
|
|
|
// Create descriptor.
|
|
Location loc = castOp.getLoc();
|
|
auto desc = MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
|
|
|
|
// Set allocated and aligned pointers.
|
|
Value allocatedPtr, alignedPtr;
|
|
extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
|
|
castOp.getSource(), adaptor.getSource(),
|
|
&allocatedPtr, &alignedPtr);
|
|
desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
|
|
desc.setAlignedPtr(rewriter, loc, alignedPtr);
|
|
|
|
// Set offset.
|
|
if (castOp.isDynamicOffset(0))
|
|
desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
|
|
else
|
|
desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
|
|
|
|
// Set sizes and strides.
|
|
unsigned dynSizeId = 0;
|
|
unsigned dynStrideId = 0;
|
|
for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
|
|
if (castOp.isDynamicSize(i))
|
|
desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
|
|
else
|
|
desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
|
|
|
|
if (castOp.isDynamicStride(i))
|
|
desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
|
|
else
|
|
desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
|
|
}
|
|
*descriptor = desc;
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct MemRefReshapeOpLowering
|
|
: public ConvertOpToLLVMPattern<memref::ReshapeOp> {
|
|
using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Type srcType = reshapeOp.getSource().getType();
|
|
|
|
Value descriptor;
|
|
if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
|
|
adaptor, &descriptor)))
|
|
return failure();
|
|
rewriter.replaceOp(reshapeOp, {descriptor});
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
LogicalResult
|
|
convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
|
|
Type srcType, memref::ReshapeOp reshapeOp,
|
|
memref::ReshapeOp::Adaptor adaptor,
|
|
Value *descriptor) const {
|
|
auto shapeMemRefType = cast<MemRefType>(reshapeOp.getShape().getType());
|
|
if (shapeMemRefType.hasStaticShape()) {
|
|
MemRefType targetMemRefType =
|
|
cast<MemRefType>(reshapeOp.getResult().getType());
|
|
auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
|
|
typeConverter->convertType(targetMemRefType));
|
|
if (!llvmTargetDescriptorTy)
|
|
return failure();
|
|
|
|
// Create descriptor.
|
|
Location loc = reshapeOp.getLoc();
|
|
auto desc =
|
|
MemRefDescriptor::poison(rewriter, loc, llvmTargetDescriptorTy);
|
|
|
|
// Set allocated and aligned pointers.
|
|
Value allocatedPtr, alignedPtr;
|
|
extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
|
|
reshapeOp.getSource(), adaptor.getSource(),
|
|
&allocatedPtr, &alignedPtr);
|
|
desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
|
|
desc.setAlignedPtr(rewriter, loc, alignedPtr);
|
|
|
|
// Extract the offset and strides from the type.
|
|
int64_t offset;
|
|
SmallVector<int64_t> strides;
|
|
if (failed(targetMemRefType.getStridesAndOffset(strides, offset)))
|
|
return rewriter.notifyMatchFailure(
|
|
reshapeOp, "failed to get stride and offset exprs");
|
|
|
|
if (!isStaticStrideOrOffset(offset))
|
|
return rewriter.notifyMatchFailure(reshapeOp,
|
|
"dynamic offset is unsupported");
|
|
|
|
desc.setConstantOffset(rewriter, loc, offset);
|
|
|
|
assert(targetMemRefType.getLayout().isIdentity() &&
|
|
"Identity layout map is a precondition of a valid reshape op");
|
|
|
|
Type indexType = getIndexType();
|
|
Value stride = nullptr;
|
|
int64_t targetRank = targetMemRefType.getRank();
|
|
for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
|
|
if (ShapedType::isStatic(strides[i])) {
|
|
// If the stride for this dimension is dynamic, then use the product
|
|
// of the sizes of the inner dimensions.
|
|
stride =
|
|
createIndexAttrConstant(rewriter, loc, indexType, strides[i]);
|
|
} else if (!stride) {
|
|
// `stride` is null only in the first iteration of the loop. However,
|
|
// since the target memref has an identity layout, we can safely set
|
|
// the innermost stride to 1.
|
|
stride = createIndexAttrConstant(rewriter, loc, indexType, 1);
|
|
}
|
|
|
|
Value dimSize;
|
|
// If the size of this dimension is dynamic, then load it at runtime
|
|
// from the shape operand.
|
|
if (!targetMemRefType.isDynamicDim(i)) {
|
|
dimSize = createIndexAttrConstant(rewriter, loc, indexType,
|
|
targetMemRefType.getDimSize(i));
|
|
} else {
|
|
Value shapeOp = reshapeOp.getShape();
|
|
Value index = createIndexAttrConstant(rewriter, loc, indexType, i);
|
|
dimSize = memref::LoadOp::create(rewriter, loc, shapeOp, index);
|
|
Type indexType = getIndexType();
|
|
if (dimSize.getType() != indexType)
|
|
dimSize = typeConverter->materializeTargetConversion(
|
|
rewriter, loc, indexType, dimSize);
|
|
assert(dimSize && "Invalid memref element type");
|
|
}
|
|
|
|
desc.setSize(rewriter, loc, i, dimSize);
|
|
desc.setStride(rewriter, loc, i, stride);
|
|
|
|
// Prepare the stride value for the next dimension.
|
|
stride = LLVM::MulOp::create(rewriter, loc, stride, dimSize);
|
|
}
|
|
|
|
*descriptor = desc;
|
|
return success();
|
|
}
|
|
|
|
// The shape is a rank-1 tensor with unknown length.
|
|
Location loc = reshapeOp.getLoc();
|
|
MemRefDescriptor shapeDesc(adaptor.getShape());
|
|
Value resultRank = shapeDesc.size(rewriter, loc, 0);
|
|
|
|
// Extract address space and element type.
|
|
auto targetType = cast<UnrankedMemRefType>(reshapeOp.getResult().getType());
|
|
unsigned addressSpace =
|
|
*getTypeConverter()->getMemRefAddressSpace(targetType);
|
|
|
|
// Create the unranked memref descriptor that holds the ranked one. The
|
|
// inner descriptor is allocated on stack.
|
|
auto targetDesc = UnrankedMemRefDescriptor::poison(
|
|
rewriter, loc, typeConverter->convertType(targetType));
|
|
targetDesc.setRank(rewriter, loc, resultRank);
|
|
Value allocationSize = UnrankedMemRefDescriptor::computeSize(
|
|
rewriter, loc, *getTypeConverter(), targetDesc, addressSpace);
|
|
Value underlyingDescPtr = LLVM::AllocaOp::create(
|
|
rewriter, loc, getPtrType(), IntegerType::get(getContext(), 8),
|
|
allocationSize);
|
|
targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
|
|
|
|
// Extract pointers and offset from the source memref.
|
|
Value allocatedPtr, alignedPtr, offset;
|
|
extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
|
|
reshapeOp.getSource(), adaptor.getSource(),
|
|
&allocatedPtr, &alignedPtr, &offset);
|
|
|
|
// Set pointers and offset.
|
|
auto elementPtrType =
|
|
LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
|
|
|
|
UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
|
|
elementPtrType, allocatedPtr);
|
|
UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
|
|
underlyingDescPtr, elementPtrType,
|
|
alignedPtr);
|
|
UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
|
|
underlyingDescPtr, elementPtrType,
|
|
offset);
|
|
|
|
// Use the offset pointer as base for further addressing. Copy over the new
|
|
// shape and compute strides. For this, we create a loop from rank-1 to 0.
|
|
Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
|
|
rewriter, loc, *getTypeConverter(), underlyingDescPtr, elementPtrType);
|
|
Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
|
|
rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank);
|
|
Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc);
|
|
Value oneIndex = createIndexAttrConstant(rewriter, loc, getIndexType(), 1);
|
|
Value resultRankMinusOne =
|
|
LLVM::SubOp::create(rewriter, loc, resultRank, oneIndex);
|
|
|
|
Block *initBlock = rewriter.getInsertionBlock();
|
|
Type indexType = getTypeConverter()->getIndexType();
|
|
Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
|
|
|
|
Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
|
|
{indexType, indexType}, {loc, loc});
|
|
|
|
// Move the remaining initBlock ops to condBlock.
|
|
Block *remainingBlock = rewriter.splitBlock(initBlock, remainingOpsIt);
|
|
rewriter.mergeBlocks(remainingBlock, condBlock, ValueRange());
|
|
|
|
rewriter.setInsertionPointToEnd(initBlock);
|
|
LLVM::BrOp::create(rewriter, loc,
|
|
ValueRange({resultRankMinusOne, oneIndex}), condBlock);
|
|
rewriter.setInsertionPointToStart(condBlock);
|
|
Value indexArg = condBlock->getArgument(0);
|
|
Value strideArg = condBlock->getArgument(1);
|
|
|
|
Value zeroIndex = createIndexAttrConstant(rewriter, loc, indexType, 0);
|
|
Value pred = LLVM::ICmpOp::create(
|
|
rewriter, loc, IntegerType::get(rewriter.getContext(), 1),
|
|
LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
|
|
|
|
Block *bodyBlock =
|
|
rewriter.splitBlock(condBlock, rewriter.getInsertionPoint());
|
|
rewriter.setInsertionPointToStart(bodyBlock);
|
|
|
|
// Copy size from shape to descriptor.
|
|
auto llvmIndexPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
|
|
Value sizeLoadGep = LLVM::GEPOp::create(
|
|
rewriter, loc, llvmIndexPtrType,
|
|
typeConverter->convertType(shapeMemRefType.getElementType()),
|
|
shapeOperandPtr, indexArg);
|
|
Value size = LLVM::LoadOp::create(rewriter, loc, indexType, sizeLoadGep);
|
|
UnrankedMemRefDescriptor::setSize(rewriter, loc, *getTypeConverter(),
|
|
targetSizesBase, indexArg, size);
|
|
|
|
// Write stride value and compute next one.
|
|
UnrankedMemRefDescriptor::setStride(rewriter, loc, *getTypeConverter(),
|
|
targetStridesBase, indexArg, strideArg);
|
|
Value nextStride = LLVM::MulOp::create(rewriter, loc, strideArg, size);
|
|
|
|
// Decrement loop counter and branch back.
|
|
Value decrement = LLVM::SubOp::create(rewriter, loc, indexArg, oneIndex);
|
|
LLVM::BrOp::create(rewriter, loc, ValueRange({decrement, nextStride}),
|
|
condBlock);
|
|
|
|
Block *remainder =
|
|
rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint());
|
|
|
|
// Hook up the cond exit to the remainder.
|
|
rewriter.setInsertionPointToEnd(condBlock);
|
|
LLVM::CondBrOp::create(rewriter, loc, pred, bodyBlock, ValueRange(),
|
|
remainder, ValueRange());
|
|
|
|
// Reset position to beginning of new remainder block.
|
|
rewriter.setInsertionPointToStart(remainder);
|
|
|
|
*descriptor = targetDesc;
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// RessociatingReshapeOp must be expanded before we reach this stage.
|
|
/// Report that information.
|
|
template <typename ReshapeOp>
|
|
class ReassociatingReshapeOpConversion
|
|
: public ConvertOpToLLVMPattern<ReshapeOp> {
|
|
public:
|
|
using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
|
|
using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
return rewriter.notifyMatchFailure(
|
|
reshapeOp,
|
|
"reassociation operations should have been expanded beforehand");
|
|
}
|
|
};
|
|
|
|
/// Subviews must be expanded before we reach this stage.
|
|
/// Report that information.
|
|
struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
|
|
using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
return rewriter.notifyMatchFailure(
|
|
subViewOp, "subview operations should have been expanded beforehand");
|
|
}
|
|
};
|
|
|
|
/// Conversion pattern that transforms a transpose op into:
|
|
/// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
|
|
/// 2. A load of the ViewDescriptor from the pointer allocated in 1.
|
|
/// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
|
|
/// and stride. Size and stride are permutations of the original values.
|
|
/// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
|
|
/// The transpose op is replaced by the alloca'ed pointer.
|
|
class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
|
|
public:
|
|
using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto loc = transposeOp.getLoc();
|
|
MemRefDescriptor viewMemRef(adaptor.getIn());
|
|
|
|
// No permutation, early exit.
|
|
if (transposeOp.getPermutation().isIdentity())
|
|
return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
|
|
|
|
auto targetMemRef = MemRefDescriptor::poison(
|
|
rewriter, loc,
|
|
typeConverter->convertType(transposeOp.getIn().getType()));
|
|
|
|
// Copy the base and aligned pointers from the old descriptor to the new
|
|
// one.
|
|
targetMemRef.setAllocatedPtr(rewriter, loc,
|
|
viewMemRef.allocatedPtr(rewriter, loc));
|
|
targetMemRef.setAlignedPtr(rewriter, loc,
|
|
viewMemRef.alignedPtr(rewriter, loc));
|
|
|
|
// Copy the offset pointer from the old descriptor to the new one.
|
|
targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));
|
|
|
|
// Iterate over the dimensions and apply size/stride permutation:
|
|
// When enumerating the results of the permutation map, the enumeration
|
|
// index is the index into the target dimensions and the DimExpr points to
|
|
// the dimension of the source memref.
|
|
for (const auto &en :
|
|
llvm::enumerate(transposeOp.getPermutation().getResults())) {
|
|
int targetPos = en.index();
|
|
int sourcePos = cast<AffineDimExpr>(en.value()).getPosition();
|
|
targetMemRef.setSize(rewriter, loc, targetPos,
|
|
viewMemRef.size(rewriter, loc, sourcePos));
|
|
targetMemRef.setStride(rewriter, loc, targetPos,
|
|
viewMemRef.stride(rewriter, loc, sourcePos));
|
|
}
|
|
|
|
rewriter.replaceOp(transposeOp, {targetMemRef});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Conversion pattern that transforms an op into:
|
|
/// 1. An `llvm.mlir.undef` operation to create a memref descriptor
|
|
/// 2. Updates to the descriptor to introduce the data ptr, offset, size
|
|
/// and stride.
|
|
/// The view op is replaced by the descriptor.
|
|
struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
|
|
using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern;
|
|
|
|
// Build and return the value for the idx^th shape dimension, either by
|
|
// returning the constant shape dimension or counting the proper dynamic size.
|
|
Value getSize(ConversionPatternRewriter &rewriter, Location loc,
|
|
ArrayRef<int64_t> shape, ValueRange dynamicSizes, unsigned idx,
|
|
Type indexType) const {
|
|
assert(idx < shape.size());
|
|
if (ShapedType::isStatic(shape[idx]))
|
|
return createIndexAttrConstant(rewriter, loc, indexType, shape[idx]);
|
|
// Count the number of dynamic dims in range [0, idx]
|
|
unsigned nDynamic =
|
|
llvm::count_if(shape.take_front(idx), ShapedType::isDynamic);
|
|
return dynamicSizes[nDynamic];
|
|
}
|
|
|
|
// Build and return the idx^th stride, either by returning the constant stride
|
|
// or by computing the dynamic stride from the current `runningStride` and
|
|
// `nextSize`. The caller should keep a running stride and update it with the
|
|
// result returned by this function.
|
|
Value getStride(ConversionPatternRewriter &rewriter, Location loc,
|
|
ArrayRef<int64_t> strides, Value nextSize,
|
|
Value runningStride, unsigned idx, Type indexType) const {
|
|
assert(idx < strides.size());
|
|
if (ShapedType::isStatic(strides[idx]))
|
|
return createIndexAttrConstant(rewriter, loc, indexType, strides[idx]);
|
|
if (nextSize)
|
|
return runningStride
|
|
? LLVM::MulOp::create(rewriter, loc, runningStride, nextSize)
|
|
: nextSize;
|
|
assert(!runningStride);
|
|
return createIndexAttrConstant(rewriter, loc, indexType, 1);
|
|
}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto loc = viewOp.getLoc();
|
|
|
|
auto viewMemRefType = viewOp.getType();
|
|
auto targetElementTy =
|
|
typeConverter->convertType(viewMemRefType.getElementType());
|
|
auto targetDescTy = typeConverter->convertType(viewMemRefType);
|
|
if (!targetDescTy || !targetElementTy ||
|
|
!LLVM::isCompatibleType(targetElementTy) ||
|
|
!LLVM::isCompatibleType(targetDescTy))
|
|
return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
|
|
failure();
|
|
|
|
int64_t offset;
|
|
SmallVector<int64_t, 4> strides;
|
|
auto successStrides = viewMemRefType.getStridesAndOffset(strides, offset);
|
|
if (failed(successStrides))
|
|
return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
|
|
assert(offset == 0 && "expected offset to be 0");
|
|
|
|
// Target memref must be contiguous in memory (innermost stride is 1), or
|
|
// empty (special case when at least one of the memref dimensions is 0).
|
|
if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
|
|
return viewOp.emitWarning("cannot cast to non-contiguous shape"),
|
|
failure();
|
|
|
|
// Create the descriptor.
|
|
MemRefDescriptor sourceMemRef(adaptor.getSource());
|
|
auto targetMemRef = MemRefDescriptor::poison(rewriter, loc, targetDescTy);
|
|
|
|
// Field 1: Copy the allocated pointer, used for malloc/free.
|
|
Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
|
|
auto srcMemRefType = cast<MemRefType>(viewOp.getSource().getType());
|
|
targetMemRef.setAllocatedPtr(rewriter, loc, allocatedPtr);
|
|
|
|
// Field 2: Copy the actual aligned pointer to payload.
|
|
Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
|
|
alignedPtr = LLVM::GEPOp::create(
|
|
rewriter, loc, alignedPtr.getType(),
|
|
typeConverter->convertType(srcMemRefType.getElementType()), alignedPtr,
|
|
adaptor.getByteShift());
|
|
|
|
targetMemRef.setAlignedPtr(rewriter, loc, alignedPtr);
|
|
|
|
Type indexType = getIndexType();
|
|
// Field 3: The offset in the resulting type must be 0. This is
|
|
// because of the type change: an offset on srcType* may not be
|
|
// expressible as an offset on dstType*.
|
|
targetMemRef.setOffset(
|
|
rewriter, loc,
|
|
createIndexAttrConstant(rewriter, loc, indexType, offset));
|
|
|
|
// Early exit for 0-D corner case.
|
|
if (viewMemRefType.getRank() == 0)
|
|
return rewriter.replaceOp(viewOp, {targetMemRef}), success();
|
|
|
|
// Fields 4 and 5: Update sizes and strides.
|
|
Value stride = nullptr, nextSize = nullptr;
|
|
for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
|
|
// Update size.
|
|
Value size = getSize(rewriter, loc, viewMemRefType.getShape(),
|
|
adaptor.getSizes(), i, indexType);
|
|
targetMemRef.setSize(rewriter, loc, i, size);
|
|
// Update stride.
|
|
stride =
|
|
getStride(rewriter, loc, strides, nextSize, stride, i, indexType);
|
|
targetMemRef.setStride(rewriter, loc, i, stride);
|
|
nextSize = size;
|
|
}
|
|
|
|
rewriter.replaceOp(viewOp, {targetMemRef});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// AtomicRMWOpLowering
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Try to match the kind of a memref.atomic_rmw to determine whether to use a
|
|
/// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
|
|
static std::optional<LLVM::AtomicBinOp>
|
|
matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
|
|
switch (atomicOp.getKind()) {
|
|
case arith::AtomicRMWKind::addf:
|
|
return LLVM::AtomicBinOp::fadd;
|
|
case arith::AtomicRMWKind::addi:
|
|
return LLVM::AtomicBinOp::add;
|
|
case arith::AtomicRMWKind::assign:
|
|
return LLVM::AtomicBinOp::xchg;
|
|
case arith::AtomicRMWKind::maximumf:
|
|
// TODO: remove this by end of 2025.
|
|
LDBG() << "the lowering of memref.atomicrmw maximumf changed "
|
|
"from fmax to fmaximum, expect more NaNs";
|
|
return LLVM::AtomicBinOp::fmaximum;
|
|
case arith::AtomicRMWKind::maxnumf:
|
|
return LLVM::AtomicBinOp::fmax;
|
|
case arith::AtomicRMWKind::maxs:
|
|
return LLVM::AtomicBinOp::max;
|
|
case arith::AtomicRMWKind::maxu:
|
|
return LLVM::AtomicBinOp::umax;
|
|
case arith::AtomicRMWKind::minimumf:
|
|
// TODO: remove this by end of 2025.
|
|
LDBG() << "the lowering of memref.atomicrmw minimum changed "
|
|
"from fmin to fminimum, expect more NaNs";
|
|
return LLVM::AtomicBinOp::fminimum;
|
|
case arith::AtomicRMWKind::minnumf:
|
|
return LLVM::AtomicBinOp::fmin;
|
|
case arith::AtomicRMWKind::mins:
|
|
return LLVM::AtomicBinOp::min;
|
|
case arith::AtomicRMWKind::minu:
|
|
return LLVM::AtomicBinOp::umin;
|
|
case arith::AtomicRMWKind::ori:
|
|
return LLVM::AtomicBinOp::_or;
|
|
case arith::AtomicRMWKind::xori:
|
|
return LLVM::AtomicBinOp::_xor;
|
|
case arith::AtomicRMWKind::andi:
|
|
return LLVM::AtomicBinOp::_and;
|
|
default:
|
|
return std::nullopt;
|
|
}
|
|
llvm_unreachable("Invalid AtomicRMWKind");
|
|
}
|
|
|
|
struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
|
|
using Base::Base;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto maybeKind = matchSimpleAtomicOp(atomicOp);
|
|
if (!maybeKind)
|
|
return failure();
|
|
auto memRefType = atomicOp.getMemRefType();
|
|
SmallVector<int64_t> strides;
|
|
int64_t offset;
|
|
if (failed(memRefType.getStridesAndOffset(strides, offset)))
|
|
return failure();
|
|
auto dataPtr =
|
|
getStridedElementPtr(rewriter, atomicOp.getLoc(), memRefType,
|
|
adaptor.getMemref(), adaptor.getIndices());
|
|
rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
|
|
atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
|
|
LLVM::AtomicOrdering::acq_rel);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Unpack the pointer returned by a memref.extract_aligned_pointer_as_index.
|
|
class ConvertExtractAlignedPointerAsIndex
|
|
: public ConvertOpToLLVMPattern<memref::ExtractAlignedPointerAsIndexOp> {
|
|
public:
|
|
using ConvertOpToLLVMPattern<
|
|
memref::ExtractAlignedPointerAsIndexOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
|
|
OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
BaseMemRefType sourceTy = extractOp.getSource().getType();
|
|
|
|
Value alignedPtr;
|
|
if (sourceTy.hasRank()) {
|
|
MemRefDescriptor desc(adaptor.getSource());
|
|
alignedPtr = desc.alignedPtr(rewriter, extractOp->getLoc());
|
|
} else {
|
|
auto elementPtrTy = LLVM::LLVMPointerType::get(
|
|
rewriter.getContext(), sourceTy.getMemorySpaceAsInt());
|
|
|
|
UnrankedMemRefDescriptor desc(adaptor.getSource());
|
|
Value descPtr = desc.memRefDescPtr(rewriter, extractOp->getLoc());
|
|
|
|
alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
|
|
rewriter, extractOp->getLoc(), *getTypeConverter(), descPtr,
|
|
elementPtrTy);
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(
|
|
extractOp, getTypeConverter()->getIndexType(), alignedPtr);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Materialize the MemRef descriptor represented by the results of
|
|
/// ExtractStridedMetadataOp.
|
|
class ExtractStridedMetadataOpLowering
|
|
: public ConvertOpToLLVMPattern<memref::ExtractStridedMetadataOp> {
|
|
public:
|
|
using ConvertOpToLLVMPattern<
|
|
memref::ExtractStridedMetadataOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
|
|
OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
|
|
if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType()))
|
|
return failure();
|
|
|
|
// Create the descriptor.
|
|
MemRefDescriptor sourceMemRef(adaptor.getSource());
|
|
Location loc = extractStridedMetadataOp.getLoc();
|
|
Value source = extractStridedMetadataOp.getSource();
|
|
|
|
auto sourceMemRefType = cast<MemRefType>(source.getType());
|
|
int64_t rank = sourceMemRefType.getRank();
|
|
SmallVector<Value> results;
|
|
results.reserve(2 + rank * 2);
|
|
|
|
// Base buffer.
|
|
Value baseBuffer = sourceMemRef.allocatedPtr(rewriter, loc);
|
|
Value alignedBuffer = sourceMemRef.alignedPtr(rewriter, loc);
|
|
MemRefDescriptor dstMemRef = MemRefDescriptor::fromStaticShape(
|
|
rewriter, loc, *getTypeConverter(),
|
|
cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType()),
|
|
baseBuffer, alignedBuffer);
|
|
results.push_back((Value)dstMemRef);
|
|
|
|
// Offset.
|
|
results.push_back(sourceMemRef.offset(rewriter, loc));
|
|
|
|
// Sizes.
|
|
for (unsigned i = 0; i < rank; ++i)
|
|
results.push_back(sourceMemRef.size(rewriter, loc, i));
|
|
// Strides.
|
|
for (unsigned i = 0; i < rank; ++i)
|
|
results.push_back(sourceMemRef.stride(rewriter, loc, i));
|
|
|
|
rewriter.replaceOp(extractStridedMetadataOp, results);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
|
|
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
|
|
SymbolTableCollection *symbolTables) {
|
|
// clang-format off
|
|
patterns.add<
|
|
AllocaOpLowering,
|
|
AllocaScopeOpLowering,
|
|
AtomicRMWOpLowering,
|
|
AssumeAlignmentOpLowering,
|
|
ConvertExtractAlignedPointerAsIndex,
|
|
DimOpLowering,
|
|
ExtractStridedMetadataOpLowering,
|
|
GenericAtomicRMWOpLowering,
|
|
GetGlobalMemrefOpLowering,
|
|
LoadOpLowering,
|
|
MemRefCastOpLowering,
|
|
MemorySpaceCastOpLowering,
|
|
MemRefReinterpretCastOpLowering,
|
|
MemRefReshapeOpLowering,
|
|
PrefetchOpLowering,
|
|
RankOpLowering,
|
|
ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
|
|
ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
|
|
StoreOpLowering,
|
|
SubViewOpLowering,
|
|
TransposeOpLowering,
|
|
ViewOpLowering>(converter);
|
|
// clang-format on
|
|
patterns.add<GlobalMemrefOpLowering, MemRefCopyOpLowering>(converter,
|
|
symbolTables);
|
|
auto allocLowering = converter.getOptions().allocLowering;
|
|
if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
|
|
patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter,
|
|
symbolTables);
|
|
else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
|
|
patterns.add<AllocOpLowering, DeallocOpLowering>(converter, symbolTables);
|
|
}
|
|
|
|
namespace {
|
|
struct FinalizeMemRefToLLVMConversionPass
|
|
: public impl::FinalizeMemRefToLLVMConversionPassBase<
|
|
FinalizeMemRefToLLVMConversionPass> {
|
|
using FinalizeMemRefToLLVMConversionPassBase::
|
|
FinalizeMemRefToLLVMConversionPassBase;
|
|
|
|
void runOnOperation() override {
|
|
Operation *op = getOperation();
|
|
const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
|
|
LowerToLLVMOptions options(&getContext(),
|
|
dataLayoutAnalysis.getAtOrAbove(op));
|
|
options.allocLowering =
|
|
(useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
|
|
: LowerToLLVMOptions::AllocLowering::Malloc);
|
|
|
|
options.useGenericFunctions = useGenericFunctions;
|
|
|
|
if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
|
|
options.overrideIndexBitwidth(indexBitwidth);
|
|
|
|
LLVMTypeConverter typeConverter(&getContext(), options,
|
|
&dataLayoutAnalysis);
|
|
RewritePatternSet patterns(&getContext());
|
|
SymbolTableCollection symbolTables;
|
|
populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns,
|
|
&symbolTables);
|
|
LLVMConversionTarget target(getContext());
|
|
target.addLegalOp<func::FuncOp>();
|
|
if (failed(applyPartialConversion(op, target, std::move(patterns))))
|
|
signalPassFailure();
|
|
}
|
|
};
|
|
|
|
/// Implement the interface to convert MemRef to LLVM.
|
|
struct MemRefToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
|
|
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
|
|
void loadDependentDialects(MLIRContext *context) const final {
|
|
context->loadDialect<LLVM::LLVMDialect>();
|
|
}
|
|
|
|
/// Hook for derived dialect interface to provide conversion patterns
|
|
/// and mark dialect legal for the conversion target.
|
|
void populateConvertToLLVMConversionPatterns(
|
|
ConversionTarget &target, LLVMTypeConverter &typeConverter,
|
|
RewritePatternSet &patterns) const final {
|
|
populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::registerConvertMemRefToLLVMInterface(DialectRegistry ®istry) {
|
|
registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
|
|
dialect->addInterfaces<MemRefToLLVMDialectInterface>();
|
|
});
|
|
}
|