Introduce support for the third kind of values in the transform dialect:
value handles. Similarly to operation handles, value handles are
pointing to a set of values in the payload IR. This enables
transformation to be targeted at specific values, such as individual
results of a multi-result payload operation without indirecting through
the producing op or block arguments that previously could not be easily
addressed. This is expected to support a broad class of memory-oriented
transformations such as selective bufferization, buffer assignment, and
memory transfer management.
Value handles are functionally similar to operation handles and require
similar implementation logic. The most important change concerns the
handle invalidation mechanism where operation and value handles can
affect each other.
This patch includes two cleanups that make it easier to introduce value
handles:
- `RaggedArray` structure that encapsulates the SmallVector of
ArrayRef backed by flat SmallVector logic, frequently used in the
transform interfaces implementation;
- rewrite the tests that associated payload handles with an integer
value `reinterpret_cast`ed as a pointer, which were a frequent
source of confusion and crashes when adding more debugging
facilities that can inspect the payload.
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D143385
132 lines
4.8 KiB
C++
132 lines
4.8 KiB
C++
//===- TransformDialect.cpp - Transform Dialect Definition ----------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
|
#include "mlir/Dialect/PDL/IR/PDL.h"
|
|
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
|
|
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
|
|
#include "mlir/Dialect/Transform/IR/TransformOps.h"
|
|
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
|
|
#include "mlir/IR/DialectImplementation.h"
|
|
|
|
using namespace mlir;
|
|
|
|
#include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc"
|
|
|
|
#ifndef NDEBUG
|
|
void transform::detail::checkImplementsTransformOpInterface(
|
|
StringRef name, MLIRContext *context) {
|
|
// Since the operation is being inserted into the Transform dialect and the
|
|
// dialect does not implement the interface fallback, only check for the op
|
|
// itself having the interface implementation.
|
|
RegisteredOperationName opName =
|
|
*RegisteredOperationName::lookup(name, context);
|
|
assert((opName.hasInterface<TransformOpInterface>() ||
|
|
opName.hasTrait<OpTrait::IsTerminator>()) &&
|
|
"non-terminator ops injected into the transform dialect must "
|
|
"implement TransformOpInterface");
|
|
assert(opName.hasInterface<MemoryEffectOpInterface>() &&
|
|
"ops injected into the transform dialect must implement "
|
|
"MemoryEffectsOpInterface");
|
|
}
|
|
|
|
void transform::detail::checkImplementsTransformHandleTypeInterface(
|
|
TypeID typeID, MLIRContext *context) {
|
|
const auto &abstractType = AbstractType::lookup(typeID, context);
|
|
assert((abstractType.hasInterface(
|
|
TransformHandleTypeInterface::getInterfaceID()) ||
|
|
abstractType.hasInterface(
|
|
TransformParamTypeInterface::getInterfaceID()) ||
|
|
abstractType.hasInterface(
|
|
TransformValueHandleTypeInterface::getInterfaceID())) &&
|
|
"expected Transform dialect type to implement one of the three "
|
|
"interfaces");
|
|
}
|
|
#endif // NDEBUG
|
|
|
|
namespace {
|
|
struct PDLOperationTypeTransformHandleTypeInterfaceImpl
|
|
: public transform::TransformHandleTypeInterface::ExternalModel<
|
|
PDLOperationTypeTransformHandleTypeInterfaceImpl,
|
|
pdl::OperationType> {
|
|
DiagnosedSilenceableFailure
|
|
checkPayload(Type type, Location loc, ArrayRef<Operation *> payload) const {
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
void transform::TransformDialect::initialize() {
|
|
// Using the checked versions to enable the same assertions as for the ops
|
|
// from extensions.
|
|
addOperationsChecked<
|
|
#define GET_OP_LIST
|
|
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
|
|
>();
|
|
initializeTypes();
|
|
|
|
pdl::OperationType::attachInterface<
|
|
PDLOperationTypeTransformHandleTypeInterfaceImpl>(*getContext());
|
|
}
|
|
|
|
void transform::TransformDialect::mergeInPDLMatchHooks(
|
|
llvm::StringMap<PDLConstraintFunction> &&constraintFns) {
|
|
// Steal the constraint functions from the given map.
|
|
for (auto &it : constraintFns)
|
|
pdlMatchHooks.registerConstraintFunction(it.getKey(), std::move(it.second));
|
|
}
|
|
|
|
const llvm::StringMap<PDLConstraintFunction> &
|
|
transform::TransformDialect::getPDLConstraintHooks() const {
|
|
return pdlMatchHooks.getConstraintFunctions();
|
|
}
|
|
|
|
Type transform::TransformDialect::parseType(DialectAsmParser &parser) const {
|
|
StringRef keyword;
|
|
SMLoc loc = parser.getCurrentLocation();
|
|
if (failed(parser.parseKeyword(&keyword)))
|
|
return nullptr;
|
|
|
|
auto it = typeParsingHooks.find(keyword);
|
|
if (it == typeParsingHooks.end()) {
|
|
parser.emitError(loc) << "unknown type mnemonic: " << keyword;
|
|
return nullptr;
|
|
}
|
|
|
|
return it->getValue()(parser);
|
|
}
|
|
|
|
void transform::TransformDialect::printType(Type type,
|
|
DialectAsmPrinter &printer) const {
|
|
auto it = typePrintingHooks.find(type.getTypeID());
|
|
assert(it != typePrintingHooks.end() && "printing unknown type");
|
|
it->getSecond()(type, printer);
|
|
}
|
|
|
|
void transform::TransformDialect::reportDuplicateTypeRegistration(
|
|
StringRef mnemonic) {
|
|
std::string buffer;
|
|
llvm::raw_string_ostream msg(buffer);
|
|
msg << "extensible dialect type '" << mnemonic
|
|
<< "' is already registered with a different implementation";
|
|
msg.flush();
|
|
llvm::report_fatal_error(StringRef(buffer));
|
|
}
|
|
|
|
void transform::TransformDialect::reportDuplicateOpRegistration(
|
|
StringRef opName) {
|
|
std::string buffer;
|
|
llvm::raw_string_ostream msg(buffer);
|
|
msg << "extensible dialect operation '" << opName
|
|
<< "' is already registered with a mismatching TypeID";
|
|
msg.flush();
|
|
llvm::report_fatal_error(StringRef(buffer));
|
|
}
|
|
|
|
#include "mlir/Dialect/Transform/IR/TransformDialectEnums.cpp.inc"
|