llvm-project/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp
Alex Zinenko 4b455a71b7 [mlir] adapt TransformEachOpTrait to parameter values
Adapt the implementation of TransformEachOpTrait to the existence of
parameter values recently introduced into the transform dialect. In
particular, allow `applyToOne` hooks to return a list containing a mix
of `Operation *` that will be associated with handles and `Attribute`
that will be associated with parameter values by the trait
implementation of the transform interface's `apply` method.

Disentangle the "transposition" of the list of per-payload op partial
results to decrease its overall complexity and detemplatize the code
that doesn't really need templates. This removes the poorly documented
special handling for single-result ops with TransformEachOpTrait that
could have assigned null pointer values to handles.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D140979
2023-01-06 12:23:41 +00:00

69 lines
2.4 KiB
C++

//===- MemRefTransformOps.cpp - Implementation of Memref transform ops ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// MemRefMultiBufferOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::applyToOne(
memref::AllocOp target, transform::ApplyToEachResultList &results,
transform::TransformState &state) {
auto newBuffer = memref::multiBuffer(target, getFactor());
if (failed(newBuffer)) {
Diagnostic diag(target->getLoc(), DiagnosticSeverity::Note);
diag << "op failed to multibuffer";
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
}
results.push_back(*newBuffer);
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
namespace {
class MemRefTransformDialectExtension
: public transform::TransformDialectExtension<
MemRefTransformDialectExtension> {
public:
using Base::Base;
void init() {
declareDependentDialect<pdl::PDLDialect>();
declareGeneratedDialect<AffineDialect>();
declareGeneratedDialect<arith::ArithDialect>();
registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
>();
}
};
} // namespace
#define GET_OP_CLASSES
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
void mlir::memref::registerTransformDialectExtension(
DialectRegistry &registry) {
registry.addExtensions<MemRefTransformDialectExtension>();
}