
Adapt the implementation of TransformEachOpTrait to the existence of parameter values recently introduced into the transform dialect. In particular, allow `applyToOne` hooks to return a list containing a mix of `Operation *` that will be associated with handles and `Attribute` that will be associated with parameter values by the trait implementation of the transform interface's `apply` method. Disentangle the "transposition" of the list of per-payload op partial results to decrease its overall complexity and detemplatize the code that doesn't really need templates. This removes the poorly documented special handling for single-result ops with TransformEachOpTrait that could have assigned null pointer values to handles. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D140979
69 lines
2.4 KiB
C++
69 lines
2.4 KiB
C++
//===- MemRefTransformOps.cpp - Implementation of Memref transform ops ----===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
|
|
#include "mlir/Dialect/PDL/IR/PDL.h"
|
|
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
|
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
|
|
|
|
using namespace mlir;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MemRefMultiBufferOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::applyToOne(
|
|
memref::AllocOp target, transform::ApplyToEachResultList &results,
|
|
transform::TransformState &state) {
|
|
auto newBuffer = memref::multiBuffer(target, getFactor());
|
|
if (failed(newBuffer)) {
|
|
Diagnostic diag(target->getLoc(), DiagnosticSeverity::Note);
|
|
diag << "op failed to multibuffer";
|
|
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
|
|
}
|
|
|
|
results.push_back(*newBuffer);
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Transform op registration
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
class MemRefTransformDialectExtension
|
|
: public transform::TransformDialectExtension<
|
|
MemRefTransformDialectExtension> {
|
|
public:
|
|
using Base::Base;
|
|
|
|
void init() {
|
|
declareDependentDialect<pdl::PDLDialect>();
|
|
declareGeneratedDialect<AffineDialect>();
|
|
declareGeneratedDialect<arith::ArithDialect>();
|
|
|
|
registerTransformOps<
|
|
#define GET_OP_LIST
|
|
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
|
|
>();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
|
|
|
|
void mlir::memref::registerTransformDialectExtension(
|
|
DialectRegistry ®istry) {
|
|
registry.addExtensions<MemRefTransformDialectExtension>();
|
|
}
|