llvm-project/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
donald chen 2c1ae801e1
[mlir][side effect] refactor(*): Include more precise side effects (#94213)
This patch adds more precise side effects to the current ops with memory
effects, allowing us to determine which OpOperand/OpResult/BlockArgument
the
operation reads or writes, rather than just recording the reading and
writing
of values. This allows for convenient use of precise side effects to
achieve
analysis and optimization.

Related discussions:
https://discourse.llvm.org/t/rfc-add-operandindex-to-sideeffect-instance/79243
2024-06-19 22:10:34 +08:00

336 lines
13 KiB
C++

//===- MemRefTransformOps.cpp - Implementation of Memref transform ops ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
#include "mlir/Analysis/DataLayoutAnalysis.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"
using namespace mlir;
#define DEBUG_TYPE "memref-transforms"
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
//===----------------------------------------------------------------------===//
// Apply...ConversionPatternsOp
//===----------------------------------------------------------------------===//
std::unique_ptr<TypeConverter>
transform::MemrefToLLVMTypeConverterOp::getTypeConverter() {
LowerToLLVMOptions options(getContext());
options.allocLowering =
(getUseAlignedAlloc() ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
: LowerToLLVMOptions::AllocLowering::Malloc);
options.useGenericFunctions = getUseGenericFunctions();
if (getIndexBitwidth() != kDeriveIndexBitwidthFromDataLayout)
options.overrideIndexBitwidth(getIndexBitwidth());
// TODO: the following two options don't really make sense for
// memref_to_llvm_type_converter specifically but we should have a single
// to_llvm_type_converter.
if (getDataLayout().has_value())
options.dataLayout = llvm::DataLayout(getDataLayout().value());
options.useBarePtrCallConv = getUseBarePtrCallConv();
return std::make_unique<LLVMTypeConverter>(getContext(), options);
}
StringRef transform::MemrefToLLVMTypeConverterOp::getTypeConverterType() {
return "LLVMTypeConverter";
}
//===----------------------------------------------------------------------===//
// Apply...PatternsOp
//===----------------------------------------------------------------------===//
namespace {
class AllocToAllocaPattern : public OpRewritePattern<memref::AllocOp> {
public:
explicit AllocToAllocaPattern(Operation *analysisRoot, int64_t maxSize = 0)
: OpRewritePattern<memref::AllocOp>(analysisRoot->getContext()),
dataLayoutAnalysis(analysisRoot), maxSize(maxSize) {}
LogicalResult matchAndRewrite(memref::AllocOp op,
PatternRewriter &rewriter) const override {
return success(memref::allocToAlloca(
rewriter, op, [this](memref::AllocOp alloc, memref::DeallocOp dealloc) {
MemRefType type = alloc.getMemref().getType();
if (!type.hasStaticShape())
return false;
const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(alloc);
int64_t elementSize = dataLayout.getTypeSize(type.getElementType());
return maxSize == 0 || type.getNumElements() * elementSize < maxSize;
}));
}
private:
DataLayoutAnalysis dataLayoutAnalysis;
int64_t maxSize;
};
} // namespace
void transform::ApplyAllocToAllocaOp::populatePatterns(
RewritePatternSet &patterns) {}
void transform::ApplyAllocToAllocaOp::populatePatternsWithState(
RewritePatternSet &patterns, transform::TransformState &state) {
patterns.insert<AllocToAllocaPattern>(
state.getTopLevel(), static_cast<int64_t>(getSizeLimit().value_or(0)));
}
void transform::ApplyExpandOpsPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
memref::populateExpandOpsPatterns(patterns);
}
void transform::ApplyExpandStridedMetadataPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
memref::populateExpandStridedMetadataPatterns(patterns);
}
void transform::ApplyExtractAddressComputationsPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
memref::populateExtractAddressComputationsPatterns(patterns);
}
void transform::ApplyFoldMemrefAliasOpsPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
memref::populateFoldMemRefAliasOpPatterns(patterns);
}
void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp::
populatePatterns(RewritePatternSet &patterns) {
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
}
//===----------------------------------------------------------------------===//
// AllocaToGlobalOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
auto allocaOps = state.getPayloadOps(getAlloca());
SmallVector<memref::GlobalOp> globalOps;
SmallVector<memref::GetGlobalOp> getGlobalOps;
// Transform `memref.alloca`s.
for (auto *op : allocaOps) {
auto alloca = cast<memref::AllocaOp>(op);
MLIRContext *ctx = rewriter.getContext();
Location loc = alloca->getLoc();
memref::GlobalOp globalOp;
{
// Find nearest symbol table.
Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(op);
assert(symbolTableOp && "expected alloca payload to be in symbol table");
SymbolTable symbolTable(symbolTableOp);
// Insert a `memref.global` into the symbol table.
Type resultType = alloca.getResult().getType();
OpBuilder builder(rewriter.getContext());
// TODO: Add a better builder for this.
globalOp = builder.create<memref::GlobalOp>(
loc, StringAttr::get(ctx, "alloca"), StringAttr::get(ctx, "private"),
TypeAttr::get(resultType), Attribute{}, UnitAttr{}, IntegerAttr{});
symbolTable.insert(globalOp);
}
// Replace the `memref.alloca` with a `memref.get_global` accessing the
// global symbol inserted above.
rewriter.setInsertionPoint(alloca);
auto getGlobalOp = rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(
alloca, globalOp.getType(), globalOp.getName());
globalOps.push_back(globalOp);
getGlobalOps.push_back(getGlobalOp);
}
// Assemble results.
results.set(cast<OpResult>(getGlobal()), globalOps);
results.set(cast<OpResult>(getGetGlobal()), getGlobalOps);
return DiagnosedSilenceableFailure::success();
}
void transform::MemRefAllocaToGlobalOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
producesHandle(getOperation()->getOpResults(), effects);
consumesHandle(getAllocaMutable(), effects);
modifiesPayload(effects);
}
//===----------------------------------------------------------------------===//
// MemRefMultiBufferOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &transformResults,
transform::TransformState &state) {
SmallVector<Operation *> results;
for (Operation *op : state.getPayloadOps(getTarget())) {
bool canApplyMultiBuffer = true;
auto target = cast<memref::AllocOp>(op);
LLVM_DEBUG(DBGS() << "Start multibuffer transform op: " << target << "\n";);
// Skip allocations not used in a loop.
for (Operation *user : target->getUsers()) {
if (isa<memref::DeallocOp>(user))
continue;
auto loop = user->getParentOfType<LoopLikeOpInterface>();
if (!loop) {
LLVM_DEBUG(DBGS() << "--allocation not used in a loop\n";
DBGS() << "----due to user: " << *user;);
canApplyMultiBuffer = false;
break;
}
}
if (!canApplyMultiBuffer) {
LLVM_DEBUG(DBGS() << "--cannot apply multibuffering -> Skip\n";);
continue;
}
auto newBuffer =
memref::multiBuffer(rewriter, target, getFactor(), getSkipAnalysis());
if (failed(newBuffer)) {
LLVM_DEBUG(DBGS() << "--op failed to multibuffer\n";);
return emitSilenceableFailure(target->getLoc())
<< "op failed to multibuffer";
}
results.push_back(*newBuffer);
}
transformResults.set(cast<OpResult>(getResult()), results);
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// MemRefEraseDeadAllocAndStoresOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::MemRefEraseDeadAllocAndStoresOp::applyToOne(
transform::TransformRewriter &rewriter, Operation *target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
// Apply store to load forwarding and dead store elimination.
vector::transferOpflowOpt(rewriter, target);
memref::eraseDeadAllocAndStores(rewriter, target);
return DiagnosedSilenceableFailure::success();
}
void transform::MemRefEraseDeadAllocAndStoresOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getTargetMutable(), effects);
transform::modifiesPayload(effects);
}
void transform::MemRefEraseDeadAllocAndStoresOp::build(OpBuilder &builder,
OperationState &result,
Value target) {
result.addOperands(target);
}
//===----------------------------------------------------------------------===//
// MemRefMakeLoopIndependentOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform::MemRefMakeLoopIndependentOp::applyToOne(
transform::TransformRewriter &rewriter, Operation *target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
// Gather IVs.
SmallVector<Value> ivs;
Operation *nextOp = target;
for (uint64_t i = 0, e = getNumLoops(); i < e; ++i) {
nextOp = nextOp->getParentOfType<scf::ForOp>();
if (!nextOp) {
DiagnosedSilenceableFailure diag = emitSilenceableError()
<< "could not find " << i
<< "-th enclosing loop";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}
ivs.push_back(cast<scf::ForOp>(nextOp).getInductionVar());
}
// Rewrite IR.
FailureOr<Value> replacement = failure();
if (auto allocaOp = dyn_cast<memref::AllocaOp>(target)) {
replacement = memref::replaceWithIndependentOp(rewriter, allocaOp, ivs);
} else {
DiagnosedSilenceableFailure diag = emitSilenceableError()
<< "unsupported target op";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}
if (failed(replacement)) {
DiagnosedSilenceableFailure diag =
emitSilenceableError() << "could not make target op loop-independent";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}
results.push_back(replacement->getDefiningOp());
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
namespace {
class MemRefTransformDialectExtension
: public transform::TransformDialectExtension<
MemRefTransformDialectExtension> {
public:
using Base::Base;
void init() {
declareGeneratedDialect<affine::AffineDialect>();
declareGeneratedDialect<arith::ArithDialect>();
declareGeneratedDialect<memref::MemRefDialect>();
declareGeneratedDialect<nvgpu::NVGPUDialect>();
declareGeneratedDialect<vector::VectorDialect>();
registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
>();
}
};
} // namespace
#define GET_OP_CLASSES
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
void mlir::memref::registerTransformDialectExtension(
DialectRegistry &registry) {
registry.addExtensions<MemRefTransformDialectExtension>();
}