llvm-project/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
Quentin Colombet b838b62876 [mlir][MemRef][Transform] Don't apply multibuffer on "useless" allocs
`alloc`s that have users outside of loops are guaranteed to fail in
`multibuffer`.

Instead of exposing ourselves to that failure in the transform dialect,
filter out the `alloc`s that fall in this category.

To be able to do this filtering we have to change the `multibuffer`
transform op from `TransformEachOpTrait` to a plain `TransformOp`. This is
because `TransformEachOpTrait` expects that every successful `applyToOne`
returns a non-empty result.

Couple of notes:
- I changed the assembly syntax to make sure we only get `alloc` ops as
  input. (And added a test case to make sure we reject invalid inputs.)
- `multibuffer` can still fail pretty easily when you know its limitations.
  See the updated `op failed to multibuffer` test case for instance.
  Longer term, instead of leaking/coupling the actual implementation (in
  this case the checks normally done in `memref::multiBuffer`) with the
  transform dialect (the added check in `::apply`), we may want to refactor
  how we structure the underlying implementation. E.g., we could imagine a
  `canApply` method for all the implementations that we want to hook up in
  the transform dialect.
  This has some implications on how not to duplicate work between
  `canApply` and the actual implementation but I thought I throw that here
  to have us think about it :).

Differential Revision: https://reviews.llvm.org/D143747
2023-02-13 14:19:10 +01:00

86 lines
2.9 KiB
C++

//===- MemRefTransformOps.cpp - Implementation of Memref transform ops ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// MemRefMultiBufferOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply(
transform::TransformResults &transformResults,
transform::TransformState &state) {
SmallVector<Operation *> results;
ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
for (auto *op : payloadOps) {
bool canApplyMultiBuffer = true;
auto target = cast<memref::AllocOp>(op);
// Skip allocations not used in a loop.
for (Operation *user : target->getUsers()) {
auto loop = user->getParentOfType<LoopLikeOpInterface>();
if (!loop) {
canApplyMultiBuffer = false;
break;
}
}
if (!canApplyMultiBuffer)
continue;
auto newBuffer = memref::multiBuffer(target, getFactor());
if (failed(newBuffer))
return emitSilenceableFailure(target->getLoc())
<< "op failed to multibuffer";
results.push_back(*newBuffer);
}
transformResults.set(getResult().cast<OpResult>(), results);
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
namespace {
class MemRefTransformDialectExtension
: public transform::TransformDialectExtension<
MemRefTransformDialectExtension> {
public:
using Base::Base;
void init() {
declareDependentDialect<pdl::PDLDialect>();
declareGeneratedDialect<AffineDialect>();
declareGeneratedDialect<arith::ArithDialect>();
registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
>();
}
};
} // namespace
#define GET_OP_CLASSES
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
void mlir::memref::registerTransformDialectExtension(
DialectRegistry &registry) {
registry.addExtensions<MemRefTransformDialectExtension>();
}