
This commit removes the last typed pointer remnants from the MemRef to LLVM conversions, including the transform dialect operation. Typed pointers have been deprecated for a while now and it's planned to soon remove them from the LLVM dialect. Related PSA: https://discourse.llvm.org/t/psa-removal-of-typed-pointers-from-the-llvm-dialect/74502
336 lines
13 KiB
C++
336 lines
13 KiB
C++
//===- MemRefTransformOps.cpp - Implementation of Memref transform ops ----===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
|
|
|
|
#include "mlir/Analysis/DataLayoutAnalysis.h"
|
|
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
|
|
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
|
|
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
|
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
|
|
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
|
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
|
|
#include "mlir/Interfaces/LoopLikeInterface.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
#include "llvm/Support/Debug.h"
|
|
|
|
using namespace mlir;
|
|
|
|
#define DEBUG_TYPE "memref-transforms"
|
|
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Apply...ConversionPatternsOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
std::unique_ptr<TypeConverter>
|
|
transform::MemrefToLLVMTypeConverterOp::getTypeConverter() {
|
|
LowerToLLVMOptions options(getContext());
|
|
options.allocLowering =
|
|
(getUseAlignedAlloc() ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
|
|
: LowerToLLVMOptions::AllocLowering::Malloc);
|
|
options.useGenericFunctions = getUseGenericFunctions();
|
|
|
|
if (getIndexBitwidth() != kDeriveIndexBitwidthFromDataLayout)
|
|
options.overrideIndexBitwidth(getIndexBitwidth());
|
|
|
|
// TODO: the following two options don't really make sense for
|
|
// memref_to_llvm_type_converter specifically but we should have a single
|
|
// to_llvm_type_converter.
|
|
if (getDataLayout().has_value())
|
|
options.dataLayout = llvm::DataLayout(getDataLayout().value());
|
|
options.useBarePtrCallConv = getUseBarePtrCallConv();
|
|
|
|
return std::make_unique<LLVMTypeConverter>(getContext(), options);
|
|
}
|
|
|
|
StringRef transform::MemrefToLLVMTypeConverterOp::getTypeConverterType() {
|
|
return "LLVMTypeConverter";
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Apply...PatternsOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
class AllocToAllocaPattern : public OpRewritePattern<memref::AllocOp> {
|
|
public:
|
|
explicit AllocToAllocaPattern(Operation *analysisRoot, int64_t maxSize = 0)
|
|
: OpRewritePattern<memref::AllocOp>(analysisRoot->getContext()),
|
|
dataLayoutAnalysis(analysisRoot), maxSize(maxSize) {}
|
|
|
|
LogicalResult matchAndRewrite(memref::AllocOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
return success(memref::allocToAlloca(
|
|
rewriter, op, [this](memref::AllocOp alloc, memref::DeallocOp dealloc) {
|
|
MemRefType type = alloc.getMemref().getType();
|
|
if (!type.hasStaticShape())
|
|
return false;
|
|
|
|
const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(alloc);
|
|
int64_t elementSize = dataLayout.getTypeSize(type.getElementType());
|
|
return maxSize == 0 || type.getNumElements() * elementSize < maxSize;
|
|
}));
|
|
}
|
|
|
|
private:
|
|
DataLayoutAnalysis dataLayoutAnalysis;
|
|
int64_t maxSize;
|
|
};
|
|
} // namespace
|
|
|
|
void transform::ApplyAllocToAllocaOp::populatePatterns(
|
|
RewritePatternSet &patterns) {}
|
|
|
|
void transform::ApplyAllocToAllocaOp::populatePatternsWithState(
|
|
RewritePatternSet &patterns, transform::TransformState &state) {
|
|
patterns.insert<AllocToAllocaPattern>(
|
|
state.getTopLevel(), static_cast<int64_t>(getSizeLimit().value_or(0)));
|
|
}
|
|
|
|
void transform::ApplyExpandOpsPatternsOp::populatePatterns(
|
|
RewritePatternSet &patterns) {
|
|
memref::populateExpandOpsPatterns(patterns);
|
|
}
|
|
|
|
void transform::ApplyExpandStridedMetadataPatternsOp::populatePatterns(
|
|
RewritePatternSet &patterns) {
|
|
memref::populateExpandStridedMetadataPatterns(patterns);
|
|
}
|
|
|
|
void transform::ApplyExtractAddressComputationsPatternsOp::populatePatterns(
|
|
RewritePatternSet &patterns) {
|
|
memref::populateExtractAddressComputationsPatterns(patterns);
|
|
}
|
|
|
|
void transform::ApplyFoldMemrefAliasOpsPatternsOp::populatePatterns(
|
|
RewritePatternSet &patterns) {
|
|
memref::populateFoldMemRefAliasOpPatterns(patterns);
|
|
}
|
|
|
|
void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp::
|
|
populatePatterns(RewritePatternSet &patterns) {
|
|
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// AllocaToGlobalOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
DiagnosedSilenceableFailure
|
|
transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results,
|
|
transform::TransformState &state) {
|
|
auto allocaOps = state.getPayloadOps(getAlloca());
|
|
|
|
SmallVector<memref::GlobalOp> globalOps;
|
|
SmallVector<memref::GetGlobalOp> getGlobalOps;
|
|
|
|
// Transform `memref.alloca`s.
|
|
for (auto *op : allocaOps) {
|
|
auto alloca = cast<memref::AllocaOp>(op);
|
|
MLIRContext *ctx = rewriter.getContext();
|
|
Location loc = alloca->getLoc();
|
|
|
|
memref::GlobalOp globalOp;
|
|
{
|
|
// Find nearest symbol table.
|
|
Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(op);
|
|
assert(symbolTableOp && "expected alloca payload to be in symbol table");
|
|
SymbolTable symbolTable(symbolTableOp);
|
|
|
|
// Insert a `memref.global` into the symbol table.
|
|
Type resultType = alloca.getResult().getType();
|
|
OpBuilder builder(rewriter.getContext());
|
|
// TODO: Add a better builder for this.
|
|
globalOp = builder.create<memref::GlobalOp>(
|
|
loc, StringAttr::get(ctx, "alloca"), StringAttr::get(ctx, "private"),
|
|
TypeAttr::get(resultType), Attribute{}, UnitAttr{}, IntegerAttr{});
|
|
symbolTable.insert(globalOp);
|
|
}
|
|
|
|
// Replace the `memref.alloca` with a `memref.get_global` accessing the
|
|
// global symbol inserted above.
|
|
rewriter.setInsertionPoint(alloca);
|
|
auto getGlobalOp = rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(
|
|
alloca, globalOp.getType(), globalOp.getName());
|
|
|
|
globalOps.push_back(globalOp);
|
|
getGlobalOps.push_back(getGlobalOp);
|
|
}
|
|
|
|
// Assemble results.
|
|
results.set(getGlobal().cast<OpResult>(), globalOps);
|
|
results.set(getGetGlobal().cast<OpResult>(), getGlobalOps);
|
|
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void transform::MemRefAllocaToGlobalOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
producesHandle(getGlobal(), effects);
|
|
producesHandle(getGetGlobal(), effects);
|
|
consumesHandle(getAlloca(), effects);
|
|
modifiesPayload(effects);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MemRefMultiBufferOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &transformResults,
|
|
transform::TransformState &state) {
|
|
SmallVector<Operation *> results;
|
|
for (Operation *op : state.getPayloadOps(getTarget())) {
|
|
bool canApplyMultiBuffer = true;
|
|
auto target = cast<memref::AllocOp>(op);
|
|
LLVM_DEBUG(DBGS() << "Start multibuffer transform op: " << target << "\n";);
|
|
// Skip allocations not used in a loop.
|
|
for (Operation *user : target->getUsers()) {
|
|
if (isa<memref::DeallocOp>(user))
|
|
continue;
|
|
auto loop = user->getParentOfType<LoopLikeOpInterface>();
|
|
if (!loop) {
|
|
LLVM_DEBUG(DBGS() << "--allocation not used in a loop\n";
|
|
DBGS() << "----due to user: " << *user;);
|
|
canApplyMultiBuffer = false;
|
|
break;
|
|
}
|
|
}
|
|
if (!canApplyMultiBuffer) {
|
|
LLVM_DEBUG(DBGS() << "--cannot apply multibuffering -> Skip\n";);
|
|
continue;
|
|
}
|
|
|
|
auto newBuffer =
|
|
memref::multiBuffer(rewriter, target, getFactor(), getSkipAnalysis());
|
|
|
|
if (failed(newBuffer)) {
|
|
LLVM_DEBUG(DBGS() << "--op failed to multibuffer\n";);
|
|
return emitSilenceableFailure(target->getLoc())
|
|
<< "op failed to multibuffer";
|
|
}
|
|
|
|
results.push_back(*newBuffer);
|
|
}
|
|
transformResults.set(cast<OpResult>(getResult()), results);
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MemRefEraseDeadAllocAndStoresOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
DiagnosedSilenceableFailure
|
|
transform::MemRefEraseDeadAllocAndStoresOp::applyToOne(
|
|
transform::TransformRewriter &rewriter, Operation *target,
|
|
transform::ApplyToEachResultList &results,
|
|
transform::TransformState &state) {
|
|
// Apply store to load forwarding and dead store elimination.
|
|
vector::transferOpflowOpt(rewriter, target);
|
|
memref::eraseDeadAllocAndStores(rewriter, target);
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void transform::MemRefEraseDeadAllocAndStoresOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::onlyReadsHandle(getTarget(), effects);
|
|
transform::modifiesPayload(effects);
|
|
}
|
|
void transform::MemRefEraseDeadAllocAndStoresOp::build(OpBuilder &builder,
|
|
OperationState &result,
|
|
Value target) {
|
|
result.addOperands(target);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MemRefMakeLoopIndependentOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
DiagnosedSilenceableFailure transform::MemRefMakeLoopIndependentOp::applyToOne(
|
|
transform::TransformRewriter &rewriter, Operation *target,
|
|
transform::ApplyToEachResultList &results,
|
|
transform::TransformState &state) {
|
|
// Gather IVs.
|
|
SmallVector<Value> ivs;
|
|
Operation *nextOp = target;
|
|
for (uint64_t i = 0, e = getNumLoops(); i < e; ++i) {
|
|
nextOp = nextOp->getParentOfType<scf::ForOp>();
|
|
if (!nextOp) {
|
|
DiagnosedSilenceableFailure diag = emitSilenceableError()
|
|
<< "could not find " << i
|
|
<< "-th enclosing loop";
|
|
diag.attachNote(target->getLoc()) << "target op";
|
|
return diag;
|
|
}
|
|
ivs.push_back(cast<scf::ForOp>(nextOp).getInductionVar());
|
|
}
|
|
|
|
// Rewrite IR.
|
|
FailureOr<Value> replacement = failure();
|
|
if (auto allocaOp = dyn_cast<memref::AllocaOp>(target)) {
|
|
replacement = memref::replaceWithIndependentOp(rewriter, allocaOp, ivs);
|
|
} else {
|
|
DiagnosedSilenceableFailure diag = emitSilenceableError()
|
|
<< "unsupported target op";
|
|
diag.attachNote(target->getLoc()) << "target op";
|
|
return diag;
|
|
}
|
|
if (failed(replacement)) {
|
|
DiagnosedSilenceableFailure diag =
|
|
emitSilenceableError() << "could not make target op loop-independent";
|
|
diag.attachNote(target->getLoc()) << "target op";
|
|
return diag;
|
|
}
|
|
results.push_back(replacement->getDefiningOp());
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Transform op registration
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
class MemRefTransformDialectExtension
|
|
: public transform::TransformDialectExtension<
|
|
MemRefTransformDialectExtension> {
|
|
public:
|
|
using Base::Base;
|
|
|
|
void init() {
|
|
declareGeneratedDialect<affine::AffineDialect>();
|
|
declareGeneratedDialect<arith::ArithDialect>();
|
|
declareGeneratedDialect<memref::MemRefDialect>();
|
|
declareGeneratedDialect<nvgpu::NVGPUDialect>();
|
|
declareGeneratedDialect<vector::VectorDialect>();
|
|
|
|
registerTransformOps<
|
|
#define GET_OP_LIST
|
|
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
|
|
>();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
|
|
|
|
void mlir::memref::registerTransformDialectExtension(
|
|
DialectRegistry ®istry) {
|
|
registry.addExtensions<MemRefTransformDialectExtension>();
|
|
}
|