
with a small stack-use-after-scope fix in getConstraintPredicates() This reverts commit c80e6edba4a9593f0587e27fa0ac825ebe174afd.
926 lines
35 KiB
C++
926 lines
35 KiB
C++
//===- TestTransformDialectExtension.cpp ----------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file defines an extension of the MLIR Transform dialect for testing
|
|
// purposes.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "TestTransformDialectExtension.h"
|
|
#include "TestTransformStateExtension.h"
|
|
#include "mlir/Dialect/PDL/IR/PDL.h"
|
|
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
|
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
|
|
#include "mlir/Dialect/Transform/IR/TransformOps.h"
|
|
#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h"
|
|
#include "mlir/IR/OpImplementation.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
#include "llvm/Support/Compiler.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
/// Simple transform op defined outside of the dialect. Just emits a remark when
|
|
/// applied. This op is defined in C++ to test that C++ definitions also work
|
|
/// for op injection into the Transform dialect.
|
|
class TestTransformOp
|
|
: public Op<TestTransformOp, transform::TransformOpInterface::Trait,
|
|
MemoryEffectOpInterface::Trait> {
|
|
public:
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformOp)
|
|
|
|
using Op::Op;
|
|
|
|
static ArrayRef<StringRef> getAttributeNames() { return {}; }
|
|
|
|
static constexpr llvm::StringLiteral getOperationName() {
|
|
return llvm::StringLiteral("transform.test_transform_op");
|
|
}
|
|
|
|
DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results,
|
|
transform::TransformState &state) {
|
|
InFlightDiagnostic remark = emitRemark() << "applying transformation";
|
|
if (Attribute message = getMessage())
|
|
remark << " " << message;
|
|
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
Attribute getMessage() {
|
|
return getOperation()->getDiscardableAttr("message");
|
|
}
|
|
|
|
static ParseResult parse(OpAsmParser &parser, OperationState &state) {
|
|
StringAttr message;
|
|
OptionalParseResult result = parser.parseOptionalAttribute(message);
|
|
if (!result.has_value())
|
|
return success();
|
|
|
|
if (result.value().succeeded())
|
|
state.addAttribute("message", message);
|
|
return result.value();
|
|
}
|
|
|
|
void print(OpAsmPrinter &printer) {
|
|
if (getMessage())
|
|
printer << " " << getMessage();
|
|
}
|
|
|
|
// No side effects.
|
|
void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
|
|
};
|
|
|
|
/// A test op to exercise the verifier of the PossibleTopLevelTransformOpTrait
|
|
/// in cases where it is attached to ops that do not comply with the trait
|
|
/// requirements. This op cannot be defined in ODS because ODS generates strict
|
|
/// verifiers that overalp with those in the trait and run earlier.
|
|
class TestTransformUnrestrictedOpNoInterface
|
|
: public Op<TestTransformUnrestrictedOpNoInterface,
|
|
transform::PossibleTopLevelTransformOpTrait,
|
|
transform::TransformOpInterface::Trait,
|
|
MemoryEffectOpInterface::Trait> {
|
|
public:
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
|
|
TestTransformUnrestrictedOpNoInterface)
|
|
|
|
using Op::Op;
|
|
|
|
static ArrayRef<StringRef> getAttributeNames() { return {}; }
|
|
|
|
static constexpr llvm::StringLiteral getOperationName() {
|
|
return llvm::StringLiteral(
|
|
"transform.test_transform_unrestricted_op_no_interface");
|
|
}
|
|
|
|
DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results,
|
|
transform::TransformState &state) {
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
// No side effects.
|
|
void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
|
|
};
|
|
} // namespace
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestProduceSelfHandleOrForwardOperandOp::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
if (getOperation()->getNumOperands() != 0) {
|
|
results.set(cast<OpResult>(getResult()),
|
|
{getOperation()->getOperand(0).getDefiningOp()});
|
|
} else {
|
|
results.set(cast<OpResult>(getResult()), {getOperation()});
|
|
}
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestProduceSelfHandleOrForwardOperandOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
if (getOperand())
|
|
transform::onlyReadsHandle(getOperand(), effects);
|
|
transform::producesHandle(getRes(), effects);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestProduceValueHandleToSelfOperand::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
results.setValues(llvm::cast<OpResult>(getOut()), {getIn()});
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestProduceValueHandleToSelfOperand::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::onlyReadsHandle(getIn(), effects);
|
|
transform::producesHandle(getOut(), effects);
|
|
transform::onlyReadsPayload(effects);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestProduceValueHandleToResult::applyToOne(
|
|
transform::TransformRewriter &rewriter, Operation *target,
|
|
transform::ApplyToEachResultList &results,
|
|
transform::TransformState &state) {
|
|
if (target->getNumResults() <= getNumber())
|
|
return emitSilenceableError() << "payload has no result #" << getNumber();
|
|
results.push_back(target->getResult(getNumber()));
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestProduceValueHandleToResult::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::onlyReadsHandle(getIn(), effects);
|
|
transform::producesHandle(getOut(), effects);
|
|
transform::onlyReadsPayload(effects);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestProduceValueHandleToArgumentOfParentBlock::applyToOne(
|
|
transform::TransformRewriter &rewriter, Operation *target,
|
|
transform::ApplyToEachResultList &results,
|
|
transform::TransformState &state) {
|
|
if (!target->getBlock())
|
|
return emitSilenceableError() << "payload has no parent block";
|
|
if (target->getBlock()->getNumArguments() <= getNumber())
|
|
return emitSilenceableError()
|
|
<< "parent of the payload has no argument #" << getNumber();
|
|
results.push_back(target->getBlock()->getArgument(getNumber()));
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestProduceValueHandleToArgumentOfParentBlock::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::onlyReadsHandle(getIn(), effects);
|
|
transform::producesHandle(getOut(), effects);
|
|
transform::onlyReadsPayload(effects);
|
|
}
|
|
|
|
bool mlir::test::TestConsumeOperand::allowsRepeatedHandleOperands() {
|
|
return getAllowRepeatedHandles();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestConsumeOperand::apply(transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results,
|
|
transform::TransformState &state) {
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestConsumeOperand::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::consumesHandle(getOperand(), effects);
|
|
if (getSecondOperand())
|
|
transform::consumesHandle(getSecondOperand(), effects);
|
|
transform::modifiesPayload(effects);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure mlir::test::TestConsumeOperandOfOpKindOrFail::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
auto payload = state.getPayloadOps(getOperand());
|
|
assert(llvm::hasSingleElement(payload) && "expected a single target op");
|
|
if ((*payload.begin())->getName().getStringRef() != getOpKind()) {
|
|
return emitSilenceableError()
|
|
<< "op expected the operand to be associated a payload op of kind "
|
|
<< getOpKind() << " got "
|
|
<< (*payload.begin())->getName().getStringRef();
|
|
}
|
|
|
|
emitRemark() << "succeeded";
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestConsumeOperandOfOpKindOrFail::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::consumesHandle(getOperand(), effects);
|
|
transform::modifiesPayload(effects);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestSucceedIfOperandOfOpKind::matchOperation(
|
|
Operation *op, transform::TransformResults &results,
|
|
transform::TransformState &state) {
|
|
if (op->getName().getStringRef() != getOpKind()) {
|
|
return emitSilenceableError()
|
|
<< "op expected the operand to be associated with a payload op of "
|
|
"kind "
|
|
<< getOpKind() << " got " << op->getName().getStringRef();
|
|
}
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestSucceedIfOperandOfOpKind::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::onlyReadsHandle(getOperand(), effects);
|
|
transform::onlyReadsPayload(effects);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure mlir::test::TestAddTestExtensionOp::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
state.addExtension<TestTransformStateExtension>(getMessageAttr());
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestCheckIfTestExtensionPresentOp::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
auto *extension = state.getExtension<TestTransformStateExtension>();
|
|
if (!extension) {
|
|
emitRemark() << "extension absent";
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
InFlightDiagnostic diag = emitRemark()
|
|
<< "extension present, " << extension->getMessage();
|
|
for (Operation *payload : state.getPayloadOps(getOperand())) {
|
|
diag.attachNote(payload->getLoc()) << "associated payload op";
|
|
#ifndef NDEBUG
|
|
SmallVector<Value> handles;
|
|
assert(succeeded(state.getHandlesForPayloadOp(payload, handles)));
|
|
assert(llvm::is_contained(handles, getOperand()) &&
|
|
"inconsistent mapping between transform IR handles and payload IR "
|
|
"operations");
|
|
#endif // NDEBUG
|
|
}
|
|
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestCheckIfTestExtensionPresentOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::onlyReadsHandle(getOperand(), effects);
|
|
transform::onlyReadsPayload(effects);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
auto *extension = state.getExtension<TestTransformStateExtension>();
|
|
if (!extension)
|
|
return emitDefiniteFailure("TestTransformStateExtension missing");
|
|
|
|
if (failed(extension->updateMapping(
|
|
*state.getPayloadOps(getOperand()).begin(), getOperation())))
|
|
return DiagnosedSilenceableFailure::definiteFailure();
|
|
if (getNumResults() > 0)
|
|
results.set(cast<OpResult>(getResult(0)), {getOperation()});
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestRemapOperandPayloadToSelfOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::onlyReadsHandle(getOperand(), effects);
|
|
transform::producesHandle(getOut(), effects);
|
|
transform::onlyReadsPayload(effects);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure mlir::test::TestRemoveTestExtensionOp::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
state.removeExtension<TestTransformStateExtension>();
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure mlir::test::TestReversePayloadOpsOp::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
auto payloadOps = state.getPayloadOps(getTarget());
|
|
auto reversedOps = llvm::to_vector(llvm::reverse(payloadOps));
|
|
results.set(llvm::cast<OpResult>(getResult()), reversedOps);
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure mlir::test::TestTransformOpWithRegions::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestTransformOpWithRegions::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestBranchingTransformOpTerminator::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestBranchingTransformOpTerminator::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
|
|
|
|
DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
emitRemark() << getRemark();
|
|
for (Operation *op : state.getPayloadOps(getTarget()))
|
|
rewriter.eraseOp(op);
|
|
|
|
if (getFailAfterErase())
|
|
return emitSilenceableError() << "silenceable error";
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestEmitRemarkAndEraseOperandOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::consumesHandle(getTarget(), effects);
|
|
transform::modifiesPayload(effects);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfResultsOp::applyToOne(
|
|
transform::TransformRewriter &rewriter, Operation *target,
|
|
transform::ApplyToEachResultList &results,
|
|
transform::TransformState &state) {
|
|
OperationState opState(target->getLoc(), "foo");
|
|
results.push_back(OpBuilder(target).create(opState));
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestWrongNumberOfMultiResultsOp::applyToOne(
|
|
transform::TransformRewriter &rewriter, Operation *target,
|
|
transform::ApplyToEachResultList &results,
|
|
transform::TransformState &state) {
|
|
static int count = 0;
|
|
if (count++ == 0) {
|
|
OperationState opState(target->getLoc(), "foo");
|
|
results.push_back(OpBuilder(target).create(opState));
|
|
}
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestCorrectNumberOfMultiResultsOp::applyToOne(
|
|
transform::TransformRewriter &rewriter, Operation *target,
|
|
transform::ApplyToEachResultList &results,
|
|
transform::TransformState &state) {
|
|
OperationState opState(target->getLoc(), "foo");
|
|
results.push_back(OpBuilder(target).create(opState));
|
|
results.push_back(OpBuilder(target).create(opState));
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestMixedNullAndNonNullResultsOp::applyToOne(
|
|
transform::TransformRewriter &rewriter, Operation *target,
|
|
transform::ApplyToEachResultList &results,
|
|
transform::TransformState &state) {
|
|
OperationState opState(target->getLoc(), "foo");
|
|
results.push_back(nullptr);
|
|
results.push_back(OpBuilder(target).create(opState));
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestMixedSuccessAndSilenceableOp::applyToOne(
|
|
transform::TransformRewriter &rewriter, Operation *target,
|
|
transform::ApplyToEachResultList &results,
|
|
transform::TransformState &state) {
|
|
if (target->hasAttr("target_me"))
|
|
return DiagnosedSilenceableFailure::success();
|
|
return emitDefaultSilenceableFailure(target);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestCopyPayloadOp::apply(transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results,
|
|
transform::TransformState &state) {
|
|
results.set(llvm::cast<OpResult>(getCopy()),
|
|
state.getPayloadOps(getHandle()));
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestCopyPayloadOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::onlyReadsHandle(getHandle(), effects);
|
|
transform::producesHandle(getCopy(), effects);
|
|
transform::onlyReadsPayload(effects);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure mlir::transform::TestDialectOpType::checkPayload(
|
|
Location loc, ArrayRef<Operation *> payload) const {
|
|
if (payload.empty())
|
|
return DiagnosedSilenceableFailure::success();
|
|
|
|
for (Operation *op : payload) {
|
|
if (op->getName().getDialectNamespace() != "test") {
|
|
return emitSilenceableError(loc) << "expected the payload operation to "
|
|
"belong to the 'test' dialect";
|
|
}
|
|
}
|
|
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure mlir::transform::TestDialectParamType::checkPayload(
|
|
Location loc, ArrayRef<Attribute> payload) const {
|
|
for (Attribute attr : payload) {
|
|
auto integerAttr = llvm::dyn_cast<IntegerAttr>(attr);
|
|
if (integerAttr && integerAttr.getType().isSignlessInteger(32))
|
|
continue;
|
|
return emitSilenceableError(loc)
|
|
<< "expected the parameter to be a i32 integer attribute";
|
|
}
|
|
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::onlyReadsHandle(getTarget(), effects);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
int64_t count = 0;
|
|
for (Operation *op : state.getPayloadOps(getTarget())) {
|
|
op->walk([&](Operation *nested) {
|
|
SmallVector<Value> handles;
|
|
(void)state.getHandlesForPayloadOp(nested, handles);
|
|
count += handles.size();
|
|
});
|
|
}
|
|
emitRemark() << count << " handles nested under";
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestAddToParamOp::apply(transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results,
|
|
transform::TransformState &state) {
|
|
SmallVector<uint32_t> values(/*Size=*/1, /*Value=*/0);
|
|
if (Value param = getParam()) {
|
|
values = llvm::to_vector(
|
|
llvm::map_range(state.getParams(param), [](Attribute attr) -> uint32_t {
|
|
return llvm::cast<IntegerAttr>(attr).getValue().getLimitedValue(
|
|
UINT32_MAX);
|
|
}));
|
|
}
|
|
|
|
Builder builder(getContext());
|
|
SmallVector<Attribute> result = llvm::to_vector(
|
|
llvm::map_range(values, [this, &builder](uint32_t value) -> Attribute {
|
|
return builder.getI32IntegerAttr(value + getAddendum());
|
|
}));
|
|
results.setParams(llvm::cast<OpResult>(getResult()), result);
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestProduceParamWithNumberOfTestOps::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
Builder builder(getContext());
|
|
SmallVector<Attribute> result = llvm::to_vector(
|
|
llvm::map_range(state.getPayloadOps(getHandle()),
|
|
[&builder](Operation *payload) -> Attribute {
|
|
int32_t count = 0;
|
|
payload->walk([&count](Operation *op) {
|
|
if (op->getName().getDialectNamespace() == "test")
|
|
++count;
|
|
});
|
|
return builder.getI32IntegerAttr(count);
|
|
}));
|
|
results.setParams(llvm::cast<OpResult>(getResult()), result);
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestProduceParamOp::apply(transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results,
|
|
transform::TransformState &state) {
|
|
results.setParams(llvm::cast<OpResult>(getResult()), getAttr());
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestProduceTransformParamOrForwardOperandOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::onlyReadsHandle(getIn(), effects);
|
|
transform::producesHandle(getOut(), effects);
|
|
transform::producesHandle(getParam(), effects);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestProduceTransformParamOrForwardOperandOp::applyToOne(
|
|
transform::TransformRewriter &rewriter, Operation *target,
|
|
::transform::ApplyToEachResultList &results,
|
|
::transform::TransformState &state) {
|
|
Builder builder(getContext());
|
|
if (getFirstResultIsParam()) {
|
|
results.push_back(builder.getI64IntegerAttr(0));
|
|
} else if (getFirstResultIsNull()) {
|
|
results.push_back(nullptr);
|
|
} else {
|
|
results.push_back(*state.getPayloadOps(getIn()).begin());
|
|
}
|
|
|
|
if (getSecondResultIsHandle()) {
|
|
results.push_back(*state.getPayloadOps(getIn()).begin());
|
|
} else {
|
|
results.push_back(builder.getI64IntegerAttr(42));
|
|
}
|
|
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestProduceNullPayloadOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::producesHandle(getOut(), effects);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure mlir::test::TestProduceNullPayloadOp::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
SmallVector<Operation *, 1> null({nullptr});
|
|
results.set(llvm::cast<OpResult>(getOut()), null);
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure mlir::test::TestProduceEmptyPayloadOp::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
results.set(cast<OpResult>(getOut()), {});
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestProduceNullParamOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::producesHandle(getOut(), effects);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure mlir::test::TestProduceNullParamOp::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
results.setParams(llvm::cast<OpResult>(getOut()), Attribute());
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestProduceNullValueOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::producesHandle(getOut(), effects);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure mlir::test::TestProduceNullValueOp::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
results.setValues(llvm::cast<OpResult>(getOut()), {Value()});
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestRequiredMemoryEffectsOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
if (getHasOperandEffect())
|
|
transform::consumesHandle(getIn(), effects);
|
|
|
|
if (getHasResultEffect())
|
|
transform::producesHandle(getOut(), effects);
|
|
else
|
|
transform::onlyReadsHandle(getOut(), effects);
|
|
|
|
if (getModifiesPayload())
|
|
transform::modifiesPayload(effects);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure mlir::test::TestRequiredMemoryEffectsOp::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
results.set(llvm::cast<OpResult>(getOut()), state.getPayloadOps(getIn()));
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestTrackedRewriteOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::onlyReadsHandle(getIn(), effects);
|
|
transform::modifiesPayload(effects);
|
|
}
|
|
|
|
void mlir::test::TestDummyPayloadOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
for (OpResult result : getResults())
|
|
transform::producesHandle(result, effects);
|
|
}
|
|
|
|
LogicalResult mlir::test::TestDummyPayloadOp::verify() {
|
|
if (getFailToVerify())
|
|
return emitOpError() << "fail_to_verify is set";
|
|
return success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestTrackedRewriteOp::apply(transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results,
|
|
transform::TransformState &state) {
|
|
int64_t numIterations = 0;
|
|
|
|
// `getPayloadOps` returns an iterator that skips ops that are erased in the
|
|
// loop body. Replacement ops are not enumerated.
|
|
for (Operation *op : state.getPayloadOps(getIn())) {
|
|
++numIterations;
|
|
(void)op;
|
|
|
|
// Erase all payload ops. The outer loop should have only one iteration.
|
|
for (Operation *op : state.getPayloadOps(getIn())) {
|
|
rewriter.setInsertionPoint(op);
|
|
if (op->hasAttr("erase_me")) {
|
|
rewriter.eraseOp(op);
|
|
continue;
|
|
}
|
|
if (!op->hasAttr("replace_me")) {
|
|
continue;
|
|
}
|
|
|
|
SmallVector<NamedAttribute> attributes;
|
|
attributes.emplace_back(rewriter.getStringAttr("new_op"),
|
|
rewriter.getUnitAttr());
|
|
OperationState opState(op->getLoc(), op->getName().getIdentifier(),
|
|
/*operands=*/ValueRange(),
|
|
/*types=*/op->getResultTypes(), attributes);
|
|
Operation *newOp = rewriter.create(opState);
|
|
rewriter.replaceOp(op, newOp->getResults());
|
|
}
|
|
}
|
|
|
|
emitRemark() << numIterations << " iterations";
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
namespace {
|
|
// Test pattern to replace an operation with a new op.
|
|
class ReplaceWithNewOp : public RewritePattern {
|
|
public:
|
|
ReplaceWithNewOp(MLIRContext *context)
|
|
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
|
|
|
|
LogicalResult matchAndRewrite(Operation *op,
|
|
PatternRewriter &rewriter) const override {
|
|
auto newName = op->getAttrOfType<StringAttr>("replace_with_new_op");
|
|
if (!newName)
|
|
return failure();
|
|
Operation *newOp = rewriter.create(
|
|
op->getLoc(), OperationName(newName, op->getContext()).getIdentifier(),
|
|
op->getOperands(), op->getResultTypes());
|
|
rewriter.replaceOp(op, newOp->getResults());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// Test pattern to erase an operation.
|
|
class EraseOp : public RewritePattern {
|
|
public:
|
|
EraseOp(MLIRContext *context)
|
|
: RewritePattern("test.erase_op", /*benefit=*/1, context) {}
|
|
LogicalResult matchAndRewrite(Operation *op,
|
|
PatternRewriter &rewriter) const override {
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
void mlir::test::ApplyTestPatternsOp::populatePatterns(
|
|
RewritePatternSet &patterns) {
|
|
patterns.insert<ReplaceWithNewOp, EraseOp>(patterns.getContext());
|
|
}
|
|
|
|
void mlir::test::TestReEnterRegionOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::consumesHandle(getOperands(), effects);
|
|
transform::modifiesPayload(effects);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
mlir::test::TestReEnterRegionOp::apply(transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results,
|
|
transform::TransformState &state) {
|
|
|
|
SmallVector<SmallVector<transform::MappedValue>> mappings;
|
|
for (BlockArgument arg : getBody().front().getArguments()) {
|
|
mappings.emplace_back(llvm::to_vector(llvm::map_range(
|
|
state.getPayloadOps(getOperand(arg.getArgNumber())),
|
|
[](Operation *op) -> transform::MappedValue { return op; })));
|
|
}
|
|
|
|
for (int i = 0; i < 4; ++i) {
|
|
auto scope = state.make_region_scope(getBody());
|
|
for (BlockArgument arg : getBody().front().getArguments()) {
|
|
if (failed(state.mapBlockArgument(arg, mappings[arg.getArgNumber()])))
|
|
return DiagnosedSilenceableFailure::definiteFailure();
|
|
}
|
|
for (Operation &op : getBody().front().without_terminator()) {
|
|
DiagnosedSilenceableFailure diag =
|
|
state.applyTransform(cast<transform::TransformOpInterface>(op));
|
|
if (!diag.succeeded())
|
|
return diag;
|
|
}
|
|
}
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
LogicalResult mlir::test::TestReEnterRegionOp::verify() {
|
|
if (getNumOperands() != getBody().front().getNumArguments()) {
|
|
return emitOpError() << "expects as many operands as block arguments";
|
|
}
|
|
return success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure mlir::test::TestNotifyPayloadOpReplacedOp::apply(
|
|
transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results, transform::TransformState &state) {
|
|
auto originalOps = state.getPayloadOps(getOriginal());
|
|
auto replacementOps = state.getPayloadOps(getReplacement());
|
|
if (llvm::range_size(originalOps) != llvm::range_size(replacementOps))
|
|
return emitSilenceableError() << "expected same number of original and "
|
|
"replacement payload operations";
|
|
for (const auto &[original, replacement] :
|
|
llvm::zip(originalOps, replacementOps)) {
|
|
if (failed(
|
|
rewriter.notifyPayloadOperationReplaced(original, replacement))) {
|
|
auto diag = emitSilenceableError()
|
|
<< "unable to replace payload op in transform mapping";
|
|
diag.attachNote(original->getLoc()) << "original payload op";
|
|
diag.attachNote(replacement->getLoc()) << "replacement payload op";
|
|
return diag;
|
|
}
|
|
}
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestNotifyPayloadOpReplacedOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::onlyReadsHandle(getOriginal(), effects);
|
|
transform::onlyReadsHandle(getReplacement(), effects);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure mlir::test::TestProduceInvalidIR::applyToOne(
|
|
transform::TransformRewriter &rewriter, Operation *target,
|
|
transform::ApplyToEachResultList &results,
|
|
transform::TransformState &state) {
|
|
// Provide some IR that does not verify.
|
|
rewriter.setInsertionPointToStart(&target->getRegion(0).front());
|
|
rewriter.create<TestDummyPayloadOp>(target->getLoc(), TypeRange(),
|
|
ValueRange(), /*failToVerify=*/true);
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void mlir::test::TestProduceInvalidIR::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
transform::onlyReadsHandle(getTarget(), effects);
|
|
transform::modifiesPayload(effects);
|
|
}
|
|
|
|
namespace {
|
|
/// Test conversion pattern that replaces ops with the "replace_with_new_op"
|
|
/// attribute with "test.new_op".
|
|
class ReplaceWithNewOpConversion : public ConversionPattern {
|
|
public:
|
|
ReplaceWithNewOpConversion(TypeConverter &typeConverter, MLIRContext *context)
|
|
: ConversionPattern(typeConverter, RewritePattern::MatchAnyOpTypeTag(),
|
|
/*benefit=*/1, context) {}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
if (!op->hasAttr("replace_with_new_op"))
|
|
return failure();
|
|
SmallVector<Type> newResultTypes;
|
|
if (failed(getTypeConverter()->convertTypes(op->getResultTypes(),
|
|
newResultTypes)))
|
|
return failure();
|
|
Operation *newOp = rewriter.create(
|
|
op->getLoc(),
|
|
OperationName("test.new_op", op->getContext()).getIdentifier(),
|
|
operands, newResultTypes);
|
|
rewriter.replaceOp(op, newOp->getResults());
|
|
return success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
void mlir::test::ApplyTestConversionPatternsOp::populatePatterns(
|
|
TypeConverter &typeConverter, RewritePatternSet &patterns) {
|
|
patterns.insert<ReplaceWithNewOpConversion>(typeConverter,
|
|
patterns.getContext());
|
|
}
|
|
|
|
namespace {
|
|
/// Test type converter that converts tensor types to memref types.
|
|
class TestTypeConverter : public TypeConverter {
|
|
public:
|
|
TestTypeConverter() {
|
|
addConversion([](Type t) { return t; });
|
|
addConversion([](RankedTensorType type) -> Type {
|
|
return MemRefType::get(type.getShape(), type.getElementType());
|
|
});
|
|
auto unrealizedCastConverter = [&](OpBuilder &builder, Type resultType,
|
|
ValueRange inputs,
|
|
Location loc) -> std::optional<Value> {
|
|
if (inputs.size() != 1)
|
|
return std::nullopt;
|
|
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
|
|
.getResult(0);
|
|
};
|
|
addSourceMaterialization(unrealizedCastConverter);
|
|
addTargetMaterialization(unrealizedCastConverter);
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
std::unique_ptr<::mlir::TypeConverter>
|
|
mlir::test::TestTypeConverterOp::getTypeConverter() {
|
|
return std::make_unique<TestTypeConverter>();
|
|
}
|
|
|
|
namespace {
|
|
/// Test extension of the Transform dialect. Registers additional ops and
|
|
/// declares PDL as dependent dialect since the additional ops are using PDL
|
|
/// types for operands and results.
|
|
class TestTransformDialectExtension
|
|
: public transform::TransformDialectExtension<
|
|
TestTransformDialectExtension> {
|
|
public:
|
|
using Base::Base;
|
|
|
|
void init() {
|
|
declareDependentDialect<pdl::PDLDialect>();
|
|
registerTransformOps<TestTransformOp,
|
|
TestTransformUnrestrictedOpNoInterface,
|
|
#define GET_OP_LIST
|
|
#include "TestTransformDialectExtension.cpp.inc"
|
|
>();
|
|
registerTypes<
|
|
#define GET_TYPEDEF_LIST
|
|
#include "TestTransformDialectExtensionTypes.cpp.inc"
|
|
>();
|
|
|
|
auto verboseConstraint = [](PatternRewriter &rewriter, PDLResultList &,
|
|
ArrayRef<PDLValue> pdlValues) {
|
|
for (const PDLValue &pdlValue : pdlValues) {
|
|
if (Operation *op = pdlValue.dyn_cast<Operation *>()) {
|
|
op->emitWarning() << "from PDL constraint";
|
|
}
|
|
}
|
|
return success();
|
|
};
|
|
|
|
addDialectDataInitializer<transform::PDLMatchHooks>(
|
|
[&](transform::PDLMatchHooks &hooks) {
|
|
llvm::StringMap<PDLConstraintFunction> constraints;
|
|
constraints.try_emplace("verbose_constraint", verboseConstraint);
|
|
hooks.mergeInPDLMatchHooks(std::move(constraints));
|
|
});
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
// These are automatically generated by ODS but are not used as the Transform
|
|
// dialect uses a different dispatch mechanism to support dialect extensions.
|
|
LLVM_ATTRIBUTE_UNUSED static OptionalParseResult
|
|
generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value);
|
|
LLVM_ATTRIBUTE_UNUSED static LogicalResult
|
|
generatedTypePrinter(Type def, AsmPrinter &printer);
|
|
|
|
#define GET_TYPEDEF_CLASSES
|
|
#include "TestTransformDialectExtensionTypes.cpp.inc"
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "TestTransformDialectExtension.cpp.inc"
|
|
|
|
void ::test::registerTestTransformDialectExtension(DialectRegistry ®istry) {
|
|
registry.addExtensions<TestTransformDialectExtension>();
|
|
}
|