
`PassManager::run` loads the dependent dialects for each pass into the
current context prior to invoking the individual passes. If the
dependent dialect is already loaded into the context, this should be a
no-op. However, if there are extensions registered in the
`DialectRegistry`, the dependent dialects are unconditionally registered
into the context.
This poses a problem for dynamic pass pipelines, however, because they
will likely be executing while the context is in an immutable state
(because of the parent pass pipeline being run).
To solve this, we'll update the extension registration API on
`DialectRegistry` to require a type ID for each extension that is
registered. Then, instead of unconditionally registered dialects into a
context if extensions are present, we'll check against the extension
type IDs already present in the context's internal `DialectRegistry`.
The context will only be marked as dirty if there are net-new extension
types present in the `DialectRegistry` populated by
`PassManager::getDependentDialects`.
Note: this PR removes the `addExtension` overload that utilizes
`std::function` as the parameter. This is because `std::function` is
copyable and potentially allocates memory for the contained function so
we can't use the function pointer as the unique type ID for the
extension.
Downstream changes required:
- Existing `DialectExtension` subclasses will need a type ID to be
registered for each subclass. More details on how to register a type ID
can be found here:
8b68e06731/mlir/include/mlir/Support/TypeID.h (L30)
- Existing uses of the `std::function` overload of `addExtension` will
need to be refactored into dedicated `DialectExtension` classes with
associated type IDs. The attached `std::function` can either be inlined
into or called directly from `DialectExtension::apply`.
---------
Co-authored-by: Mehdi Amini <joker.eph@gmail.com>
338 lines
13 KiB
C++
338 lines
13 KiB
C++
//===- MemRefTransformOps.cpp - Implementation of Memref transform ops ----===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
|
|
|
|
#include "mlir/Analysis/DataLayoutAnalysis.h"
|
|
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
|
|
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
|
|
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
|
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
|
|
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
|
|
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
|
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
|
|
#include "mlir/Interfaces/LoopLikeInterface.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
#include "llvm/Support/Debug.h"
|
|
|
|
using namespace mlir;
|
|
|
|
#define DEBUG_TYPE "memref-transforms"
|
|
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Apply...ConversionPatternsOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
std::unique_ptr<TypeConverter>
|
|
transform::MemrefToLLVMTypeConverterOp::getTypeConverter() {
|
|
LowerToLLVMOptions options(getContext());
|
|
options.allocLowering =
|
|
(getUseAlignedAlloc() ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
|
|
: LowerToLLVMOptions::AllocLowering::Malloc);
|
|
options.useGenericFunctions = getUseGenericFunctions();
|
|
|
|
if (getIndexBitwidth() != kDeriveIndexBitwidthFromDataLayout)
|
|
options.overrideIndexBitwidth(getIndexBitwidth());
|
|
|
|
// TODO: the following two options don't really make sense for
|
|
// memref_to_llvm_type_converter specifically but we should have a single
|
|
// to_llvm_type_converter.
|
|
if (getDataLayout().has_value())
|
|
options.dataLayout = llvm::DataLayout(getDataLayout().value());
|
|
options.useBarePtrCallConv = getUseBarePtrCallConv();
|
|
|
|
return std::make_unique<LLVMTypeConverter>(getContext(), options);
|
|
}
|
|
|
|
StringRef transform::MemrefToLLVMTypeConverterOp::getTypeConverterType() {
|
|
return "LLVMTypeConverter";
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Apply...PatternsOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
class AllocToAllocaPattern : public OpRewritePattern<memref::AllocOp> {
|
|
public:
|
|
explicit AllocToAllocaPattern(Operation *analysisRoot, int64_t maxSize = 0)
|
|
: OpRewritePattern<memref::AllocOp>(analysisRoot->getContext()),
|
|
dataLayoutAnalysis(analysisRoot), maxSize(maxSize) {}
|
|
|
|
LogicalResult matchAndRewrite(memref::AllocOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
return success(memref::allocToAlloca(
|
|
rewriter, op, [this](memref::AllocOp alloc, memref::DeallocOp dealloc) {
|
|
MemRefType type = alloc.getMemref().getType();
|
|
if (!type.hasStaticShape())
|
|
return false;
|
|
|
|
const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(alloc);
|
|
int64_t elementSize = dataLayout.getTypeSize(type.getElementType());
|
|
return maxSize == 0 || type.getNumElements() * elementSize < maxSize;
|
|
}));
|
|
}
|
|
|
|
private:
|
|
DataLayoutAnalysis dataLayoutAnalysis;
|
|
int64_t maxSize;
|
|
};
|
|
} // namespace
|
|
|
|
void transform::ApplyAllocToAllocaOp::populatePatterns(
|
|
RewritePatternSet &patterns) {}
|
|
|
|
void transform::ApplyAllocToAllocaOp::populatePatternsWithState(
|
|
RewritePatternSet &patterns, transform::TransformState &state) {
|
|
patterns.insert<AllocToAllocaPattern>(
|
|
state.getTopLevel(), static_cast<int64_t>(getSizeLimit().value_or(0)));
|
|
}
|
|
|
|
void transform::ApplyExpandOpsPatternsOp::populatePatterns(
|
|
RewritePatternSet &patterns) {
|
|
memref::populateExpandOpsPatterns(patterns);
|
|
}
|
|
|
|
void transform::ApplyExpandStridedMetadataPatternsOp::populatePatterns(
|
|
RewritePatternSet &patterns) {
|
|
memref::populateExpandStridedMetadataPatterns(patterns);
|
|
}
|
|
|
|
void transform::ApplyExtractAddressComputationsPatternsOp::populatePatterns(
|
|
RewritePatternSet &patterns) {
|
|
memref::populateExtractAddressComputationsPatterns(patterns);
|
|
}
|
|
|
|
void transform::ApplyFoldMemrefAliasOpsPatternsOp::populatePatterns(
|
|
RewritePatternSet &patterns) {
|
|
memref::populateFoldMemRefAliasOpPatterns(patterns);
|
|
}
|
|
|
|
void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp::
|
|
populatePatterns(RewritePatternSet &patterns) {
|
|
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// AllocaToGlobalOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
DiagnosedSilenceableFailure
|
|
transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results,
|
|
transform::TransformState &state) {
|
|
auto allocaOps = state.getPayloadOps(getAlloca());
|
|
|
|
SmallVector<memref::GlobalOp> globalOps;
|
|
SmallVector<memref::GetGlobalOp> getGlobalOps;
|
|
|
|
// Transform `memref.alloca`s.
|
|
for (auto *op : allocaOps) {
|
|
auto alloca = cast<memref::AllocaOp>(op);
|
|
MLIRContext *ctx = rewriter.getContext();
|
|
Location loc = alloca->getLoc();
|
|
|
|
memref::GlobalOp globalOp;
|
|
{
|
|
// Find nearest symbol table.
|
|
Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(op);
|
|
assert(symbolTableOp && "expected alloca payload to be in symbol table");
|
|
SymbolTable symbolTable(symbolTableOp);
|
|
|
|
// Insert a `memref.global` into the symbol table.
|
|
Type resultType = alloca.getResult().getType();
|
|
OpBuilder builder(rewriter.getContext());
|
|
// TODO: Add a better builder for this.
|
|
globalOp = builder.create<memref::GlobalOp>(
|
|
loc, StringAttr::get(ctx, "alloca"), StringAttr::get(ctx, "private"),
|
|
TypeAttr::get(resultType), Attribute{}, UnitAttr{}, IntegerAttr{});
|
|
symbolTable.insert(globalOp);
|
|
}
|
|
|
|
// Replace the `memref.alloca` with a `memref.get_global` accessing the
|
|
// global symbol inserted above.
|
|
rewriter.setInsertionPoint(alloca);
|
|
auto getGlobalOp = rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(
|
|
alloca, globalOp.getType(), globalOp.getName());
|
|
|
|
globalOps.push_back(globalOp);
|
|
getGlobalOps.push_back(getGlobalOp);
|
|
}
|
|
|
|
// Assemble results.
|
|
results.set(cast<OpResult>(getGlobal()), globalOps);
|
|
results.set(cast<OpResult>(getGetGlobal()), getGlobalOps);
|
|
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void transform::MemRefAllocaToGlobalOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
producesHandle(getOperation()->getOpResults(), effects);
|
|
consumesHandle(getAllocaMutable(), effects);
|
|
modifiesPayload(effects);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MemRefMultiBufferOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &transformResults,
|
|
transform::TransformState &state) {
|
|
SmallVector<Operation *> results;
|
|
for (Operation *op : state.getPayloadOps(getTarget())) {
|
|
bool canApplyMultiBuffer = true;
|
|
auto target = cast<memref::AllocOp>(op);
|
|
LLVM_DEBUG(DBGS() << "Start multibuffer transform op: " << target << "\n";);
|
|
// Skip allocations not used in a loop.
|
|
for (Operation *user : target->getUsers()) {
|
|
if (isa<memref::DeallocOp>(user))
|
|
continue;
|
|
auto loop = user->getParentOfType<LoopLikeOpInterface>();
|
|
if (!loop) {
|
|
LLVM_DEBUG(DBGS() << "--allocation not used in a loop\n";
|
|
DBGS() << "----due to user: " << *user;);
|
|
canApplyMultiBuffer = false;
|
|
break;
|
|
}
|
|
}
|
|
if (!canApplyMultiBuffer) {
|
|
LLVM_DEBUG(DBGS() << "--cannot apply multibuffering -> Skip\n";);
|
|
continue;
|
|
}
|
|
|
|
auto newBuffer =
|
|
memref::multiBuffer(rewriter, target, getFactor(), getSkipAnalysis());
|
|
|
|
if (failed(newBuffer)) {
|
|
LLVM_DEBUG(DBGS() << "--op failed to multibuffer\n";);
|
|
return emitSilenceableFailure(target->getLoc())
|
|
<< "op failed to multibuffer";
|
|
}
|
|
|
|
results.push_back(*newBuffer);
|
|
}
|
|
transformResults.set(cast<OpResult>(getResult()), results);
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MemRefEraseDeadAllocAndStoresOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
DiagnosedSilenceableFailure
|
|
transform::MemRefEraseDeadAllocAndStoresOp::applyToOne(
|
|
transform::TransformRewriter &rewriter, Operation *target,
|
|
transform::ApplyToEachResultList &results,
|
|
transform::TransformState &state) {
|
|
// Apply store to load forwarding and dead store elimination.
|
|
vector::transferOpflowOpt(rewriter, target);
|
|
memref::eraseDeadAllocAndStores(rewriter, target);
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void transform::MemRefEraseDeadAllocAndStoresOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::onlyReadsHandle(getTargetMutable(), effects);
|
|
transform::modifiesPayload(effects);
|
|
}
|
|
void transform::MemRefEraseDeadAllocAndStoresOp::build(OpBuilder &builder,
|
|
OperationState &result,
|
|
Value target) {
|
|
result.addOperands(target);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MemRefMakeLoopIndependentOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
DiagnosedSilenceableFailure transform::MemRefMakeLoopIndependentOp::applyToOne(
|
|
transform::TransformRewriter &rewriter, Operation *target,
|
|
transform::ApplyToEachResultList &results,
|
|
transform::TransformState &state) {
|
|
// Gather IVs.
|
|
SmallVector<Value> ivs;
|
|
Operation *nextOp = target;
|
|
for (uint64_t i = 0, e = getNumLoops(); i < e; ++i) {
|
|
nextOp = nextOp->getParentOfType<scf::ForOp>();
|
|
if (!nextOp) {
|
|
DiagnosedSilenceableFailure diag = emitSilenceableError()
|
|
<< "could not find " << i
|
|
<< "-th enclosing loop";
|
|
diag.attachNote(target->getLoc()) << "target op";
|
|
return diag;
|
|
}
|
|
ivs.push_back(cast<scf::ForOp>(nextOp).getInductionVar());
|
|
}
|
|
|
|
// Rewrite IR.
|
|
FailureOr<Value> replacement = failure();
|
|
if (auto allocaOp = dyn_cast<memref::AllocaOp>(target)) {
|
|
replacement = memref::replaceWithIndependentOp(rewriter, allocaOp, ivs);
|
|
} else {
|
|
DiagnosedSilenceableFailure diag = emitSilenceableError()
|
|
<< "unsupported target op";
|
|
diag.attachNote(target->getLoc()) << "target op";
|
|
return diag;
|
|
}
|
|
if (failed(replacement)) {
|
|
DiagnosedSilenceableFailure diag =
|
|
emitSilenceableError() << "could not make target op loop-independent";
|
|
diag.attachNote(target->getLoc()) << "target op";
|
|
return diag;
|
|
}
|
|
results.push_back(replacement->getDefiningOp());
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Transform op registration
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
class MemRefTransformDialectExtension
|
|
: public transform::TransformDialectExtension<
|
|
MemRefTransformDialectExtension> {
|
|
public:
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MemRefTransformDialectExtension)
|
|
|
|
using Base::Base;
|
|
|
|
void init() {
|
|
declareGeneratedDialect<affine::AffineDialect>();
|
|
declareGeneratedDialect<arith::ArithDialect>();
|
|
declareGeneratedDialect<memref::MemRefDialect>();
|
|
declareGeneratedDialect<nvgpu::NVGPUDialect>();
|
|
declareGeneratedDialect<vector::VectorDialect>();
|
|
|
|
registerTransformOps<
|
|
#define GET_OP_LIST
|
|
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
|
|
>();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
|
|
|
|
void mlir::memref::registerTransformDialectExtension(
|
|
DialectRegistry ®istry) {
|
|
registry.addExtensions<MemRefTransformDialectExtension>();
|
|
}
|