
`alloc`s that have users outside of loops are guaranteed to fail in `multibuffer`. Instead of exposing ourselves to that failure in the transform dialect, filter out the `alloc`s that fall in this category. To be able to do this filtering we have to change the `multibuffer` transform op from `TransformEachOpTrait` to a plain `TransformOp`. This is because `TransformEachOpTrait` expects that every successful `applyToOne` returns a non-empty result. Couple of notes: - I changed the assembly syntax to make sure we only get `alloc` ops as input. (And added a test case to make sure we reject invalid inputs.) - `multibuffer` can still fail pretty easily when you know its limitations. See the updated `op failed to multibuffer` test case for instance. Longer term, instead of leaking/coupling the actual implementation (in this case the checks normally done in `memref::multiBuffer`) with the transform dialect (the added check in `::apply`), we may want to refactor how we structure the underlying implementation. E.g., we could imagine a `canApply` method for all the implementations that we want to hook up in the transform dialect. This has some implications on how not to duplicate work between `canApply` and the actual implementation but I thought I throw that here to have us think about it :). Differential Revision: https://reviews.llvm.org/D143747
86 lines
2.9 KiB
C++
86 lines
2.9 KiB
C++
//===- MemRefTransformOps.cpp - Implementation of Memref transform ops ----===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
|
|
#include "mlir/Dialect/PDL/IR/PDL.h"
|
|
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
|
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
|
|
#include "mlir/Interfaces/LoopLikeInterface.h"
|
|
|
|
using namespace mlir;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MemRefMultiBufferOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply(
|
|
transform::TransformResults &transformResults,
|
|
transform::TransformState &state) {
|
|
SmallVector<Operation *> results;
|
|
ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
|
|
for (auto *op : payloadOps) {
|
|
bool canApplyMultiBuffer = true;
|
|
auto target = cast<memref::AllocOp>(op);
|
|
// Skip allocations not used in a loop.
|
|
for (Operation *user : target->getUsers()) {
|
|
auto loop = user->getParentOfType<LoopLikeOpInterface>();
|
|
if (!loop) {
|
|
canApplyMultiBuffer = false;
|
|
break;
|
|
}
|
|
}
|
|
if (!canApplyMultiBuffer)
|
|
continue;
|
|
|
|
auto newBuffer = memref::multiBuffer(target, getFactor());
|
|
if (failed(newBuffer))
|
|
return emitSilenceableFailure(target->getLoc())
|
|
<< "op failed to multibuffer";
|
|
|
|
results.push_back(*newBuffer);
|
|
}
|
|
transformResults.set(getResult().cast<OpResult>(), results);
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Transform op registration
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
class MemRefTransformDialectExtension
|
|
: public transform::TransformDialectExtension<
|
|
MemRefTransformDialectExtension> {
|
|
public:
|
|
using Base::Base;
|
|
|
|
void init() {
|
|
declareDependentDialect<pdl::PDLDialect>();
|
|
declareGeneratedDialect<AffineDialect>();
|
|
declareGeneratedDialect<arith::ArithDialect>();
|
|
|
|
registerTransformOps<
|
|
#define GET_OP_LIST
|
|
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
|
|
>();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
|
|
|
|
void mlir::memref::registerTransformDialectExtension(
|
|
DialectRegistry ®istry) {
|
|
registry.addExtensions<MemRefTransformDialectExtension>();
|
|
}
|