llvm-project/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
Alex Zinenko 1d9a1139fd [mlir] harden expensive-checks mode against ops with repeated operands
Transform operations may indicate that they may accept and consume
several handles pointing to the same or nested payload entities. The
initial implementation of the expensive-checks mode was simply ignoring
such cases as consuming the second handle would fail the check after the
first handle invalidated it by consuming the same payload. Additional
checks had been added since then, which could now trigger assertions in
the expensive-checks module itself (instead of or in addition to
use-after-free assertions down the road), specifically because the
payload associations for invalidated handles is removed from the state
to enable other kinds of checking.

Rework the handling of transform operations with repeated handles so
use-after-consume is still reported properly if the consumption happened
by a preceding operation, as opposed to the a preceding operand of the
same operation that is still (corretly) ignored if the op requests that.

Depends on: D151560

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D151569
2023-05-30 07:53:39 +00:00

806 lines
30 KiB
C++

//===- TestTransformDialectExtension.cpp ----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines an extension of the MLIR Transform dialect for testing
// purposes.
//
//===----------------------------------------------------------------------===//
#include "TestTransformDialectExtension.h"
#include "TestTransformStateExtension.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Compiler.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
namespace {
/// Simple transform op defined outside of the dialect. Just emits a remark when
/// applied. This op is defined in C++ to test that C++ definitions also work
/// for op injection into the Transform dialect.
class TestTransformOp
: public Op<TestTransformOp, transform::TransformOpInterface::Trait,
MemoryEffectOpInterface::Trait> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformOp)
using Op::Op;
static ArrayRef<StringRef> getAttributeNames() { return {}; }
static constexpr llvm::StringLiteral getOperationName() {
return llvm::StringLiteral("transform.test_transform_op");
}
DiagnosedSilenceableFailure apply(transform::TransformResults &results,
transform::TransformState &state) {
InFlightDiagnostic remark = emitRemark() << "applying transformation";
if (Attribute message = getMessage())
remark << " " << message;
return DiagnosedSilenceableFailure::success();
}
Attribute getMessage() { return getOperation()->getAttr("message"); }
static ParseResult parse(OpAsmParser &parser, OperationState &state) {
StringAttr message;
OptionalParseResult result = parser.parseOptionalAttribute(message);
if (!result.has_value())
return success();
if (result.value().succeeded())
state.addAttribute("message", message);
return result.value();
}
void print(OpAsmPrinter &printer) {
if (getMessage())
printer << " " << getMessage();
}
// No side effects.
void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
};
/// A test op to exercise the verifier of the PossibleTopLevelTransformOpTrait
/// in cases where it is attached to ops that do not comply with the trait
/// requirements. This op cannot be defined in ODS because ODS generates strict
/// verifiers that overalp with those in the trait and run earlier.
class TestTransformUnrestrictedOpNoInterface
: public Op<TestTransformUnrestrictedOpNoInterface,
transform::PossibleTopLevelTransformOpTrait,
transform::TransformOpInterface::Trait,
MemoryEffectOpInterface::Trait> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestTransformUnrestrictedOpNoInterface)
using Op::Op;
static ArrayRef<StringRef> getAttributeNames() { return {}; }
static constexpr llvm::StringLiteral getOperationName() {
return llvm::StringLiteral(
"transform.test_transform_unrestricted_op_no_interface");
}
DiagnosedSilenceableFailure apply(transform::TransformResults &results,
transform::TransformState &state) {
return DiagnosedSilenceableFailure::success();
}
// No side effects.
void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
};
} // namespace
DiagnosedSilenceableFailure
mlir::test::TestProduceSelfHandleOrForwardOperandOp::apply(
transform::TransformResults &results, transform::TransformState &state) {
if (getOperation()->getNumOperands() != 0) {
results.set(cast<OpResult>(getResult()),
{getOperation()->getOperand(0).getDefiningOp()});
} else {
results.set(cast<OpResult>(getResult()), {getOperation()});
}
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestProduceSelfHandleOrForwardOperandOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
if (getOperand())
transform::onlyReadsHandle(getOperand(), effects);
transform::producesHandle(getRes(), effects);
}
DiagnosedSilenceableFailure
mlir::test::TestProduceValueHandleToSelfOperand::apply(
transform::TransformResults &results, transform::TransformState &state) {
results.setValues(llvm::cast<OpResult>(getOut()), getIn());
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestProduceValueHandleToSelfOperand::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getIn(), effects);
transform::producesHandle(getOut(), effects);
transform::onlyReadsPayload(effects);
}
DiagnosedSilenceableFailure
mlir::test::TestProduceValueHandleToResult::applyToOne(
Operation *target, transform::ApplyToEachResultList &results,
transform::TransformState &state) {
if (target->getNumResults() <= getNumber())
return emitSilenceableError() << "payload has no result #" << getNumber();
results.push_back(target->getResult(getNumber()));
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestProduceValueHandleToResult::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getIn(), effects);
transform::producesHandle(getOut(), effects);
transform::onlyReadsPayload(effects);
}
DiagnosedSilenceableFailure
mlir::test::TestProduceValueHandleToArgumentOfParentBlock::applyToOne(
Operation *target, transform::ApplyToEachResultList &results,
transform::TransformState &state) {
if (!target->getBlock())
return emitSilenceableError() << "payload has no parent block";
if (target->getBlock()->getNumArguments() <= getNumber())
return emitSilenceableError()
<< "parent of the payload has no argument #" << getNumber();
results.push_back(target->getBlock()->getArgument(getNumber()));
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestProduceValueHandleToArgumentOfParentBlock::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getIn(), effects);
transform::producesHandle(getOut(), effects);
transform::onlyReadsPayload(effects);
}
bool mlir::test::TestConsumeOperand::allowsRepeatedHandleOperands() {
return getAllowRepeatedHandles();
}
DiagnosedSilenceableFailure
mlir::test::TestConsumeOperand::apply(transform::TransformResults &results,
transform::TransformState &state) {
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestConsumeOperand::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::consumesHandle(getOperand(), effects);
if (getSecondOperand())
transform::consumesHandle(getSecondOperand(), effects);
transform::modifiesPayload(effects);
}
DiagnosedSilenceableFailure mlir::test::TestConsumeOperandOfOpKindOrFail::apply(
transform::TransformResults &results, transform::TransformState &state) {
auto payload = state.getPayloadOps(getOperand());
assert(llvm::hasSingleElement(payload) && "expected a single target op");
if ((*payload.begin())->getName().getStringRef() != getOpKind()) {
return emitSilenceableError()
<< "op expected the operand to be associated a payload op of kind "
<< getOpKind() << " got "
<< (*payload.begin())->getName().getStringRef();
}
emitRemark() << "succeeded";
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestConsumeOperandOfOpKindOrFail::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::consumesHandle(getOperand(), effects);
transform::modifiesPayload(effects);
}
DiagnosedSilenceableFailure
mlir::test::TestSucceedIfOperandOfOpKind::matchOperation(
Operation *op, transform::TransformResults &results,
transform::TransformState &state) {
if (op->getName().getStringRef() != getOpKind()) {
return emitSilenceableError()
<< "op expected the operand to be associated with a payload op of "
"kind "
<< getOpKind() << " got " << op->getName().getStringRef();
}
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestSucceedIfOperandOfOpKind::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getOperand(), effects);
transform::onlyReadsPayload(effects);
}
DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandOp::apply(
transform::TransformResults &results, transform::TransformState &state) {
auto payload = state.getPayloadOps(getOperand());
for (Operation *op : payload)
op->emitRemark() << getMessage();
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestPrintRemarkAtOperandOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getOperand(), effects);
transform::onlyReadsPayload(effects);
}
DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandValue::apply(
transform::TransformResults &results, transform::TransformState &state) {
ArrayRef<Value> values = state.getPayloadValues(getIn());
for (Value value : values) {
std::string note;
llvm::raw_string_ostream os(note);
if (auto arg = llvm::dyn_cast<BlockArgument>(value)) {
os << "a block argument #" << arg.getArgNumber() << " in block #"
<< std::distance(arg.getOwner()->getParent()->begin(),
arg.getOwner()->getIterator())
<< " in region #" << arg.getOwner()->getParent()->getRegionNumber();
} else {
os << "an op result #" << llvm::cast<OpResult>(value).getResultNumber();
}
InFlightDiagnostic diag = ::emitRemark(value.getLoc()) << getMessage();
diag.attachNote() << "value handle points to " << os.str();
}
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestPrintRemarkAtOperandValue::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getIn(), effects);
transform::onlyReadsPayload(effects);
}
DiagnosedSilenceableFailure
mlir::test::TestAddTestExtensionOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
state.addExtension<TestTransformStateExtension>(getMessageAttr());
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure
mlir::test::TestCheckIfTestExtensionPresentOp::apply(
transform::TransformResults &results, transform::TransformState &state) {
auto *extension = state.getExtension<TestTransformStateExtension>();
if (!extension) {
emitRemark() << "extension absent";
return DiagnosedSilenceableFailure::success();
}
InFlightDiagnostic diag = emitRemark()
<< "extension present, " << extension->getMessage();
for (Operation *payload : state.getPayloadOps(getOperand())) {
diag.attachNote(payload->getLoc()) << "associated payload op";
#ifndef NDEBUG
SmallVector<Value> handles;
assert(succeeded(state.getHandlesForPayloadOp(payload, handles)));
assert(llvm::is_contained(handles, getOperand()) &&
"inconsistent mapping between transform IR handles and payload IR "
"operations");
#endif // NDEBUG
}
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestCheckIfTestExtensionPresentOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getOperand(), effects);
transform::onlyReadsPayload(effects);
}
DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply(
transform::TransformResults &results, transform::TransformState &state) {
auto *extension = state.getExtension<TestTransformStateExtension>();
if (!extension)
return emitDefiniteFailure("TestTransformStateExtension missing");
if (failed(extension->updateMapping(
*state.getPayloadOps(getOperand()).begin(), getOperation())))
return DiagnosedSilenceableFailure::definiteFailure();
if (getNumResults() > 0)
results.set(cast<OpResult>(getResult(0)), {getOperation()});
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestRemapOperandPayloadToSelfOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getOperand(), effects);
transform::producesHandle(getOut(), effects);
transform::onlyReadsPayload(effects);
}
DiagnosedSilenceableFailure mlir::test::TestRemoveTestExtensionOp::apply(
transform::TransformResults &results, transform::TransformState &state) {
state.removeExtension<TestTransformStateExtension>();
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure
mlir::test::TestReversePayloadOpsOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
auto payloadOps = state.getPayloadOps(getTarget());
auto reversedOps = llvm::to_vector(llvm::reverse(payloadOps));
results.set(llvm::cast<OpResult>(getResult()), reversedOps);
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure mlir::test::TestTransformOpWithRegions::apply(
transform::TransformResults &results, transform::TransformState &state) {
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestTransformOpWithRegions::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
DiagnosedSilenceableFailure
mlir::test::TestBranchingTransformOpTerminator::apply(
transform::TransformResults &results, transform::TransformState &state) {
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestBranchingTransformOpTerminator::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply(
transform::TransformResults &results, transform::TransformState &state) {
emitRemark() << getRemark();
for (Operation *op : state.getPayloadOps(getTarget()))
op->erase();
if (getFailAfterErase())
return emitSilenceableError() << "silenceable error";
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestEmitRemarkAndEraseOperandOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::consumesHandle(getTarget(), effects);
transform::modifiesPayload(effects);
}
DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfResultsOp::applyToOne(
Operation *target, transform::ApplyToEachResultList &results,
transform::TransformState &state) {
OperationState opState(target->getLoc(), "foo");
results.push_back(OpBuilder(target).create(opState));
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure
mlir::test::TestWrongNumberOfMultiResultsOp::applyToOne(
Operation *target, transform::ApplyToEachResultList &results,
transform::TransformState &state) {
static int count = 0;
if (count++ == 0) {
OperationState opState(target->getLoc(), "foo");
results.push_back(OpBuilder(target).create(opState));
}
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure
mlir::test::TestCorrectNumberOfMultiResultsOp::applyToOne(
Operation *target, transform::ApplyToEachResultList &results,
transform::TransformState &state) {
OperationState opState(target->getLoc(), "foo");
results.push_back(OpBuilder(target).create(opState));
results.push_back(OpBuilder(target).create(opState));
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure
mlir::test::TestMixedNullAndNonNullResultsOp::applyToOne(
Operation *target, transform::ApplyToEachResultList &results,
transform::TransformState &state) {
OperationState opState(target->getLoc(), "foo");
results.push_back(nullptr);
results.push_back(OpBuilder(target).create(opState));
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure
mlir::test::TestMixedSuccessAndSilenceableOp::applyToOne(
Operation *target, transform::ApplyToEachResultList &results,
transform::TransformState &state) {
if (target->hasAttr("target_me"))
return DiagnosedSilenceableFailure::success();
return emitDefaultSilenceableFailure(target);
}
DiagnosedSilenceableFailure
mlir::test::TestPrintNumberOfAssociatedPayloadIROps::apply(
transform::TransformResults &results, transform::TransformState &state) {
if (!getHandle())
emitRemark() << 0;
emitRemark() << llvm::range_size(state.getPayloadOps(getHandle()));
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestPrintNumberOfAssociatedPayloadIROps::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getHandle(), effects);
}
DiagnosedSilenceableFailure
mlir::test::TestCopyPayloadOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
results.set(llvm::cast<OpResult>(getCopy()),
state.getPayloadOps(getHandle()));
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestCopyPayloadOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getHandle(), effects);
transform::producesHandle(getCopy(), effects);
transform::onlyReadsPayload(effects);
}
DiagnosedSilenceableFailure mlir::transform::TestDialectOpType::checkPayload(
Location loc, ArrayRef<Operation *> payload) const {
if (payload.empty())
return DiagnosedSilenceableFailure::success();
for (Operation *op : payload) {
if (op->getName().getDialectNamespace() != "test") {
return emitSilenceableError(loc) << "expected the payload operation to "
"belong to the 'test' dialect";
}
}
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure mlir::transform::TestDialectParamType::checkPayload(
Location loc, ArrayRef<Attribute> payload) const {
for (Attribute attr : payload) {
auto integerAttr = llvm::dyn_cast<IntegerAttr>(attr);
if (integerAttr && integerAttr.getType().isSignlessInteger(32))
continue;
return emitSilenceableError(loc)
<< "expected the parameter to be a i32 integer attribute";
}
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getTarget(), effects);
}
DiagnosedSilenceableFailure
mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::apply(
transform::TransformResults &results, transform::TransformState &state) {
int64_t count = 0;
for (Operation *op : state.getPayloadOps(getTarget())) {
op->walk([&](Operation *nested) {
SmallVector<Value> handles;
(void)state.getHandlesForPayloadOp(nested, handles);
count += handles.size();
});
}
emitRemark() << count << " handles nested under";
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestPrintParamOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getParam(), effects);
if (getAnchor())
transform::onlyReadsHandle(getAnchor(), effects);
transform::onlyReadsPayload(effects);
}
DiagnosedSilenceableFailure
mlir::test::TestPrintParamOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
std::string str;
llvm::raw_string_ostream os(str);
if (getMessage())
os << *getMessage() << " ";
llvm::interleaveComma(state.getParams(getParam()), os);
if (!getAnchor()) {
emitRemark() << os.str();
return DiagnosedSilenceableFailure::success();
}
for (Operation *payload : state.getPayloadOps(getAnchor()))
::mlir::emitRemark(payload->getLoc()) << os.str();
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure
mlir::test::TestAddToParamOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<uint32_t> values(/*Size=*/1, /*Value=*/0);
if (Value param = getParam()) {
values = llvm::to_vector(
llvm::map_range(state.getParams(param), [](Attribute attr) -> uint32_t {
return llvm::cast<IntegerAttr>(attr).getValue().getLimitedValue(
UINT32_MAX);
}));
}
Builder builder(getContext());
SmallVector<Attribute> result = llvm::to_vector(
llvm::map_range(values, [this, &builder](uint32_t value) -> Attribute {
return builder.getI32IntegerAttr(value + getAddendum());
}));
results.setParams(llvm::cast<OpResult>(getResult()), result);
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure
mlir::test::TestProduceParamWithNumberOfTestOps::apply(
transform::TransformResults &results, transform::TransformState &state) {
Builder builder(getContext());
SmallVector<Attribute> result = llvm::to_vector(
llvm::map_range(state.getPayloadOps(getHandle()),
[&builder](Operation *payload) -> Attribute {
int32_t count = 0;
payload->walk([&count](Operation *op) {
if (op->getName().getDialectNamespace() == "test")
++count;
});
return builder.getI32IntegerAttr(count);
}));
results.setParams(llvm::cast<OpResult>(getResult()), result);
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure
mlir::test::TestProduceIntegerParamWithTypeOp::apply(
transform::TransformResults &results, transform::TransformState &state) {
Attribute zero = IntegerAttr::get(getType(), 0);
results.setParams(llvm::cast<OpResult>(getResult()), zero);
return DiagnosedSilenceableFailure::success();
}
LogicalResult mlir::test::TestProduceIntegerParamWithTypeOp::verify() {
if (!llvm::isa<IntegerType>(getType())) {
return emitOpError() << "expects an integer type";
}
return success();
}
void mlir::test::TestProduceTransformParamOrForwardOperandOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getIn(), effects);
transform::producesHandle(getOut(), effects);
transform::producesHandle(getParam(), effects);
}
DiagnosedSilenceableFailure
mlir::test::TestProduceTransformParamOrForwardOperandOp::applyToOne(
Operation *target, ::transform::ApplyToEachResultList &results,
::transform::TransformState &state) {
Builder builder(getContext());
if (getFirstResultIsParam()) {
results.push_back(builder.getI64IntegerAttr(0));
} else if (getFirstResultIsNull()) {
results.push_back(nullptr);
} else {
results.push_back(*state.getPayloadOps(getIn()).begin());
}
if (getSecondResultIsHandle()) {
results.push_back(*state.getPayloadOps(getIn()).begin());
} else {
results.push_back(builder.getI64IntegerAttr(42));
}
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestProduceNullPayloadOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::producesHandle(getOut(), effects);
}
DiagnosedSilenceableFailure mlir::test::TestProduceNullPayloadOp::apply(
transform::TransformResults &results, transform::TransformState &state) {
SmallVector<Operation *, 1> null({nullptr});
results.set(llvm::cast<OpResult>(getOut()), null);
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure mlir::test::TestProduceEmptyPayloadOp::apply(
transform::TransformResults &results, transform::TransformState &state) {
results.set(cast<OpResult>(getOut()), {});
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestProduceNullParamOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::producesHandle(getOut(), effects);
}
DiagnosedSilenceableFailure
mlir::test::TestProduceNullParamOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
results.setParams(llvm::cast<OpResult>(getOut()), Attribute());
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestProduceNullValueOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::producesHandle(getOut(), effects);
}
DiagnosedSilenceableFailure
mlir::test::TestProduceNullValueOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
results.setValues(llvm::cast<OpResult>(getOut()), Value());
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestRequiredMemoryEffectsOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
if (getHasOperandEffect())
transform::consumesHandle(getIn(), effects);
if (getHasResultEffect())
transform::producesHandle(getOut(), effects);
else
transform::onlyReadsHandle(getOut(), effects);
if (getModifiesPayload())
transform::modifiesPayload(effects);
}
DiagnosedSilenceableFailure mlir::test::TestRequiredMemoryEffectsOp::apply(
transform::TransformResults &results, transform::TransformState &state) {
results.set(llvm::cast<OpResult>(getOut()), state.getPayloadOps(getIn()));
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestTrackedRewriteOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getIn(), effects);
transform::modifiesPayload(effects);
}
namespace {
/// A TrackingListener for test cases. When the replacement op is
/// "test.update_mapping", it is considered as a replacement op in the transform
/// state mapping. Otherwise, it is not and the original op is simply removed
/// from the mapping.
class TestTrackingListener : public transform::TrackingListener {
using transform::TrackingListener::TrackingListener;
protected:
Operation *findReplacementOp(Operation *op,
ValueRange newValues) const override {
if (newValues.size() != 1)
return nullptr;
Operation *replacement = newValues[0].getDefiningOp();
if (!replacement)
return nullptr;
if (replacement->getName().getStringRef() != "test.update_mapping")
return nullptr;
return replacement;
}
};
} // namespace
DiagnosedSilenceableFailure
mlir::test::TestTrackedRewriteOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
TestTrackingListener listener(state, *this);
IRRewriter rewriter(getContext(), &listener);
int64_t numIterations = 0;
// `getPayloadOps` returns an iterator that skips ops that are erased in the
// loop body. Replacement ops are not enumerated.
for (Operation *op : state.getPayloadOps(getIn())) {
++numIterations;
rewriter.setInsertionPointToEnd(op->getBlock());
// Erase all payload ops. The outer loop should have only one iteration.
for (Operation *op : state.getPayloadOps(getIn())) {
if (op->getName().getStringRef() != "test.replace_me")
continue;
auto replacementName = op->getAttrOfType<StringAttr>("replacement");
if (!replacementName)
continue;
SmallVector<NamedAttribute> attributes;
attributes.emplace_back(rewriter.getStringAttr("original_op"),
op->getName().getIdentifier());
OperationState opState(op->getLoc(), replacementName,
/*operands=*/ValueRange(),
/*types=*/op->getResultTypes(), attributes);
Operation *newOp = rewriter.create(opState);
rewriter.replaceOp(op, newOp->getResults());
}
}
emitRemark() << numIterations << " iterations";
return DiagnosedSilenceableFailure::success();
}
namespace {
/// Test extension of the Transform dialect. Registers additional ops and
/// declares PDL as dependent dialect since the additional ops are using PDL
/// types for operands and results.
class TestTransformDialectExtension
: public transform::TransformDialectExtension<
TestTransformDialectExtension> {
public:
using Base::Base;
void init() {
declareDependentDialect<pdl::PDLDialect>();
registerTransformOps<TestTransformOp,
TestTransformUnrestrictedOpNoInterface,
#define GET_OP_LIST
#include "TestTransformDialectExtension.cpp.inc"
>();
registerTypes<
#define GET_TYPEDEF_LIST
#include "TestTransformDialectExtensionTypes.cpp.inc"
>();
auto verboseConstraint = [](PatternRewriter &rewriter,
ArrayRef<PDLValue> pdlValues) {
for (const PDLValue &pdlValue : pdlValues) {
if (Operation *op = pdlValue.dyn_cast<Operation *>()) {
op->emitWarning() << "from PDL constraint";
}
}
return success();
};
addDialectDataInitializer<transform::PDLMatchHooks>(
[&](transform::PDLMatchHooks &hooks) {
llvm::StringMap<PDLConstraintFunction> constraints;
constraints.try_emplace("verbose_constraint", verboseConstraint);
hooks.mergeInPDLMatchHooks(std::move(constraints));
});
}
};
} // namespace
// These are automatically generated by ODS but are not used as the Transform
// dialect uses a different dispatch mechanism to support dialect extensions.
LLVM_ATTRIBUTE_UNUSED static OptionalParseResult
generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value);
LLVM_ATTRIBUTE_UNUSED static LogicalResult
generatedTypePrinter(Type def, AsmPrinter &printer);
#define GET_TYPEDEF_CLASSES
#include "TestTransformDialectExtensionTypes.cpp.inc"
#define GET_OP_CLASSES
#include "TestTransformDialectExtension.cpp.inc"
void ::test::registerTestTransformDialectExtension(DialectRegistry &registry) {
registry.addExtensions<TestTransformDialectExtension>();
}