In D115022, we introduced an optimization where OpResults of a `linalg.generic` may bufferize in-place with an "in" OpOperand if the corresponding "out" OpOperand is not used in the computation. This optimization can lead to unexpected behavior if the newly chosen OpOperand is in the same alias set as another OpOperand (that is used in the computation). In that case, the newly chosen OpOperand must bufferize out-of-place. This can be confusing to users, as always choosing the "out" OpOperand (regardless of whether it is used) would be expected when having the notion of "destination-passing style" in mind. With this change, we go back to always bufferizing in-place with "out" OpOperands by default, but letting users override the behavior with a bufferization option. Differential Revision: https://reviews.llvm.org/D120182
134 lines
5.1 KiB
C++
134 lines
5.1 KiB
C++
//===- ComprehensiveBufferize.cpp - Single pass bufferization -------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "PassDetail.h"
|
|
|
|
#include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h"
|
|
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
|
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
|
|
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
|
|
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
|
|
#include "mlir/Dialect/Linalg/Passes.h"
|
|
#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
|
|
#include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h"
|
|
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
|
|
#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Pass/PassManager.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
#include "mlir/Transforms/Passes.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::bufferization;
|
|
using namespace mlir::linalg;
|
|
using namespace mlir::linalg::comprehensive_bufferize;
|
|
|
|
namespace {
|
|
struct LinalgComprehensiveModuleBufferize
|
|
: public LinalgComprehensiveModuleBufferizeBase<
|
|
LinalgComprehensiveModuleBufferize> {
|
|
LinalgComprehensiveModuleBufferize() = default;
|
|
|
|
LinalgComprehensiveModuleBufferize(
|
|
const LinalgComprehensiveModuleBufferize &p) = default;
|
|
|
|
explicit LinalgComprehensiveModuleBufferize(
|
|
AnalysisBufferizationOptions options)
|
|
: options(options) {}
|
|
|
|
void runOnOperation() override;
|
|
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry
|
|
.insert<bufferization::BufferizationDialect, linalg::LinalgDialect,
|
|
memref::MemRefDialect, tensor::TensorDialect,
|
|
vector::VectorDialect, scf::SCFDialect,
|
|
arith::ArithmeticDialect, StandardOpsDialect, AffineDialect>();
|
|
affine_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
|
arith::registerBufferizableOpInterfaceExternalModels(registry);
|
|
linalg::registerBufferizableOpInterfaceExternalModels(registry);
|
|
scf::registerBufferizableOpInterfaceExternalModels(registry);
|
|
std_ext::registerModuleBufferizationExternalModels(registry);
|
|
tensor::registerBufferizableOpInterfaceExternalModels(registry);
|
|
vector::registerBufferizableOpInterfaceExternalModels(registry);
|
|
}
|
|
|
|
private:
|
|
llvm::Optional<AnalysisBufferizationOptions> options;
|
|
};
|
|
} // namespace
|
|
|
|
static void applyEnablingTransformations(ModuleOp moduleOp) {
|
|
RewritePatternSet patterns(moduleOp.getContext());
|
|
patterns.add<GeneralizePadOpPattern>(moduleOp.getContext());
|
|
(void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns));
|
|
}
|
|
|
|
static FailureOr<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc,
|
|
MemRefType type,
|
|
ValueRange dynShape,
|
|
unsigned int bufferAlignment) {
|
|
Value allocated = b.create<memref::AllocaOp>(
|
|
loc, type, dynShape, b.getI64IntegerAttr(bufferAlignment));
|
|
return allocated;
|
|
}
|
|
|
|
void LinalgComprehensiveModuleBufferize::runOnOperation() {
|
|
AnalysisBufferizationOptions opt;
|
|
if (!options) {
|
|
// Make new bufferization options if none were provided when creating the
|
|
// pass.
|
|
if (useAlloca) {
|
|
opt.allocationFn = allocationFnUsingAlloca;
|
|
opt.deallocationFn = [](OpBuilder &b, Location loc, Value v) {
|
|
return success();
|
|
};
|
|
}
|
|
opt.allowReturnMemref = allowReturnMemref;
|
|
opt.allowUnknownOps = allowUnknownOps;
|
|
opt.analysisFuzzerSeed = analysisFuzzerSeed;
|
|
opt.createDeallocs = createDeallocs;
|
|
opt.fullyDynamicLayoutMaps = fullyDynamicLayoutMaps;
|
|
opt.printConflicts = printConflicts;
|
|
opt.testAnalysisOnly = testAnalysisOnly;
|
|
opt.alwaysAliasingWithDest = alwaysAliasingWithDest;
|
|
if (initTensorElimination) {
|
|
opt.addPostAnalysisStep(insertSliceAnchoredInitTensorEliminationStep);
|
|
}
|
|
} else {
|
|
opt = *options;
|
|
}
|
|
|
|
ModuleOp moduleOp = getOperation();
|
|
applyEnablingTransformations(moduleOp);
|
|
|
|
if (failed(runModuleBufferize(moduleOp, opt))) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
if (opt.testAnalysisOnly)
|
|
return;
|
|
|
|
OpPassManager cleanupPipeline("builtin.module");
|
|
cleanupPipeline.addPass(createCanonicalizerPass());
|
|
cleanupPipeline.addPass(createCSEPass());
|
|
cleanupPipeline.addPass(createLoopInvariantCodeMotionPass());
|
|
(void)runPipeline(cleanupPipeline, moduleOp);
|
|
}
|
|
|
|
std::unique_ptr<Pass> mlir::createLinalgComprehensiveModuleBufferizePass() {
|
|
return std::make_unique<LinalgComprehensiveModuleBufferize>();
|
|
}
|
|
|
|
std::unique_ptr<Pass> mlir::createLinalgComprehensiveModuleBufferizePass(
|
|
const AnalysisBufferizationOptions &options) {
|
|
return std::make_unique<LinalgComprehensiveModuleBufferize>(options);
|
|
}
|