Currently, the sequence of Transform dialect operations only supports a single use of each operand (verified by the `transform.sequence` operation). This was originally motivated by the need to guard against accessing a payload IR operation associated with a transform IR value after this operation has likely been rewritten by a transformation. However, not all Transform dialect operations rewrite payload IR, in particular the "navigation" operation such as `transform.pdl_match` do not. Introduce memory effects to the Transform dialect operations to describe their effect on the payload IR and the mapping between payload IR opreations and transform IR values. Use these effects to replace the single-use rule, allowing repeated reads and disallowing use-after-free, where operations with the "free" effect are considered to "consume" the transform IR value and rewrite the corresponding payload IR operations). As an additional improvement, this enables code motion transformation on the transform IR itself. Reviewed By: Mogball Differential Revision: https://reviews.llvm.org/D124181
170 lines
5.9 KiB
C++
170 lines
5.9 KiB
C++
//===- TestTransformDialectExtension.cpp ----------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file defines an extension of the MLIR Transform dialect for testing
|
|
// purposes.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "TestTransformDialectExtension.h"
|
|
#include "mlir/Dialect/PDL/IR/PDL.h"
|
|
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
|
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/OpImplementation.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
/// Simple transform op defined outside of the dialect. Just emits a remark when
|
|
/// applied. This op is defined in C++ to test that C++ definitions also work
|
|
/// for op injection into the Transform dialect.
|
|
class TestTransformOp
|
|
: public Op<TestTransformOp, transform::TransformOpInterface::Trait,
|
|
MemoryEffectOpInterface::Trait> {
|
|
public:
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformOp)
|
|
|
|
using Op::Op;
|
|
|
|
static ArrayRef<StringRef> getAttributeNames() { return {}; }
|
|
|
|
static constexpr llvm::StringLiteral getOperationName() {
|
|
return llvm::StringLiteral("transform.test_transform_op");
|
|
}
|
|
|
|
LogicalResult apply(transform::TransformResults &results,
|
|
transform::TransformState &state) {
|
|
InFlightDiagnostic remark = emitRemark() << "applying transformation";
|
|
if (Attribute message = getMessage())
|
|
remark << " " << message;
|
|
|
|
return success();
|
|
}
|
|
|
|
Attribute getMessage() { return getOperation()->getAttr("message"); }
|
|
|
|
static ParseResult parse(OpAsmParser &parser, OperationState &state) {
|
|
StringAttr message;
|
|
OptionalParseResult result = parser.parseOptionalAttribute(message);
|
|
if (!result.hasValue())
|
|
return success();
|
|
|
|
if (result.getValue().succeeded())
|
|
state.addAttribute("message", message);
|
|
return result.getValue();
|
|
}
|
|
|
|
void print(OpAsmPrinter &printer) {
|
|
if (getMessage())
|
|
printer << " " << getMessage();
|
|
}
|
|
|
|
// No side effects.
|
|
void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
|
|
};
|
|
|
|
/// A test op to exercise the verifier of the PossibleTopLevelTransformOpTrait
|
|
/// in cases where it is attached to ops that do not comply with the trait
|
|
/// requirements. This op cannot be defined in ODS because ODS generates strict
|
|
/// verifiers that overalp with those in the trait and run earlier.
|
|
class TestTransformUnrestrictedOpNoInterface
|
|
: public Op<TestTransformUnrestrictedOpNoInterface,
|
|
transform::PossibleTopLevelTransformOpTrait,
|
|
transform::TransformOpInterface::Trait,
|
|
MemoryEffectOpInterface::Trait> {
|
|
public:
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
|
|
TestTransformUnrestrictedOpNoInterface)
|
|
|
|
using Op::Op;
|
|
|
|
static ArrayRef<StringRef> getAttributeNames() { return {}; }
|
|
|
|
static constexpr llvm::StringLiteral getOperationName() {
|
|
return llvm::StringLiteral(
|
|
"transform.test_transform_unrestricted_op_no_interface");
|
|
}
|
|
|
|
LogicalResult apply(transform::TransformResults &results,
|
|
transform::TransformState &state) {
|
|
return success();
|
|
}
|
|
|
|
// No side effects.
|
|
void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
|
|
};
|
|
} // namespace
|
|
|
|
LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::apply(
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
if (getOperation()->getNumOperands() != 0) {
|
|
results.set(getResult().cast<OpResult>(),
|
|
getOperation()->getOperand(0).getDefiningOp());
|
|
} else {
|
|
results.set(getResult().cast<OpResult>(),
|
|
reinterpret_cast<Operation *>(*getParameter()));
|
|
}
|
|
return success();
|
|
}
|
|
|
|
LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::verify() {
|
|
if (getParameter().hasValue() ^ (getNumOperands() != 1))
|
|
return emitOpError() << "expects either a parameter or an operand";
|
|
return success();
|
|
}
|
|
|
|
LogicalResult mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply(
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
ArrayRef<Operation *> payload = state.getPayloadOps(getOperand());
|
|
assert(payload.size() == 1 && "expected a single target op");
|
|
auto value = reinterpret_cast<intptr_t>(payload[0]);
|
|
if (static_cast<uint64_t>(value) != getParameter()) {
|
|
return emitOpError() << "expected the operand to be associated with "
|
|
<< getParameter() << " got " << value;
|
|
}
|
|
|
|
emitRemark() << "succeeded";
|
|
return success();
|
|
}
|
|
|
|
LogicalResult mlir::test::TestPrintRemarkAtOperandOp::apply(
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
ArrayRef<Operation *> payload = state.getPayloadOps(getOperand());
|
|
for (Operation *op : payload)
|
|
op->emitRemark() << getMessage();
|
|
|
|
return success();
|
|
}
|
|
|
|
namespace {
|
|
/// Test extension of the Transform dialect. Registers additional ops and
|
|
/// declares PDL as dependent dialect since the additional ops are using PDL
|
|
/// types for operands and results.
|
|
class TestTransformDialectExtension
|
|
: public transform::TransformDialectExtension<
|
|
TestTransformDialectExtension> {
|
|
public:
|
|
TestTransformDialectExtension() {
|
|
declareDependentDialect<pdl::PDLDialect>();
|
|
registerTransformOps<TestTransformOp,
|
|
TestTransformUnrestrictedOpNoInterface,
|
|
#define GET_OP_LIST
|
|
#include "TestTransformDialectExtension.cpp.inc"
|
|
>();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "TestTransformDialectExtension.cpp.inc"
|
|
|
|
void ::test::registerTestTransformDialectExtension(DialectRegistry ®istry) {
|
|
registry.addExtensions<TestTransformDialectExtension>();
|
|
}
|