This is the second (and final) step of making "destination style" usable without depending on the Linalg dialect. (The first step was D135129.) This change allows us to provide default bufferization implementations for all destination-style ops. It also allows us to simplify `TilingInterface`. (E.g., `getDestinationOperands` can be removed.) Differential Revision: https://reviews.llvm.org/D136179
161 lines
6.3 KiB
C++
161 lines
6.3 KiB
C++
//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
|
|
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/Dialect.h"
|
|
#include "mlir/IR/Operation.h"
|
|
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
|
|
|
using namespace mlir;
|
|
using namespace linalg;
|
|
using namespace mlir::bufferization;
|
|
|
|
namespace {
|
|
|
|
/// Generic conversion for any DestinationStyleOpInterface on tensors.
|
|
static LogicalResult
|
|
bufferizeDestinationStyleOpInterface(RewriterBase &rewriter,
|
|
DestinationStyleOpInterface op,
|
|
const BufferizationOptions &options) {
|
|
// Take a guard before anything else.
|
|
OpBuilder::InsertionGuard g(rewriter);
|
|
rewriter.setInsertionPoint(op);
|
|
|
|
// Nothing to do. This op is already bufferized.
|
|
if (op.hasBufferSemantics())
|
|
return success();
|
|
|
|
// Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need
|
|
// basis.
|
|
if (!op.hasTensorSemantics())
|
|
return op->emitError() << "op does not have tensor semantics";
|
|
|
|
// New input operands for the cloned op.
|
|
SmallVector<Value> newInputBuffers;
|
|
newInputBuffers.reserve(op.getNumInputs());
|
|
for (OpOperand *opOperand : op.getInputOperands()) {
|
|
if (op.isScalar(opOperand)) {
|
|
newInputBuffers.push_back(opOperand->get());
|
|
continue;
|
|
}
|
|
FailureOr<Value> buffer = getBuffer(rewriter, opOperand->get(), options);
|
|
if (failed(buffer))
|
|
return failure();
|
|
newInputBuffers.push_back(*buffer);
|
|
}
|
|
|
|
// New output operands for the cloned op.
|
|
SmallVector<Value> newOutputBuffers;
|
|
for (OpResult opResult : op->getOpResults()) {
|
|
OpOperand *opOperand = op.getOutputOperand(opResult.getResultNumber());
|
|
FailureOr<Value> resultBuffer =
|
|
getBuffer(rewriter, opOperand->get(), options);
|
|
if (failed(resultBuffer))
|
|
return failure();
|
|
newOutputBuffers.push_back(*resultBuffer);
|
|
}
|
|
|
|
// Merge input/output operands.
|
|
SmallVector<Value> newOperands = newInputBuffers;
|
|
newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end());
|
|
|
|
// Set insertion point now that potential alloc/dealloc are introduced.
|
|
rewriter.setInsertionPoint(op);
|
|
// Clone the op, but use the new operands. Move the existing block into the
|
|
// new op. Since the new op does not have any tensor results, it does not
|
|
// return anything.
|
|
assert(op->getNumRegions() == 1 && "expected that op has 1 region");
|
|
auto newOp = cast<DestinationStyleOpInterface>(op.cloneWithoutRegions(
|
|
rewriter, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands));
|
|
rewriter.inlineRegionBefore(op->getRegion(0), newOp->getRegion(0),
|
|
newOp->getRegion(0).begin());
|
|
|
|
// Replace the results of the old op with the new output buffers.
|
|
replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers);
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Bufferization of linalg.generic. Replace with a new linalg.generic that
|
|
/// operates entirely on memrefs.
|
|
template <typename OpTy>
|
|
struct LinalgOpInterface
|
|
: public BufferizableOpInterface::ExternalModel<LinalgOpInterface<OpTy>,
|
|
OpTy> {
|
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
|
const AnalysisState &state) const {
|
|
// Operand is read if it is used in the computation.
|
|
auto genericOp = cast<linalg::LinalgOp>(op);
|
|
return genericOp.payloadUsesValueFromOperand(&opOperand);
|
|
}
|
|
|
|
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
|
const AnalysisState &state) const {
|
|
// Operand is written to if it has an aliasing OpResult.
|
|
auto bufferizableOp = cast<BufferizableOpInterface>(op);
|
|
return !bufferizableOp.getAliasingOpResult(opOperand, state).empty();
|
|
}
|
|
|
|
SmallVector<OpOperand *>
|
|
getAliasingOpOperand(Operation *op, OpResult opResult,
|
|
const AnalysisState &state) const {
|
|
auto genericOp = cast<DestinationStyleOpInterface>(op);
|
|
|
|
// The i-th OpResult may alias with the i-th "out" tensor.
|
|
return {genericOp.getOutputOperand(opResult.getResultNumber())};
|
|
}
|
|
|
|
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
|
const AnalysisState &state) const {
|
|
auto genericOp = cast<DestinationStyleOpInterface>(op);
|
|
|
|
// The i-th "out" tensor may alias with the i-th OpResult.
|
|
if (genericOp.isOutput(&opOperand))
|
|
return {genericOp.getTiedOpResult(&opOperand)};
|
|
return {};
|
|
}
|
|
|
|
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
|
const AnalysisState &state) const {
|
|
return BufferRelation::Equivalent;
|
|
}
|
|
|
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
|
const BufferizationOptions &options) const {
|
|
return bufferizeDestinationStyleOpInterface(
|
|
rewriter, cast<DestinationStyleOpInterface>(op), options);
|
|
}
|
|
};
|
|
|
|
/// Helper structure that iterates over all LinalgOps in `OpTys` and registers
|
|
/// the `BufferizableOpInterface` with each of them.
|
|
template <typename... Ops>
|
|
struct LinalgOpInterfaceHelper {
|
|
static void registerOpInterface(MLIRContext *ctx) {
|
|
(Ops::template attachInterface<LinalgOpInterface<Ops>>(*ctx), ...);
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
void mlir::linalg::registerBufferizableOpInterfaceExternalModels(
|
|
DialectRegistry ®istry) {
|
|
registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
|
|
// Register all Linalg structured ops. `LinalgOp` is an interface and it is
|
|
// not possible to attach an external interface to an existing interface.
|
|
// Therefore, attach the `BufferizableOpInterface` to all ops one-by-one.
|
|
LinalgOpInterfaceHelper<
|
|
#define GET_OP_LIST
|
|
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
|
|
>::registerOpInterface(ctx);
|
|
});
|
|
}
|