llvm-project/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
Christian Ulmann da5b382d8c
[MLIR][MemRefToLLVM] Remove last typed pointer remnants (#71113)
This commit removes the last typed pointer remnants from the MemRef to
LLVM conversions, including the transform dialect operation. Typed
pointers have been deprecated for a while now and it's planned to soon
remove them from the LLVM dialect.

Related PSA:
https://discourse.llvm.org/t/psa-removal-of-typed-pointers-from-the-llvm-dialect/74502
2023-11-03 10:55:11 +01:00

336 lines
13 KiB
C++

//===- MemRefTransformOps.cpp - Implementation of Memref transform ops ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
#include "mlir/Analysis/DataLayoutAnalysis.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"
using namespace mlir;
#define DEBUG_TYPE "memref-transforms"
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
//===----------------------------------------------------------------------===//
// Apply...ConversionPatternsOp
//===----------------------------------------------------------------------===//
std::unique_ptr<TypeConverter>
transform::MemrefToLLVMTypeConverterOp::getTypeConverter() {
LowerToLLVMOptions options(getContext());
options.allocLowering =
(getUseAlignedAlloc() ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
: LowerToLLVMOptions::AllocLowering::Malloc);
options.useGenericFunctions = getUseGenericFunctions();
if (getIndexBitwidth() != kDeriveIndexBitwidthFromDataLayout)
options.overrideIndexBitwidth(getIndexBitwidth());
// TODO: the following two options don't really make sense for
// memref_to_llvm_type_converter specifically but we should have a single
// to_llvm_type_converter.
if (getDataLayout().has_value())
options.dataLayout = llvm::DataLayout(getDataLayout().value());
options.useBarePtrCallConv = getUseBarePtrCallConv();
return std::make_unique<LLVMTypeConverter>(getContext(), options);
}
StringRef transform::MemrefToLLVMTypeConverterOp::getTypeConverterType() {
return "LLVMTypeConverter";
}
//===----------------------------------------------------------------------===//
// Apply...PatternsOp
//===----------------------------------------------------------------------===//
namespace {
class AllocToAllocaPattern : public OpRewritePattern<memref::AllocOp> {
public:
explicit AllocToAllocaPattern(Operation *analysisRoot, int64_t maxSize = 0)
: OpRewritePattern<memref::AllocOp>(analysisRoot->getContext()),
dataLayoutAnalysis(analysisRoot), maxSize(maxSize) {}
LogicalResult matchAndRewrite(memref::AllocOp op,
PatternRewriter &rewriter) const override {
return success(memref::allocToAlloca(
rewriter, op, [this](memref::AllocOp alloc, memref::DeallocOp dealloc) {
MemRefType type = alloc.getMemref().getType();
if (!type.hasStaticShape())
return false;
const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(alloc);
int64_t elementSize = dataLayout.getTypeSize(type.getElementType());
return maxSize == 0 || type.getNumElements() * elementSize < maxSize;
}));
}
private:
DataLayoutAnalysis dataLayoutAnalysis;
int64_t maxSize;
};
} // namespace
void transform::ApplyAllocToAllocaOp::populatePatterns(
RewritePatternSet &patterns) {}
void transform::ApplyAllocToAllocaOp::populatePatternsWithState(
RewritePatternSet &patterns, transform::TransformState &state) {
patterns.insert<AllocToAllocaPattern>(
state.getTopLevel(), static_cast<int64_t>(getSizeLimit().value_or(0)));
}
void transform::ApplyExpandOpsPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
memref::populateExpandOpsPatterns(patterns);
}
void transform::ApplyExpandStridedMetadataPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
memref::populateExpandStridedMetadataPatterns(patterns);
}
void transform::ApplyExtractAddressComputationsPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
memref::populateExtractAddressComputationsPatterns(patterns);
}
void transform::ApplyFoldMemrefAliasOpsPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
memref::populateFoldMemRefAliasOpPatterns(patterns);
}
void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp::
populatePatterns(RewritePatternSet &patterns) {
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
}
//===----------------------------------------------------------------------===//
// AllocaToGlobalOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
auto allocaOps = state.getPayloadOps(getAlloca());
SmallVector<memref::GlobalOp> globalOps;
SmallVector<memref::GetGlobalOp> getGlobalOps;
// Transform `memref.alloca`s.
for (auto *op : allocaOps) {
auto alloca = cast<memref::AllocaOp>(op);
MLIRContext *ctx = rewriter.getContext();
Location loc = alloca->getLoc();
memref::GlobalOp globalOp;
{
// Find nearest symbol table.
Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(op);
assert(symbolTableOp && "expected alloca payload to be in symbol table");
SymbolTable symbolTable(symbolTableOp);
// Insert a `memref.global` into the symbol table.
Type resultType = alloca.getResult().getType();
OpBuilder builder(rewriter.getContext());
// TODO: Add a better builder for this.
globalOp = builder.create<memref::GlobalOp>(
loc, StringAttr::get(ctx, "alloca"), StringAttr::get(ctx, "private"),
TypeAttr::get(resultType), Attribute{}, UnitAttr{}, IntegerAttr{});
symbolTable.insert(globalOp);
}
// Replace the `memref.alloca` with a `memref.get_global` accessing the
// global symbol inserted above.
rewriter.setInsertionPoint(alloca);
auto getGlobalOp = rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(
alloca, globalOp.getType(), globalOp.getName());
globalOps.push_back(globalOp);
getGlobalOps.push_back(getGlobalOp);
}
// Assemble results.
results.set(getGlobal().cast<OpResult>(), globalOps);
results.set(getGetGlobal().cast<OpResult>(), getGlobalOps);
return DiagnosedSilenceableFailure::success();
}
void transform::MemRefAllocaToGlobalOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
producesHandle(getGlobal(), effects);
producesHandle(getGetGlobal(), effects);
consumesHandle(getAlloca(), effects);
modifiesPayload(effects);
}
//===----------------------------------------------------------------------===//
// MemRefMultiBufferOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &transformResults,
transform::TransformState &state) {
SmallVector<Operation *> results;
for (Operation *op : state.getPayloadOps(getTarget())) {
bool canApplyMultiBuffer = true;
auto target = cast<memref::AllocOp>(op);
LLVM_DEBUG(DBGS() << "Start multibuffer transform op: " << target << "\n";);
// Skip allocations not used in a loop.
for (Operation *user : target->getUsers()) {
if (isa<memref::DeallocOp>(user))
continue;
auto loop = user->getParentOfType<LoopLikeOpInterface>();
if (!loop) {
LLVM_DEBUG(DBGS() << "--allocation not used in a loop\n";
DBGS() << "----due to user: " << *user;);
canApplyMultiBuffer = false;
break;
}
}
if (!canApplyMultiBuffer) {
LLVM_DEBUG(DBGS() << "--cannot apply multibuffering -> Skip\n";);
continue;
}
auto newBuffer =
memref::multiBuffer(rewriter, target, getFactor(), getSkipAnalysis());
if (failed(newBuffer)) {
LLVM_DEBUG(DBGS() << "--op failed to multibuffer\n";);
return emitSilenceableFailure(target->getLoc())
<< "op failed to multibuffer";
}
results.push_back(*newBuffer);
}
transformResults.set(cast<OpResult>(getResult()), results);
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// MemRefEraseDeadAllocAndStoresOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::MemRefEraseDeadAllocAndStoresOp::applyToOne(
transform::TransformRewriter &rewriter, Operation *target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
// Apply store to load forwarding and dead store elimination.
vector::transferOpflowOpt(rewriter, target);
memref::eraseDeadAllocAndStores(rewriter, target);
return DiagnosedSilenceableFailure::success();
}
void transform::MemRefEraseDeadAllocAndStoresOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getTarget(), effects);
transform::modifiesPayload(effects);
}
void transform::MemRefEraseDeadAllocAndStoresOp::build(OpBuilder &builder,
OperationState &result,
Value target) {
result.addOperands(target);
}
//===----------------------------------------------------------------------===//
// MemRefMakeLoopIndependentOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform::MemRefMakeLoopIndependentOp::applyToOne(
transform::TransformRewriter &rewriter, Operation *target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
// Gather IVs.
SmallVector<Value> ivs;
Operation *nextOp = target;
for (uint64_t i = 0, e = getNumLoops(); i < e; ++i) {
nextOp = nextOp->getParentOfType<scf::ForOp>();
if (!nextOp) {
DiagnosedSilenceableFailure diag = emitSilenceableError()
<< "could not find " << i
<< "-th enclosing loop";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}
ivs.push_back(cast<scf::ForOp>(nextOp).getInductionVar());
}
// Rewrite IR.
FailureOr<Value> replacement = failure();
if (auto allocaOp = dyn_cast<memref::AllocaOp>(target)) {
replacement = memref::replaceWithIndependentOp(rewriter, allocaOp, ivs);
} else {
DiagnosedSilenceableFailure diag = emitSilenceableError()
<< "unsupported target op";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}
if (failed(replacement)) {
DiagnosedSilenceableFailure diag =
emitSilenceableError() << "could not make target op loop-independent";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}
results.push_back(replacement->getDefiningOp());
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
namespace {
class MemRefTransformDialectExtension
: public transform::TransformDialectExtension<
MemRefTransformDialectExtension> {
public:
using Base::Base;
void init() {
declareGeneratedDialect<affine::AffineDialect>();
declareGeneratedDialect<arith::ArithDialect>();
declareGeneratedDialect<memref::MemRefDialect>();
declareGeneratedDialect<nvgpu::NVGPUDialect>();
declareGeneratedDialect<vector::VectorDialect>();
registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
>();
}
};
} // namespace
#define GET_OP_CLASSES
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
void mlir::memref::registerTransformDialectExtension(
DialectRegistry &registry) {
registry.addExtensions<MemRefTransformDialectExtension>();
}