Oleksandr "Alex" Zinenko e4b04b391f
[mlir] make transform.foreach_match forward arguments (#89920)
It may be useful to have access to additional handles or parameters when
performing matches and actions in `foreach_match`, for example, to
parameterize the matcher by rank or restrict it in a non-trivial way.
Enable `foreach_match` to forward additional handles from operands to
matcher symbols and from action symbols to results.
2024-05-03 10:15:44 +02:00

2793 lines
110 KiB
C++

//===- TransformOps.cpp - Transform dialect operations --------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/CSE.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include <optional>
#define DEBUG_TYPE "transform-dialect"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
#define DEBUG_TYPE_MATCHER "transform-matcher"
#define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ")
#define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x)
using namespace mlir;
static ParseResult parseSequenceOpOperands(
OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
Type &rootType,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
SmallVectorImpl<Type> &extraBindingTypes);
static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
Value root, Type rootType,
ValueRange extraBindings,
TypeRange extraBindingTypes);
static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op,
ArrayAttr matchers, ArrayAttr actions);
static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
ArrayAttr &matchers,
ArrayAttr &actions);
/// Helper function to check if the given transform op is contained in (or
/// equal to) the given payload target op. In that case, an error is returned.
/// Transforming transform IR that is currently executing is generally unsafe.
static DiagnosedSilenceableFailure
ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform,
Operation *payload) {
Operation *transformAncestor = transform.getOperation();
while (transformAncestor) {
if (transformAncestor == payload) {
DiagnosedDefiniteFailure diag =
transform.emitDefiniteFailure()
<< "cannot apply transform to itself (or one of its ancestors)";
diag.attachNote(payload->getLoc()) << "target payload op";
return diag;
}
transformAncestor = transformAncestor->getParentOp();
}
return DiagnosedSilenceableFailure::success();
}
#define GET_OP_CLASSES
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
//===----------------------------------------------------------------------===//
// AlternativesOp
//===----------------------------------------------------------------------===//
OperandRange
transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) {
if (!point.isParent() && getOperation()->getNumOperands() == 1)
return getOperation()->getOperands();
return OperandRange(getOperation()->operand_end(),
getOperation()->operand_end());
}
void transform::AlternativesOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
for (Region &alternative : llvm::drop_begin(
getAlternatives(),
point.isParent() ? 0
: point.getRegionOrNull()->getRegionNumber() + 1)) {
regions.emplace_back(&alternative, !getOperands().empty()
? alternative.getArguments()
: Block::BlockArgListType());
}
if (!point.isParent())
regions.emplace_back(getOperation()->getResults());
}
void transform::AlternativesOp::getRegionInvocationBounds(
ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
(void)operands;
// The region corresponding to the first alternative is always executed, the
// remaining may or may not be executed.
bounds.reserve(getNumRegions());
bounds.emplace_back(1, 1);
bounds.resize(getNumRegions(), InvocationBounds(0, 1));
}
static void forwardEmptyOperands(Block *block, transform::TransformState &state,
transform::TransformResults &results) {
for (const auto &res : block->getParentOp()->getOpResults())
results.set(res, {});
}
DiagnosedSilenceableFailure
transform::AlternativesOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<Operation *> originals;
if (Value scopeHandle = getScope())
llvm::append_range(originals, state.getPayloadOps(scopeHandle));
else
originals.push_back(state.getTopLevel());
for (Operation *original : originals) {
if (original->isAncestor(getOperation())) {
auto diag = emitDefiniteFailure()
<< "scope must not contain the transforms being applied";
diag.attachNote(original->getLoc()) << "scope";
return diag;
}
if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
auto diag = emitDefiniteFailure()
<< "only isolated-from-above ops can be alternative scopes";
diag.attachNote(original->getLoc()) << "scope";
return diag;
}
}
for (Region &reg : getAlternatives()) {
// Clone the scope operations and make the transforms in this alternative
// region apply to them by virtue of mapping the block argument (the only
// visible handle) to the cloned scope operations. This effectively prevents
// the transformation from accessing any IR outside the scope.
auto scope = state.make_region_scope(reg);
auto clones = llvm::to_vector(
llvm::map_range(originals, [](Operation *op) { return op->clone(); }));
auto deleteClones = llvm::make_scope_exit([&] {
for (Operation *clone : clones)
clone->erase();
});
if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
return DiagnosedSilenceableFailure::definiteFailure();
bool failed = false;
for (Operation &transform : reg.front().without_terminator()) {
DiagnosedSilenceableFailure result =
state.applyTransform(cast<TransformOpInterface>(transform));
if (result.isSilenceableFailure()) {
LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage()
<< "\n");
failed = true;
break;
}
if (::mlir::failed(result.silence()))
return DiagnosedSilenceableFailure::definiteFailure();
}
// If all operations in the given alternative succeeded, no need to consider
// the rest. Replace the original scoping operation with the clone on which
// the transformations were performed.
if (!failed) {
// We will be using the clones, so cancel their scheduled deletion.
deleteClones.release();
TrackingListener listener(state, *this);
IRRewriter rewriter(getContext(), &listener);
for (const auto &kvp : llvm::zip(originals, clones)) {
Operation *original = std::get<0>(kvp);
Operation *clone = std::get<1>(kvp);
original->getBlock()->getOperations().insert(original->getIterator(),
clone);
rewriter.replaceOp(original, clone->getResults());
}
detail::forwardTerminatorOperands(&reg.front(), state, results);
return DiagnosedSilenceableFailure::success();
}
}
return emitSilenceableError() << "all alternatives failed";
}
void transform::AlternativesOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
consumesHandle(getOperands(), effects);
producesHandle(getResults(), effects);
for (Region *region : getRegions()) {
if (!region->empty())
producesHandle(region->front().getArguments(), effects);
}
modifiesPayload(effects);
}
LogicalResult transform::AlternativesOp::verify() {
for (Region &alternative : getAlternatives()) {
Block &block = alternative.front();
Operation *terminator = block.getTerminator();
if (terminator->getOperands().getTypes() != getResults().getTypes()) {
InFlightDiagnostic diag = emitOpError()
<< "expects terminator operands to have the "
"same type as results of the operation";
diag.attachNote(terminator->getLoc()) << "terminator";
return diag;
}
}
return success();
}
//===----------------------------------------------------------------------===//
// AnnotateOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::AnnotateOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<Operation *> targets =
llvm::to_vector(state.getPayloadOps(getTarget()));
Attribute attr = UnitAttr::get(getContext());
if (auto paramH = getParam()) {
ArrayRef<Attribute> params = state.getParams(paramH);
if (params.size() != 1) {
if (targets.size() != params.size()) {
return emitSilenceableError()
<< "parameter and target have different payload lengths ("
<< params.size() << " vs " << targets.size() << ")";
}
for (auto &&[target, attr] : llvm::zip_equal(targets, params))
target->setAttr(getName(), attr);
return DiagnosedSilenceableFailure::success();
}
attr = params[0];
}
for (auto *target : targets)
target->setAttr(getName(), attr);
return DiagnosedSilenceableFailure::success();
}
void transform::AnnotateOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getTarget(), effects);
onlyReadsHandle(getParam(), effects);
modifiesPayload(effects);
}
//===----------------------------------------------------------------------===//
// ApplyCommonSubexpressionEliminationOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
transform::TransformRewriter &rewriter, Operation *target,
ApplyToEachResultList &results, transform::TransformState &state) {
// Make sure that this transform is not applied to itself. Modifying the
// transform IR while it is being interpreted is generally dangerous.
DiagnosedSilenceableFailure payloadCheck =
ensurePayloadIsSeparateFromTransform(*this, target);
if (!payloadCheck.succeeded())
return payloadCheck;
DominanceInfo domInfo;
mlir::eliminateCommonSubExpressions(rewriter, domInfo, target);
return DiagnosedSilenceableFailure::success();
}
void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getTarget(), effects);
transform::modifiesPayload(effects);
}
//===----------------------------------------------------------------------===//
// ApplyDeadCodeEliminationOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform::ApplyDeadCodeEliminationOp::applyToOne(
transform::TransformRewriter &rewriter, Operation *target,
ApplyToEachResultList &results, transform::TransformState &state) {
// Make sure that this transform is not applied to itself. Modifying the
// transform IR while it is being interpreted is generally dangerous.
DiagnosedSilenceableFailure payloadCheck =
ensurePayloadIsSeparateFromTransform(*this, target);
if (!payloadCheck.succeeded())
return payloadCheck;
// Maintain a worklist of potentially dead ops.
SetVector<Operation *> worklist;
// Helper function that adds all defining ops of used values (operands and
// operands of nested ops).
auto addDefiningOpsToWorklist = [&](Operation *op) {
op->walk([&](Operation *op) {
for (Value v : op->getOperands())
if (Operation *defOp = v.getDefiningOp())
if (target->isProperAncestor(defOp))
worklist.insert(defOp);
});
};
// Helper function that erases an op.
auto eraseOp = [&](Operation *op) {
// Remove op and nested ops from the worklist.
op->walk([&](Operation *op) {
const auto *it = llvm::find(worklist, op);
if (it != worklist.end())
worklist.erase(it);
});
rewriter.eraseOp(op);
};
// Initial walk over the IR.
target->walk<WalkOrder::PostOrder>([&](Operation *op) {
if (op != target && isOpTriviallyDead(op)) {
addDefiningOpsToWorklist(op);
eraseOp(op);
}
});
// Erase all ops that have become dead.
while (!worklist.empty()) {
Operation *op = worklist.pop_back_val();
if (!isOpTriviallyDead(op))
continue;
addDefiningOpsToWorklist(op);
eraseOp(op);
}
return DiagnosedSilenceableFailure::success();
}
void transform::ApplyDeadCodeEliminationOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getTarget(), effects);
transform::modifiesPayload(effects);
}
//===----------------------------------------------------------------------===//
// ApplyPatternsOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
transform::TransformRewriter &rewriter, Operation *target,
ApplyToEachResultList &results, transform::TransformState &state) {
// Make sure that this transform is not applied to itself. Modifying the
// transform IR while it is being interpreted is generally dangerous. Even
// more so for the ApplyPatternsOp because the GreedyPatternRewriteDriver
// performs many additional simplifications such as dead code elimination.
DiagnosedSilenceableFailure payloadCheck =
ensurePayloadIsSeparateFromTransform(*this, target);
if (!payloadCheck.succeeded())
return payloadCheck;
// Gather all specified patterns.
MLIRContext *ctx = target->getContext();
RewritePatternSet patterns(ctx);
if (!getRegion().empty()) {
for (Operation &op : getRegion().front()) {
cast<transform::PatternDescriptorOpInterface>(&op)
.populatePatternsWithState(patterns, state);
}
}
// Configure the GreedyPatternRewriteDriver.
GreedyRewriteConfig config;
config.listener =
static_cast<RewriterBase::Listener *>(rewriter.getListener());
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
config.maxIterations = getMaxIterations() == static_cast<uint64_t>(-1)
? GreedyRewriteConfig::kNoLimit
: getMaxIterations();
config.maxNumRewrites = getMaxNumRewrites() == static_cast<uint64_t>(-1)
? GreedyRewriteConfig::kNoLimit
: getMaxNumRewrites();
// Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE
// was requested, apply the greedy pattern rewrite only once. (The greedy
// pattern rewrite driver already iterates to a fixpoint internally.)
bool cseChanged = false;
// One or two iterations should be sufficient. Stop iterating after a certain
// threshold to make debugging easier.
static const int64_t kNumMaxIterations = 50;
int64_t iteration = 0;
do {
LogicalResult result = failure();
if (target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
// Op is isolated from above. Apply patterns and also perform region
// simplification.
result = applyPatternsAndFoldGreedily(target, frozenPatterns, config);
} else {
// Manually gather list of ops because the other
// GreedyPatternRewriteDriver overloads only accepts ops that are isolated
// from above. This way, patterns can be applied to ops that are not
// isolated from above. Regions are not being simplified. Furthermore,
// only a single greedy rewrite iteration is performed.
SmallVector<Operation *> ops;
target->walk([&](Operation *nestedOp) {
if (target != nestedOp)
ops.push_back(nestedOp);
});
result = applyOpPatternsAndFold(ops, frozenPatterns, config);
}
// A failure typically indicates that the pattern application did not
// converge.
if (failed(result)) {
return emitSilenceableFailure(target)
<< "greedy pattern application failed";
}
if (getApplyCse()) {
DominanceInfo domInfo;
mlir::eliminateCommonSubExpressions(rewriter, domInfo, target,
&cseChanged);
}
} while (cseChanged && ++iteration < kNumMaxIterations);
if (iteration == kNumMaxIterations)
return emitDefiniteFailure() << "fixpoint iteration did not converge";
return DiagnosedSilenceableFailure::success();
}
LogicalResult transform::ApplyPatternsOp::verify() {
if (!getRegion().empty()) {
for (Operation &op : getRegion().front()) {
if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
InFlightDiagnostic diag = emitOpError()
<< "expected children ops to implement "
"PatternDescriptorOpInterface";
diag.attachNote(op.getLoc()) << "op without interface";
return diag;
}
}
}
return success();
}
void transform::ApplyPatternsOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getTarget(), effects);
transform::modifiesPayload(effects);
}
void transform::ApplyPatternsOp::build(
OpBuilder &builder, OperationState &result, Value target,
function_ref<void(OpBuilder &, Location)> bodyBuilder) {
result.addOperands(target);
OpBuilder::InsertionGuard g(builder);
Region *region = result.addRegion();
builder.createBlock(region);
if (bodyBuilder)
bodyBuilder(builder, result.location);
}
//===----------------------------------------------------------------------===//
// ApplyCanonicalizationPatternsOp
//===----------------------------------------------------------------------===//
void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
MLIRContext *ctx = patterns.getContext();
for (Dialect *dialect : ctx->getLoadedDialects())
dialect->getCanonicalizationPatterns(patterns);
for (RegisteredOperationName op : ctx->getRegisteredOperations())
op.getCanonicalizationPatterns(patterns, ctx);
}
//===----------------------------------------------------------------------===//
// ApplyConversionPatternsOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
MLIRContext *ctx = getContext();
// Instantiate the default type converter if a type converter builder is
// specified.
std::unique_ptr<TypeConverter> defaultTypeConverter;
transform::TypeConverterBuilderOpInterface typeConverterBuilder =
getDefaultTypeConverter();
if (typeConverterBuilder)
defaultTypeConverter = typeConverterBuilder.getTypeConverter();
// Configure conversion target.
ConversionTarget conversionTarget(*getContext());
if (getLegalOps())
for (Attribute attr : cast<ArrayAttr>(*getLegalOps()))
conversionTarget.addLegalOp(
OperationName(cast<StringAttr>(attr).getValue(), ctx));
if (getIllegalOps())
for (Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
conversionTarget.addIllegalOp(
OperationName(cast<StringAttr>(attr).getValue(), ctx));
if (getLegalDialects())
for (Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
if (getIllegalDialects())
for (Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
// Gather all specified patterns.
RewritePatternSet patterns(ctx);
// Need to keep the converters alive until after pattern application because
// the patterns take a reference to an object that would otherwise get out of
// scope.
SmallVector<std::unique_ptr<TypeConverter>> keepAliveConverters;
if (!getPatterns().empty()) {
for (Operation &op : getPatterns().front()) {
auto descriptor =
cast<transform::ConversionPatternDescriptorOpInterface>(&op);
// Check if this pattern set specifies a type converter.
std::unique_ptr<TypeConverter> typeConverter =
descriptor.getTypeConverter();
TypeConverter *converter = nullptr;
if (typeConverter) {
keepAliveConverters.emplace_back(std::move(typeConverter));
converter = keepAliveConverters.back().get();
} else {
// No type converter specified: Use the default type converter.
if (!defaultTypeConverter) {
auto diag = emitDefiniteFailure()
<< "pattern descriptor does not specify type "
"converter and apply_conversion_patterns op has "
"no default type converter";
diag.attachNote(op.getLoc()) << "pattern descriptor op";
return diag;
}
converter = defaultTypeConverter.get();
}
// Add descriptor-specific updates to the conversion target, which may
// depend on the final type converter. In structural converters, the
// legality of types dictates the dynamic legality of an operation.
descriptor.populateConversionTargetRules(*converter, conversionTarget);
descriptor.populatePatterns(*converter, patterns);
}
}
// Attach a tracking listener if handles should be preserved. We configure the
// listener to allow op replacements with different names, as conversion
// patterns typically replace ops with replacement ops that have a different
// name.
TrackingListenerConfig trackingConfig;
trackingConfig.requireMatchingReplacementOpName = false;
ErrorCheckingTrackingListener trackingListener(state, *this, trackingConfig);
ConversionConfig conversionConfig;
if (getPreserveHandles())
conversionConfig.listener = &trackingListener;
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
for (Operation *target : state.getPayloadOps(getTarget())) {
// Make sure that this transform is not applied to itself. Modifying the
// transform IR while it is being interpreted is generally dangerous.
DiagnosedSilenceableFailure payloadCheck =
ensurePayloadIsSeparateFromTransform(*this, target);
if (!payloadCheck.succeeded())
return payloadCheck;
LogicalResult status = failure();
if (getPartialConversion()) {
status = applyPartialConversion(target, conversionTarget, frozenPatterns,
conversionConfig);
} else {
status = applyFullConversion(target, conversionTarget, frozenPatterns,
conversionConfig);
}
// Check dialect conversion state.
DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success();
if (failed(status)) {
diag = emitSilenceableError() << "dialect conversion failed";
diag.attachNote(target->getLoc()) << "target op";
}
// Check tracking listener error state.
DiagnosedSilenceableFailure trackingFailure =
trackingListener.checkAndResetError();
if (!trackingFailure.succeeded()) {
if (diag.succeeded()) {
// Tracking failure is the only failure.
return trackingFailure;
} else {
diag.attachNote() << "tracking listener also failed: "
<< trackingFailure.getMessage();
(void)trackingFailure.silence();
}
}
if (!diag.succeeded())
return diag;
}
return DiagnosedSilenceableFailure::success();
}
LogicalResult transform::ApplyConversionPatternsOp::verify() {
if (getNumRegions() != 1 && getNumRegions() != 2)
return emitOpError() << "expected 1 or 2 regions";
if (!getPatterns().empty()) {
for (Operation &op : getPatterns().front()) {
if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
InFlightDiagnostic diag =
emitOpError() << "expected pattern children ops to implement "
"ConversionPatternDescriptorOpInterface";
diag.attachNote(op.getLoc()) << "op without interface";
return diag;
}
}
}
if (getNumRegions() == 2) {
Region &typeConverterRegion = getRegion(1);
if (!llvm::hasSingleElement(typeConverterRegion.front()))
return emitOpError()
<< "expected exactly one op in default type converter region";
auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
&typeConverterRegion.front().front());
if (!typeConverterOp) {
InFlightDiagnostic diag = emitOpError()
<< "expected default converter child op to "
"implement TypeConverterBuilderOpInterface";
diag.attachNote(typeConverterOp->getLoc()) << "op without interface";
return diag;
}
// Check default type converter type.
if (!getPatterns().empty()) {
for (Operation &op : getPatterns().front()) {
auto descriptor =
cast<transform::ConversionPatternDescriptorOpInterface>(&op);
if (failed(descriptor.verifyTypeConverter(typeConverterOp)))
return failure();
}
}
}
return success();
}
void transform::ApplyConversionPatternsOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
if (!getPreserveHandles()) {
transform::consumesHandle(getTarget(), effects);
} else {
transform::onlyReadsHandle(getTarget(), effects);
}
transform::modifiesPayload(effects);
}
void transform::ApplyConversionPatternsOp::build(
OpBuilder &builder, OperationState &result, Value target,
function_ref<void(OpBuilder &, Location)> patternsBodyBuilder,
function_ref<void(OpBuilder &, Location)> typeConverterBodyBuilder) {
result.addOperands(target);
{
OpBuilder::InsertionGuard g(builder);
Region *region1 = result.addRegion();
builder.createBlock(region1);
if (patternsBodyBuilder)
patternsBodyBuilder(builder, result.location);
}
{
OpBuilder::InsertionGuard g(builder);
Region *region2 = result.addRegion();
builder.createBlock(region2);
if (typeConverterBodyBuilder)
typeConverterBodyBuilder(builder, result.location);
}
}
//===----------------------------------------------------------------------===//
// ApplyToLLVMConversionPatternsOp
//===----------------------------------------------------------------------===//
void transform::ApplyToLLVMConversionPatternsOp::populatePatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns) {
Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
assert(dialect && "expected that dialect is loaded");
auto *iface = cast<ConvertToLLVMPatternInterface>(dialect);
// ConversionTarget is currently ignored because the enclosing
// apply_conversion_patterns op sets up its own ConversionTarget.
ConversionTarget target(*getContext());
iface->populateConvertToLLVMConversionPatterns(
target, static_cast<LLVMTypeConverter &>(typeConverter), patterns);
}
LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter(
transform::TypeConverterBuilderOpInterface builder) {
if (builder.getTypeConverterType() != "LLVMTypeConverter")
return emitOpError("expected LLVMTypeConverter");
return success();
}
LogicalResult transform::ApplyToLLVMConversionPatternsOp::verify() {
Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
if (!dialect)
return emitOpError("unknown dialect or dialect not loaded: ")
<< getDialectName();
auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
if (!iface)
return emitOpError(
"dialect does not implement ConvertToLLVMPatternInterface or "
"extension was not loaded: ")
<< getDialectName();
return success();
}
//===----------------------------------------------------------------------===//
// ApplyLoopInvariantCodeMotionOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::ApplyLoopInvariantCodeMotionOp::applyToOne(
transform::TransformRewriter &rewriter, LoopLikeOpInterface target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
// Currently, LICM does not remove operations, so we don't need tracking.
// If this ever changes, add a LICM entry point that takes a rewriter.
moveLoopInvariantCode(target);
return DiagnosedSilenceableFailure::success();
}
void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getTarget(), effects);
transform::modifiesPayload(effects);
}
//===----------------------------------------------------------------------===//
// ApplyRegisteredPassOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
transform::TransformRewriter &rewriter, Operation *target,
ApplyToEachResultList &results, transform::TransformState &state) {
// Make sure that this transform is not applied to itself. Modifying the
// transform IR while it is being interpreted is generally dangerous. Even
// more so when applying passes because they may perform a wide range of IR
// modifications.
DiagnosedSilenceableFailure payloadCheck =
ensurePayloadIsSeparateFromTransform(*this, target);
if (!payloadCheck.succeeded())
return payloadCheck;
// Get pass or pass pipeline from registry.
const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName());
if (!info)
info = PassInfo::lookup(getPassName());
if (!info)
return emitDefiniteFailure()
<< "unknown pass or pass pipeline: " << getPassName();
// Create pass manager and run the pass or pass pipeline.
PassManager pm(getContext());
if (failed(info->addToPipeline(pm, getOptions(), [&](const Twine &msg) {
emitError(msg);
return failure();
}))) {
return emitDefiniteFailure()
<< "failed to add pass or pass pipeline to pipeline: "
<< getPassName();
}
if (failed(pm.run(target))) {
auto diag = emitSilenceableError() << "pass pipeline failed";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}
results.push_back(target);
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// CastOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::CastOp::applyToOne(transform::TransformRewriter &rewriter,
Operation *target, ApplyToEachResultList &results,
transform::TransformState &state) {
results.push_back(target);
return DiagnosedSilenceableFailure::success();
}
void transform::CastOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsPayload(effects);
onlyReadsHandle(getInput(), effects);
producesHandle(getOutput(), effects);
}
bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
assert(inputs.size() == 1 && "expected one input");
assert(outputs.size() == 1 && "expected one output");
return llvm::all_of(
std::initializer_list<Type>{inputs.front(), outputs.front()},
llvm::IsaPred<transform::TransformHandleTypeInterface>);
}
//===----------------------------------------------------------------------===//
// CollectMatchingOp
//===----------------------------------------------------------------------===//
/// Applies matcher operations from the given `block` using
/// `blockArgumentMapping` to initialize block arguments. Updates `state`
/// accordingly. If any of the matcher produces a silenceable failure, discards
/// it (printing the content to the debug output stream) and returns failure. If
/// any of the matchers produces a definite failure, reports it and returns
/// failure. If all matchers in the block succeed, populates `mappings` with the
/// payload entities associated with the block terminator operands. Note that
/// `mappings` will be cleared before that.
static DiagnosedSilenceableFailure
matchBlock(Block &block,
ArrayRef<SmallVector<transform::MappedValue>> blockArgumentMapping,
transform::TransformState &state,
SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings) {
assert(block.getParent() && "cannot match using a detached block");
auto matchScope = state.make_region_scope(*block.getParent());
if (failed(
state.mapBlockArguments(block.getArguments(), blockArgumentMapping)))
return DiagnosedSilenceableFailure::definiteFailure();
for (Operation &match : block.without_terminator()) {
if (!isa<transform::MatchOpInterface>(match)) {
return emitDefiniteFailure(match.getLoc())
<< "expected operations in the match part to "
"implement MatchOpInterface";
}
DiagnosedSilenceableFailure diag =
state.applyTransform(cast<transform::TransformOpInterface>(match));
if (diag.succeeded())
continue;
return diag;
}
// Remember the values mapped to the terminator operands so we can
// forward them to the action.
ValueRange yieldedValues = block.getTerminator()->getOperands();
// Our contract with the caller is that the mappings will contain only the
// newly mapped values, clear the rest.
mappings.clear();
transform::detail::prepareValueMappings(mappings, yieldedValues, state);
return DiagnosedSilenceableFailure::success();
}
/// Returns `true` if both types implement one of the interfaces provided as
/// template parameters.
template <typename... Tys>
static bool implementSameInterface(Type t1, Type t2) {
return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
}
/// Returns `true` if both types implement one of the transform dialect
/// interfaces.
static bool implementSameTransformInterface(Type t1, Type t2) {
return implementSameInterface<transform::TransformHandleTypeInterface,
transform::TransformParamTypeInterface,
transform::TransformValueHandleTypeInterface>(
t1, t2);
}
//===----------------------------------------------------------------------===//
// CollectMatchingOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
getOperation(), getMatcher());
if (matcher.isExternal()) {
return emitDefiniteFailure()
<< "unresolved external symbol " << getMatcher();
}
SmallVector<SmallVector<MappedValue>, 2> rawResults;
rawResults.resize(getOperation()->getNumResults());
std::optional<DiagnosedSilenceableFailure> maybeFailure;
for (Operation *root : state.getPayloadOps(getRoot())) {
WalkResult walkResult = root->walk([&](Operation *op) {
DEBUG_MATCHER({
DBGS_MATCHER() << "matching ";
op->print(llvm::dbgs(),
OpPrintingFlags().assumeVerified().skipRegions());
llvm::dbgs() << " @" << op << "\n";
});
// Try matching.
SmallVector<SmallVector<MappedValue>> mappings;
SmallVector<transform::MappedValue> inputMapping({op});
DiagnosedSilenceableFailure diag = matchBlock(
matcher.getFunctionBody().front(),
ArrayRef<SmallVector<transform::MappedValue>>(inputMapping), state,
mappings);
if (diag.isDefiniteFailure())
return WalkResult::interrupt();
if (diag.isSilenceableFailure()) {
DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
<< " failed: " << diag.getMessage());
return WalkResult::advance();
}
// If succeeded, collect results.
for (auto &&[i, mapping] : llvm::enumerate(mappings)) {
if (mapping.size() != 1) {
maybeFailure.emplace(emitSilenceableError()
<< "result #" << i << ", associated with "
<< mapping.size()
<< " payload objects, expected 1");
return WalkResult::interrupt();
}
rawResults[i].push_back(mapping[0]);
}
return WalkResult::advance();
});
if (walkResult.wasInterrupted())
return std::move(*maybeFailure);
assert(!maybeFailure && "failure set but the walk was not interrupted");
for (auto &&[opResult, rawResult] :
llvm::zip_equal(getOperation()->getResults(), rawResults)) {
results.setMappedValues(opResult, rawResult);
}
}
return DiagnosedSilenceableFailure::success();
}
void transform::CollectMatchingOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getRoot(), effects);
producesHandle(getResults(), effects);
onlyReadsPayload(effects);
}
LogicalResult transform::CollectMatchingOp::verifySymbolUses(
SymbolTableCollection &symbolTable) {
auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher()));
if (!matcherSymbol ||
!isa<TransformOpInterface>(matcherSymbol.getOperation()))
return emitError() << "unresolved matcher symbol " << getMatcher();
ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes();
if (argumentTypes.size() != 1 ||
!isa<TransformHandleTypeInterface>(argumentTypes[0])) {
return emitError()
<< "expected the matcher to take one operation handle argument";
}
if (!matcherSymbol.getArgAttr(
0, transform::TransformDialect::kArgReadOnlyAttrName)) {
return emitError() << "expected the matcher argument to be marked readonly";
}
ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes();
if (resultTypes.size() != getOperation()->getNumResults()) {
return emitError()
<< "expected the matcher to yield as many values as op has results ("
<< getOperation()->getNumResults() << "), got "
<< resultTypes.size();
}
for (auto &&[i, matcherType, resultType] :
llvm::enumerate(resultTypes, getOperation()->getResultTypes())) {
if (implementSameTransformInterface(matcherType, resultType))
continue;
return emitError()
<< "mismatching type interfaces for matcher result and op result #"
<< i;
}
return success();
}
//===----------------------------------------------------------------------===//
// ForeachMatchOp
//===----------------------------------------------------------------------===//
// This is fine because nothing is actually consumed by this op.
bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() { return true; }
DiagnosedSilenceableFailure
transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<std::pair<FunctionOpInterface, FunctionOpInterface>>
matchActionPairs;
matchActionPairs.reserve(getMatchers().size());
SymbolTableCollection symbolTable;
for (auto &&[matcher, action] :
llvm::zip_equal(getMatchers(), getActions())) {
auto matcherSymbol =
symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
getOperation(), cast<SymbolRefAttr>(matcher));
auto actionSymbol =
symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
getOperation(), cast<SymbolRefAttr>(action));
assert(matcherSymbol && actionSymbol &&
"unresolved symbols not caught by the verifier");
if (matcherSymbol.isExternal())
return emitDefiniteFailure() << "unresolved external symbol " << matcher;
if (actionSymbol.isExternal())
return emitDefiniteFailure() << "unresolved external symbol " << action;
matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
}
DiagnosedSilenceableFailure overallDiag =
DiagnosedSilenceableFailure::success();
SmallVector<SmallVector<MappedValue>> matchInputMapping;
SmallVector<SmallVector<MappedValue>> matchOutputMapping;
SmallVector<SmallVector<MappedValue>> actionResultMapping;
// Explicitly add the mapping for the first block argument (the op being
// matched).
matchInputMapping.emplace_back();
transform::detail::prepareValueMappings(matchInputMapping,
getForwardedInputs(), state);
SmallVector<MappedValue> &firstMatchArgument = matchInputMapping.front();
actionResultMapping.resize(getForwardedOutputs().size());
for (Operation *root : state.getPayloadOps(getRoot())) {
WalkResult walkResult = root->walk([&](Operation *op) {
// If getRestrictRoot is not present, skip over the root op itself so we
// don't invalidate it.
if (!getRestrictRoot() && op == root)
return WalkResult::advance();
DEBUG_MATCHER({
DBGS_MATCHER() << "matching ";
op->print(llvm::dbgs(),
OpPrintingFlags().assumeVerified().skipRegions());
llvm::dbgs() << " @" << op << "\n";
});
firstMatchArgument.clear();
firstMatchArgument.push_back(op);
// Try all the match/action pairs until the first successful match.
for (auto [matcher, action] : matchActionPairs) {
DiagnosedSilenceableFailure diag =
matchBlock(matcher.getFunctionBody().front(), matchInputMapping,
state, matchOutputMapping);
if (diag.isDefiniteFailure())
return WalkResult::interrupt();
if (diag.isSilenceableFailure()) {
DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
<< " failed: " << diag.getMessage());
continue;
}
auto scope = state.make_region_scope(action.getFunctionBody());
if (failed(state.mapBlockArguments(
action.getFunctionBody().front().getArguments(),
matchOutputMapping))) {
return WalkResult::interrupt();
}
for (Operation &transform :
action.getFunctionBody().front().without_terminator()) {
DiagnosedSilenceableFailure result =
state.applyTransform(cast<TransformOpInterface>(transform));
if (result.isDefiniteFailure())
return WalkResult::interrupt();
if (result.isSilenceableFailure()) {
if (overallDiag.succeeded()) {
overallDiag = emitSilenceableError() << "actions failed";
}
overallDiag.attachNote(action->getLoc())
<< "failed action: " << result.getMessage();
overallDiag.attachNote(op->getLoc())
<< "when applied to this matching payload";
(void)result.silence();
continue;
}
}
if (failed(detail::appendValueMappings(
MutableArrayRef<SmallVector<MappedValue>>(actionResultMapping),
action.getFunctionBody().front().getTerminator()->getOperands(),
state, getFlattenResults()))) {
emitDefiniteFailure()
<< "action @" << action.getName()
<< " has results associated with multiple payload entities, "
"but flattening was not requested";
return WalkResult::interrupt();
}
break;
}
return WalkResult::advance();
});
if (walkResult.wasInterrupted())
return DiagnosedSilenceableFailure::definiteFailure();
}
// The root operation should not have been affected, so we can just reassign
// the payload to the result. Note that we need to consume the root handle to
// make sure any handles to operations inside, that could have been affected
// by actions, are invalidated.
results.set(llvm::cast<OpResult>(getUpdated()),
state.getPayloadOps(getRoot()));
for (auto &&[result, mapping] :
llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) {
results.setMappedValues(result, mapping);
}
return overallDiag;
}
void transform::ForeachMatchOp::getAsmResultNames(
OpAsmSetValueNameFn setNameFn) {
setNameFn(getUpdated(), "updated_root");
for (Value v : getForwardedOutputs()) {
setNameFn(v, "yielded");
}
}
void transform::ForeachMatchOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
// Bail if invalid.
if (getOperation()->getNumOperands() < 1 ||
getOperation()->getNumResults() < 1) {
return modifiesPayload(effects);
}
consumesHandle(getRoot(), effects);
onlyReadsHandle(getForwardedInputs(), effects);
producesHandle(getResults(), effects);
modifiesPayload(effects);
}
/// Parses the comma-separated list of symbol reference pairs of the format
/// `@matcher -> @action`.
static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
ArrayAttr &matchers,
ArrayAttr &actions) {
StringAttr matcher;
StringAttr action;
SmallVector<Attribute> matcherList;
SmallVector<Attribute> actionList;
do {
if (parser.parseSymbolName(matcher) || parser.parseArrow() ||
parser.parseSymbolName(action)) {
return failure();
}
matcherList.push_back(SymbolRefAttr::get(matcher));
actionList.push_back(SymbolRefAttr::get(action));
} while (parser.parseOptionalComma().succeeded());
matchers = parser.getBuilder().getArrayAttr(matcherList);
actions = parser.getBuilder().getArrayAttr(actionList);
return success();
}
/// Prints the comma-separated list of symbol reference pairs of the format
/// `@matcher -> @action`.
static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op,
ArrayAttr matchers, ArrayAttr actions) {
printer.increaseIndent();
printer.increaseIndent();
for (auto &&[matcher, action, idx] : llvm::zip_equal(
matchers, actions, llvm::seq<unsigned>(0, matchers.size()))) {
printer.printNewline();
printer << cast<SymbolRefAttr>(matcher) << " -> "
<< cast<SymbolRefAttr>(action);
if (idx != matchers.size() - 1)
printer << ", ";
}
printer.decreaseIndent();
printer.decreaseIndent();
}
LogicalResult transform::ForeachMatchOp::verify() {
if (getMatchers().size() != getActions().size())
return emitOpError() << "expected the same number of matchers and actions";
if (getMatchers().empty())
return emitOpError() << "expected at least one match/action pair";
llvm::SmallPtrSet<Attribute, 8> matcherNames;
for (Attribute name : getMatchers()) {
if (matcherNames.insert(name).second)
continue;
emitWarning() << "matcher " << name
<< " is used more than once, only the first match will apply";
}
return success();
}
/// Checks that the attributes of the function-like operation have correct
/// consumption effect annotations. If `alsoVerifyInternal`, checks for
/// annotations being present even if they can be inferred from the body.
static DiagnosedSilenceableFailure
verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings,
bool alsoVerifyInternal = false) {
auto transformOp = cast<transform::TransformOpInterface>(op.getOperation());
llvm::SmallDenseSet<unsigned> consumedArguments;
if (!op.isExternal()) {
transform::getConsumedBlockArguments(op.getFunctionBody().front(),
consumedArguments);
}
for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
bool isConsumed =
op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
nullptr;
bool isReadOnly =
op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
nullptr;
if (isConsumed && isReadOnly) {
return transformOp.emitSilenceableError()
<< "argument #" << i << " cannot be both readonly and consumed";
}
if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
return transformOp.emitSilenceableError()
<< "must provide consumed/readonly status for arguments of "
"external or called ops";
}
if (op.isExternal())
continue;
if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
return transformOp.emitSilenceableError()
<< "argument #" << i
<< " is consumed in the body but is not marked as such";
}
if (emitWarnings && !consumedArguments.contains(i) && isConsumed) {
// Cannot use op.emitWarning() here as it would attempt to verify the op
// before printing, resulting in infinite recursion.
emitWarning(op->getLoc())
<< "op argument #" << i
<< " is not consumed in the body but is marked as consumed";
}
}
return DiagnosedSilenceableFailure::success();
}
LogicalResult transform::ForeachMatchOp::verifySymbolUses(
SymbolTableCollection &symbolTable) {
assert(getMatchers().size() == getActions().size());
auto consumedAttr =
StringAttr::get(getContext(), TransformDialect::kArgConsumedAttrName);
for (auto &&[matcher, action] :
llvm::zip_equal(getMatchers(), getActions())) {
// Presence and typing.
auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
symbolTable.lookupNearestSymbolFrom(getOperation(),
cast<SymbolRefAttr>(matcher)));
auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>(
symbolTable.lookupNearestSymbolFrom(getOperation(),
cast<SymbolRefAttr>(action)));
if (!matcherSymbol ||
!isa<TransformOpInterface>(matcherSymbol.getOperation()))
return emitError() << "unresolved matcher symbol " << matcher;
if (!actionSymbol ||
!isa<TransformOpInterface>(actionSymbol.getOperation()))
return emitError() << "unresolved action symbol " << action;
if (failed(verifyFunctionLikeConsumeAnnotations(matcherSymbol,
/*emitWarnings=*/false,
/*alsoVerifyInternal=*/true)
.checkAndReport())) {
return failure();
}
if (failed(verifyFunctionLikeConsumeAnnotations(actionSymbol,
/*emitWarnings=*/false,
/*alsoVerifyInternal=*/true)
.checkAndReport())) {
return failure();
}
// Input -> matcher forwarding.
TypeRange operandTypes = getOperandTypes();
TypeRange matcherArguments = matcherSymbol.getArgumentTypes();
if (operandTypes.size() != matcherArguments.size()) {
InFlightDiagnostic diag =
emitError() << "the number of operands (" << operandTypes.size()
<< ") doesn't match the number of matcher arguments ("
<< matcherArguments.size() << ") for " << matcher;
diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
return diag;
}
for (auto &&[i, operand, argument] :
llvm::enumerate(operandTypes, matcherArguments)) {
if (matcherSymbol.getArgAttr(i, consumedAttr)) {
InFlightDiagnostic diag =
emitOpError()
<< "does not expect matcher symbol to consume its operand #" << i;
diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
return diag;
}
if (implementSameTransformInterface(operand, argument))
continue;
InFlightDiagnostic diag =
emitError()
<< "mismatching type interfaces for operand and matcher argument #"
<< i << " of matcher " << matcher;
diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
return diag;
}
// Matcher -> action forwarding.
TypeRange matcherResults = matcherSymbol.getResultTypes();
TypeRange actionArguments = actionSymbol.getArgumentTypes();
if (matcherResults.size() != actionArguments.size()) {
return emitError() << "mismatching number of matcher results and "
"action arguments between "
<< matcher << " (" << matcherResults.size() << ") and "
<< action << " (" << actionArguments.size() << ")";
}
for (auto &&[i, matcherType, actionType] :
llvm::enumerate(matcherResults, actionArguments)) {
if (implementSameTransformInterface(matcherType, actionType))
continue;
return emitError() << "mismatching type interfaces for matcher result "
"and action argument #"
<< i << "of matcher " << matcher << " and action "
<< action;
}
// Action -> result forwarding.
TypeRange actionResults = actionSymbol.getResultTypes();
auto resultTypes = TypeRange(getResultTypes()).drop_front();
if (actionResults.size() != resultTypes.size()) {
InFlightDiagnostic diag =
emitError() << "the number of action results ("
<< actionResults.size() << ") for " << action
<< " doesn't match the number of extra op results ("
<< resultTypes.size() << ")";
diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
return diag;
}
for (auto &&[i, resultType, actionType] :
llvm::enumerate(resultTypes, actionResults)) {
if (implementSameTransformInterface(resultType, actionType))
continue;
InFlightDiagnostic diag =
emitError() << "mismatching type interfaces for action result #" << i
<< " of action " << action << " and op result";
diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
return diag;
}
}
return success();
}
//===----------------------------------------------------------------------===//
// ForeachOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<SmallVector<Operation *>> resultOps(getNumResults(), {});
// Store payload ops in a vector because ops may be removed from the mapping
// by the TrackingRewriter while the iteration is in progress.
SmallVector<Operation *> targets =
llvm::to_vector(state.getPayloadOps(getTarget()));
for (Operation *op : targets) {
auto scope = state.make_region_scope(getBody());
if (failed(state.mapBlockArguments(getIterationVariable(), {op})))
return DiagnosedSilenceableFailure::definiteFailure();
// Execute loop body.
for (Operation &transform : getBody().front().without_terminator()) {
DiagnosedSilenceableFailure result = state.applyTransform(
cast<transform::TransformOpInterface>(transform));
if (!result.succeeded())
return result;
}
// Append yielded payload ops to result list (if any).
for (unsigned i = 0; i < getNumResults(); ++i) {
auto yieldedOps = state.getPayloadOps(getYieldOp().getOperand(i));
resultOps[i].append(yieldedOps.begin(), yieldedOps.end());
}
}
for (unsigned i = 0; i < getNumResults(); ++i)
results.set(llvm::cast<OpResult>(getResult(i)), resultOps[i]);
return DiagnosedSilenceableFailure::success();
}
void transform::ForeachOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
BlockArgument iterVar = getIterationVariable();
if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
return isHandleConsumed(iterVar, cast<TransformOpInterface>(&op));
})) {
consumesHandle(getTarget(), effects);
} else {
onlyReadsHandle(getTarget(), effects);
}
if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
return doesModifyPayload(cast<TransformOpInterface>(&op));
})) {
modifiesPayload(effects);
} else if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
return doesReadPayload(cast<TransformOpInterface>(&op));
})) {
onlyReadsPayload(effects);
}
for (Value result : getResults())
producesHandle(result, effects);
}
void transform::ForeachOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
Region *bodyRegion = &getBody();
if (point.isParent()) {
regions.emplace_back(bodyRegion, bodyRegion->getArguments());
return;
}
// Branch back to the region or the parent.
assert(point == getBody() && "unexpected region index");
regions.emplace_back(bodyRegion, bodyRegion->getArguments());
regions.emplace_back();
}
OperandRange
transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) {
// The iteration variable op handle is mapped to a subset (one op to be
// precise) of the payload ops of the ForeachOp operand.
assert(point == getBody() && "unexpected region index");
return getOperation()->getOperands();
}
transform::YieldOp transform::ForeachOp::getYieldOp() {
return cast<transform::YieldOp>(getBody().front().getTerminator());
}
LogicalResult transform::ForeachOp::verify() {
auto yieldOp = getYieldOp();
if (getNumResults() != yieldOp.getNumOperands())
return emitOpError() << "expects the same number of results as the "
"terminator has operands";
for (Value v : yieldOp.getOperands())
if (!llvm::isa<TransformHandleTypeInterface>(v.getType()))
return yieldOp->emitOpError("expects operands to have types implementing "
"TransformHandleTypeInterface");
return success();
}
//===----------------------------------------------------------------------===//
// GetParentOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::GetParentOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<Operation *> parents;
DenseSet<Operation *> resultSet;
for (Operation *target : state.getPayloadOps(getTarget())) {
Operation *parent = target;
for (int64_t i = 0, e = getNthParent(); i < e; ++i) {
parent = parent->getParentOp();
while (parent) {
bool checkIsolatedFromAbove =
!getIsolatedFromAbove() ||
parent->hasTrait<OpTrait::IsIsolatedFromAbove>();
bool checkOpName = !getOpName().has_value() ||
parent->getName().getStringRef() == *getOpName();
if (checkIsolatedFromAbove && checkOpName)
break;
parent = parent->getParentOp();
}
if (!parent) {
if (getAllowEmptyResults()) {
results.set(llvm::cast<OpResult>(getResult()), parents);
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure diag =
emitSilenceableError()
<< "could not find a parent op that matches all requirements";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}
}
if (getDeduplicate()) {
if (!resultSet.contains(parent)) {
parents.push_back(parent);
resultSet.insert(parent);
}
} else {
parents.push_back(parent);
}
}
results.set(llvm::cast<OpResult>(getResult()), parents);
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// GetConsumersOfResult
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::GetConsumersOfResult::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
int64_t resultNumber = getResultNumber();
auto payloadOps = state.getPayloadOps(getTarget());
if (std::empty(payloadOps)) {
results.set(cast<OpResult>(getResult()), {});
return DiagnosedSilenceableFailure::success();
}
if (!llvm::hasSingleElement(payloadOps))
return emitDefiniteFailure()
<< "handle must be mapped to exactly one payload op";
Operation *target = *payloadOps.begin();
if (target->getNumResults() <= resultNumber)
return emitDefiniteFailure() << "result number overflow";
results.set(llvm::cast<OpResult>(getResult()),
llvm::to_vector(target->getResult(resultNumber).getUsers()));
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// GetDefiningOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::GetDefiningOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<Operation *> definingOps;
for (Value v : state.getPayloadValues(getTarget())) {
if (llvm::isa<BlockArgument>(v)) {
DiagnosedSilenceableFailure diag =
emitSilenceableError() << "cannot get defining op of block argument";
diag.attachNote(v.getLoc()) << "target value";
return diag;
}
definingOps.push_back(v.getDefiningOp());
}
results.set(llvm::cast<OpResult>(getResult()), definingOps);
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// GetProducerOfOperand
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
int64_t operandNumber = getOperandNumber();
SmallVector<Operation *> producers;
for (Operation *target : state.getPayloadOps(getTarget())) {
Operation *producer =
target->getNumOperands() <= operandNumber
? nullptr
: target->getOperand(operandNumber).getDefiningOp();
if (!producer) {
DiagnosedSilenceableFailure diag =
emitSilenceableError()
<< "could not find a producer for operand number: " << operandNumber
<< " of " << *target;
diag.attachNote(target->getLoc()) << "target op";
return diag;
}
producers.push_back(producer);
}
results.set(llvm::cast<OpResult>(getResult()), producers);
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// GetOperandOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::GetOperandOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<Value> operands;
for (Operation *target : state.getPayloadOps(getTarget())) {
SmallVector<int64_t> operandPositions;
DiagnosedSilenceableFailure diag = expandTargetSpecification(
getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
target->getNumOperands(), operandPositions);
if (diag.isSilenceableFailure()) {
diag.attachNote(target->getLoc())
<< "while considering positions of this payload operation";
return diag;
}
llvm::append_range(operands,
llvm::map_range(operandPositions, [&](int64_t pos) {
return target->getOperand(pos);
}));
}
results.setValues(cast<OpResult>(getResult()), operands);
return DiagnosedSilenceableFailure::success();
}
LogicalResult transform::GetOperandOp::verify() {
return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
getIsInverted(), getIsAll());
}
//===----------------------------------------------------------------------===//
// GetResultOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::GetResultOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<Value> opResults;
for (Operation *target : state.getPayloadOps(getTarget())) {
SmallVector<int64_t> resultPositions;
DiagnosedSilenceableFailure diag = expandTargetSpecification(
getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
target->getNumResults(), resultPositions);
if (diag.isSilenceableFailure()) {
diag.attachNote(target->getLoc())
<< "while considering positions of this payload operation";
return diag;
}
llvm::append_range(opResults,
llvm::map_range(resultPositions, [&](int64_t pos) {
return target->getResult(pos);
}));
}
results.setValues(cast<OpResult>(getResult()), opResults);
return DiagnosedSilenceableFailure::success();
}
LogicalResult transform::GetResultOp::verify() {
return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
getIsInverted(), getIsAll());
}
//===----------------------------------------------------------------------===//
// GetTypeOp
//===----------------------------------------------------------------------===//
void transform::GetTypeOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getValue(), effects);
producesHandle(getResult(), effects);
onlyReadsPayload(effects);
}
DiagnosedSilenceableFailure
transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<Attribute> params;
for (Value value : state.getPayloadValues(getValue())) {
Type type = value.getType();
if (getElemental()) {
if (auto shaped = dyn_cast<ShapedType>(type)) {
type = shaped.getElementType();
}
}
params.push_back(TypeAttr::get(type));
}
results.setParams(cast<OpResult>(getResult()), params);
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// IncludeOp
//===----------------------------------------------------------------------===//
/// Applies the transform ops contained in `block`. Maps `results` to the same
/// values as the operands of the block terminator.
static DiagnosedSilenceableFailure
applySequenceBlock(Block &block, transform::FailurePropagationMode mode,
transform::TransformState &state,
transform::TransformResults &results) {
// Apply the sequenced ops one by one.
for (Operation &transform : block.without_terminator()) {
DiagnosedSilenceableFailure result =
state.applyTransform(cast<transform::TransformOpInterface>(transform));
if (result.isDefiniteFailure())
return result;
if (result.isSilenceableFailure()) {
if (mode == transform::FailurePropagationMode::Propagate) {
// Propagate empty results in case of early exit.
forwardEmptyOperands(&block, state, results);
return result;
}
(void)result.silence();
}
}
// Forward the operation mapping for values yielded from the sequence to the
// values produced by the sequence op.
transform::detail::forwardTerminatorOperands(&block, state, results);
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure
transform::IncludeOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
getOperation(), getTarget());
assert(callee && "unverified reference to unknown symbol");
if (callee.isExternal())
return emitDefiniteFailure() << "unresolved external named sequence";
// Map operands to block arguments.
SmallVector<SmallVector<MappedValue>> mappings;
detail::prepareValueMappings(mappings, getOperands(), state);
auto scope = state.make_region_scope(callee.getBody());
for (auto &&[arg, map] :
llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
if (failed(state.mapBlockArgument(arg, map)))
return DiagnosedSilenceableFailure::definiteFailure();
}
DiagnosedSilenceableFailure result = applySequenceBlock(
callee.getBody().front(), getFailurePropagationMode(), state, results);
mappings.clear();
detail::prepareValueMappings(
mappings, callee.getBody().front().getTerminator()->getOperands(), state);
for (auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings))
results.setMappedValues(result, mapping);
return result;
}
static DiagnosedSilenceableFailure
verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings);
void transform::IncludeOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
// Always mark as modifying the payload.
// TODO: a mechanism to annotate effects on payload. Even when all handles are
// only read, the payload may still be modified, so we currently stay on the
// conservative side and always indicate modification. This may prevent some
// code reordering.
modifiesPayload(effects);
// Results are always produced.
producesHandle(getResults(), effects);
// Adds default effects to operands and results. This will be added if
// preconditions fail so the trait verifier doesn't complain about missing
// effects and the real precondition failure is reported later on.
auto defaultEffects = [&] { onlyReadsHandle(getOperands(), effects); };
// Bail if the callee is unknown. This may run as part of the verification
// process before we verified the validity of the callee or of this op.
auto target =
getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
if (!target)
return defaultEffects();
auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
getOperation(), getTarget());
if (!callee)
return defaultEffects();
DiagnosedSilenceableFailure earlyVerifierResult =
verifyNamedSequenceOp(callee, /*emitWarnings=*/false);
if (!earlyVerifierResult.succeeded()) {
(void)earlyVerifierResult.silence();
return defaultEffects();
}
for (unsigned i = 0, e = getNumOperands(); i < e; ++i) {
if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
consumesHandle(getOperand(i), effects);
else
onlyReadsHandle(getOperand(i), effects);
}
}
LogicalResult
transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// Access through indirection and do additional checking because this may be
// running before the main op verifier.
auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>("target");
if (!targetAttr)
return emitOpError() << "expects a 'target' symbol reference attribute";
auto target = symbolTable.lookupNearestSymbolFrom<transform::NamedSequenceOp>(
*this, targetAttr);
if (!target)
return emitOpError() << "does not reference a named transform sequence";
FunctionType fnType = target.getFunctionType();
if (fnType.getNumInputs() != getNumOperands())
return emitError("incorrect number of operands for callee");
for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
if (getOperand(i).getType() != fnType.getInput(i)) {
return emitOpError("operand type mismatch: expected operand type ")
<< fnType.getInput(i) << ", but provided "
<< getOperand(i).getType() << " for operand number " << i;
}
}
if (fnType.getNumResults() != getNumResults())
return emitError("incorrect number of results for callee");
for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
Type resultType = getResult(i).getType();
Type funcType = fnType.getResult(i);
if (!implementSameTransformInterface(resultType, funcType)) {
return emitOpError() << "type of result #" << i
<< " must implement the same transform dialect "
"interface as the corresponding callee result";
}
}
return verifyFunctionLikeConsumeAnnotations(
cast<FunctionOpInterface>(*target), /*emitWarnings=*/false,
/*alsoVerifyInternal=*/true)
.checkAndReport();
}
//===----------------------------------------------------------------------===//
// MatchOperationEmptyOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation(
::std::optional<::mlir::Operation *> maybeCurrent,
transform::TransformResults &results, transform::TransformState &state) {
if (!maybeCurrent.has_value()) {
DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp success\n"; });
return DiagnosedSilenceableFailure::success();
}
DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp failure\n"; });
return emitSilenceableError() << "operation is not empty";
}
//===----------------------------------------------------------------------===//
// MatchOperationNameOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform::MatchOperationNameOp::matchOperation(
Operation *current, transform::TransformResults &results,
transform::TransformState &state) {
StringRef currentOpName = current->getName().getStringRef();
for (auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
if (acceptedAttr.getValue() == currentOpName)
return DiagnosedSilenceableFailure::success();
}
return emitSilenceableError() << "wrong operation name";
}
//===----------------------------------------------------------------------===//
// MatchParamCmpIOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::MatchParamCmpIOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
auto signedAPIntAsString = [&](const APInt &value) {
std::string str;
llvm::raw_string_ostream os(str);
value.print(os, /*isSigned=*/true);
return os.str();
};
ArrayRef<Attribute> params = state.getParams(getParam());
ArrayRef<Attribute> references = state.getParams(getReference());
if (params.size() != references.size()) {
return emitSilenceableError()
<< "parameters have different payload lengths (" << params.size()
<< " vs " << references.size() << ")";
}
for (auto &&[i, param, reference] : llvm::enumerate(params, references)) {
auto intAttr = llvm::dyn_cast<IntegerAttr>(param);
auto refAttr = llvm::dyn_cast<IntegerAttr>(reference);
if (!intAttr || !refAttr) {
return emitDefiniteFailure()
<< "non-integer parameter value not expected";
}
if (intAttr.getType() != refAttr.getType()) {
return emitDefiniteFailure()
<< "mismatching integer attribute types in parameter #" << i;
}
APInt value = intAttr.getValue();
APInt refValue = refAttr.getValue();
// TODO: this copy will not be necessary in C++20.
int64_t position = i;
auto reportError = [&](StringRef direction) {
DiagnosedSilenceableFailure diag =
emitSilenceableError() << "expected parameter to be " << direction
<< " " << signedAPIntAsString(refValue)
<< ", got " << signedAPIntAsString(value);
diag.attachNote(getParam().getLoc())
<< "value # " << position
<< " associated with the parameter defined here";
return diag;
};
switch (getPredicate()) {
case MatchCmpIPredicate::eq:
if (value.eq(refValue))
break;
return reportError("equal to");
case MatchCmpIPredicate::ne:
if (value.ne(refValue))
break;
return reportError("not equal to");
case MatchCmpIPredicate::lt:
if (value.slt(refValue))
break;
return reportError("less than");
case MatchCmpIPredicate::le:
if (value.sle(refValue))
break;
return reportError("less than or equal to");
case MatchCmpIPredicate::gt:
if (value.sgt(refValue))
break;
return reportError("greater than");
case MatchCmpIPredicate::ge:
if (value.sge(refValue))
break;
return reportError("greater than or equal to");
}
}
return DiagnosedSilenceableFailure::success();
}
void transform::MatchParamCmpIOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getParam(), effects);
onlyReadsHandle(getReference(), effects);
}
//===----------------------------------------------------------------------===//
// ParamConstantOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::ParamConstantOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
results.setParams(cast<OpResult>(getParam()), {getValue()});
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// MergeHandlesOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::MergeHandlesOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
ValueRange handles = getHandles();
if (isa<TransformHandleTypeInterface>(handles.front().getType())) {
SmallVector<Operation *> operations;
for (Value operand : handles)
llvm::append_range(operations, state.getPayloadOps(operand));
if (!getDeduplicate()) {
results.set(llvm::cast<OpResult>(getResult()), operations);
return DiagnosedSilenceableFailure::success();
}
SetVector<Operation *> uniqued(operations.begin(), operations.end());
results.set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
return DiagnosedSilenceableFailure::success();
}
if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
SmallVector<Attribute> attrs;
for (Value attribute : handles)
llvm::append_range(attrs, state.getParams(attribute));
if (!getDeduplicate()) {
results.setParams(cast<OpResult>(getResult()), attrs);
return DiagnosedSilenceableFailure::success();
}
SetVector<Attribute> uniqued(attrs.begin(), attrs.end());
results.setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
return DiagnosedSilenceableFailure::success();
}
assert(
llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
"expected value handle type");
SmallVector<Value> payloadValues;
for (Value value : handles)
llvm::append_range(payloadValues, state.getPayloadValues(value));
if (!getDeduplicate()) {
results.setValues(cast<OpResult>(getResult()), payloadValues);
return DiagnosedSilenceableFailure::success();
}
SetVector<Value> uniqued(payloadValues.begin(), payloadValues.end());
results.setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
return DiagnosedSilenceableFailure::success();
}
bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
// Handles may be the same if deduplicating is enabled.
return getDeduplicate();
}
void transform::MergeHandlesOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getHandles(), effects);
producesHandle(getResult(), effects);
// There are no effects on the Payload IR as this is only a handle
// manipulation.
}
OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
if (getDeduplicate() || getHandles().size() != 1)
return {};
// If deduplication is not required and there is only one operand, it can be
// used directly instead of merging.
return getHandles().front();
}
//===----------------------------------------------------------------------===//
// NamedSequenceOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
if (isExternal())
return emitDefiniteFailure() << "unresolved external named sequence";
// Map the entry block argument to the list of operations.
// Note: this is the same implementation as PossibleTopLevelTransformOp but
// without attaching the interface / trait since that is tailored to a
// dangling top-level op that does not get "called".
auto scope = state.make_region_scope(getBody());
if (failed(detail::mapPossibleTopLevelTransformOpBlockArguments(
state, this->getOperation(), getBody())))
return DiagnosedSilenceableFailure::definiteFailure();
return applySequenceBlock(getBody().front(),
FailurePropagationMode::Propagate, state, results);
}
void transform::NamedSequenceOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
ParseResult transform::NamedSequenceOp::parse(OpAsmParser &parser,
OperationState &result) {
return function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false,
getFunctionTypeAttrName(result.name),
[](Builder &builder, ArrayRef<Type> inputs, ArrayRef<Type> results,
function_interface_impl::VariadicFlag,
std::string &) { return builder.getFunctionType(inputs, results); },
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
}
void transform::NamedSequenceOp::print(OpAsmPrinter &printer) {
function_interface_impl::printFunctionOp(
printer, cast<FunctionOpInterface>(getOperation()), /*isVariadic=*/false,
getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
getResAttrsAttrName());
}
/// Verifies that a symbol function-like transform dialect operation has the
/// signature and the terminator that have conforming types, i.e., types
/// implementing the same transform dialect type interface. If `allowExternal`
/// is set, allow external symbols (declarations) and don't check the terminator
/// as it may not exist.
static DiagnosedSilenceableFailure
verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal) {
if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
DiagnosedSilenceableFailure diag =
emitSilenceableFailure(op)
<< "cannot be defined inside another transform op";
diag.attachNote(parent.getLoc()) << "ancestor transform op";
return diag;
}
if (op.isExternal() || op.getFunctionBody().empty()) {
if (allowExternal)
return DiagnosedSilenceableFailure::success();
return emitSilenceableFailure(op) << "cannot be external";
}
if (op.getFunctionBody().front().empty())
return emitSilenceableFailure(op) << "expected a non-empty body block";
Operation *terminator = &op.getFunctionBody().front().back();
if (!isa<transform::YieldOp>(terminator)) {
DiagnosedSilenceableFailure diag = emitSilenceableFailure(op)
<< "expected '"
<< transform::YieldOp::getOperationName()
<< "' as terminator";
diag.attachNote(terminator->getLoc()) << "terminator";
return diag;
}
if (terminator->getNumOperands() != op.getResultTypes().size()) {
return emitSilenceableFailure(terminator)
<< "expected terminator to have as many operands as the parent op "
"has results";
}
for (auto [i, operandType, resultType] : llvm::zip_equal(
llvm::seq<unsigned>(0, terminator->getNumOperands()),
terminator->getOperands().getType(), op.getResultTypes())) {
if (operandType == resultType)
continue;
return emitSilenceableFailure(terminator)
<< "the type of the terminator operand #" << i
<< " must match the type of the corresponding parent op result ("
<< operandType << " vs " << resultType << ")";
}
return DiagnosedSilenceableFailure::success();
}
/// Verification of a NamedSequenceOp. This does not report the error
/// immediately, so it can be used to check for op's well-formedness before the
/// verifier runs, e.g., during trait verification.
static DiagnosedSilenceableFailure
verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings) {
if (Operation *parent = op->getParentWithTrait<OpTrait::SymbolTable>()) {
if (!parent->getAttr(
transform::TransformDialect::kWithNamedSequenceAttrName)) {
DiagnosedSilenceableFailure diag =
emitSilenceableFailure(op)
<< "expects the parent symbol table to have the '"
<< transform::TransformDialect::kWithNamedSequenceAttrName
<< "' attribute";
diag.attachNote(parent->getLoc()) << "symbol table operation";
return diag;
}
}
if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
DiagnosedSilenceableFailure diag =
emitSilenceableFailure(op)
<< "cannot be defined inside another transform op";
diag.attachNote(parent.getLoc()) << "ancestor transform op";
return diag;
}
if (op.isExternal() || op.getBody().empty())
return verifyFunctionLikeConsumeAnnotations(cast<FunctionOpInterface>(*op),
emitWarnings);
if (op.getBody().front().empty())
return emitSilenceableFailure(op) << "expected a non-empty body block";
Operation *terminator = &op.getBody().front().back();
if (!isa<transform::YieldOp>(terminator)) {
DiagnosedSilenceableFailure diag = emitSilenceableFailure(op)
<< "expected '"
<< transform::YieldOp::getOperationName()
<< "' as terminator";
diag.attachNote(terminator->getLoc()) << "terminator";
return diag;
}
if (terminator->getNumOperands() != op.getFunctionType().getNumResults()) {
return emitSilenceableFailure(terminator)
<< "expected terminator to have as many operands as the parent op "
"has results";
}
for (auto [i, operandType, resultType] :
llvm::zip_equal(llvm::seq<unsigned>(0, terminator->getNumOperands()),
terminator->getOperands().getType(),
op.getFunctionType().getResults())) {
if (operandType == resultType)
continue;
return emitSilenceableFailure(terminator)
<< "the type of the terminator operand #" << i
<< " must match the type of the corresponding parent op result ("
<< operandType << " vs " << resultType << ")";
}
auto funcOp = cast<FunctionOpInterface>(*op);
DiagnosedSilenceableFailure diag =
verifyFunctionLikeConsumeAnnotations(funcOp, emitWarnings);
if (!diag.succeeded())
return diag;
return verifyYieldingSingleBlockOp(funcOp,
/*allowExternal=*/true);
}
LogicalResult transform::NamedSequenceOp::verify() {
// Actual verification happens in a separate function for reusability.
return verifyNamedSequenceOp(*this, /*emitWarnings=*/true).checkAndReport();
}
template <typename FnTy>
static void buildSequenceBody(OpBuilder &builder, OperationState &state,
Type bbArgType, TypeRange extraBindingTypes,
FnTy bodyBuilder) {
SmallVector<Type> types;
types.reserve(1 + extraBindingTypes.size());
types.push_back(bbArgType);
llvm::append_range(types, extraBindingTypes);
OpBuilder::InsertionGuard guard(builder);
Region *region = state.regions.back().get();
Block *bodyBlock =
builder.createBlock(region, region->begin(), types,
SmallVector<Location>(types.size(), state.location));
// Populate body.
builder.setInsertionPointToStart(bodyBlock);
if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
} else {
bodyBuilder(builder, state.location, bodyBlock->getArgument(0),
bodyBlock->getArguments().drop_front());
}
}
void transform::NamedSequenceOp::build(OpBuilder &builder,
OperationState &state, StringRef symName,
Type rootType, TypeRange resultTypes,
SequenceBodyBuilderFn bodyBuilder,
ArrayRef<NamedAttribute> attrs,
ArrayRef<DictionaryAttr> argAttrs) {
state.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(symName));
state.addAttribute(getFunctionTypeAttrName(state.name),
TypeAttr::get(FunctionType::get(builder.getContext(),
rootType, resultTypes)));
state.attributes.append(attrs.begin(), attrs.end());
state.addRegion();
buildSequenceBody(builder, state, rootType,
/*extraBindingTypes=*/TypeRange(), bodyBuilder);
}
//===----------------------------------------------------------------------===//
// NumAssociationsOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
size_t numAssociations =
llvm::TypeSwitch<Type, size_t>(getHandle().getType())
.Case([&](TransformHandleTypeInterface opHandle) {
return llvm::range_size(state.getPayloadOps(getHandle()));
})
.Case([&](TransformValueHandleTypeInterface valueHandle) {
return llvm::range_size(state.getPayloadValues(getHandle()));
})
.Case([&](TransformParamTypeInterface param) {
return llvm::range_size(state.getParams(getHandle()));
})
.Default([](Type) {
llvm_unreachable("unknown kind of transform dialect type");
return 0;
});
results.setParams(cast<OpResult>(getNum()),
rewriter.getI64IntegerAttr(numAssociations));
return DiagnosedSilenceableFailure::success();
}
LogicalResult transform::NumAssociationsOp::verify() {
// Verify that the result type accepts an i64 attribute as payload.
auto resultType = cast<TransformParamTypeInterface>(getNum().getType());
return resultType
.checkPayload(getLoc(), {Builder(getContext()).getI64IntegerAttr(0)})
.checkAndReport();
}
//===----------------------------------------------------------------------===//
// SelectOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::SelectOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<Operation *> result;
auto payloadOps = state.getPayloadOps(getTarget());
for (Operation *op : payloadOps) {
if (op->getName().getStringRef() == getOpName())
result.push_back(op);
}
results.set(cast<OpResult>(getResult()), result);
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// SplitHandleOp
//===----------------------------------------------------------------------===//
void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result,
Value target, int64_t numResultHandles) {
result.addOperands(target);
result.addTypes(SmallVector<Type>(numResultHandles, target.getType()));
}
DiagnosedSilenceableFailure
transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
int64_t numPayloadOps = llvm::range_size(state.getPayloadOps(getHandle()));
auto produceNumOpsError = [&]() {
return emitSilenceableError()
<< getHandle() << " expected to contain " << this->getNumResults()
<< " payload ops but it contains " << numPayloadOps
<< " payload ops";
};
// Fail if there are more payload ops than results and no overflow result was
// specified.
if (numPayloadOps > getNumResults() && !getOverflowResult().has_value())
return produceNumOpsError();
// Fail if there are more results than payload ops. Unless:
// - "fail_on_payload_too_small" is set to "false", or
// - "pass_through_empty_handle" is set to "true" and there are 0 payload ops.
if (numPayloadOps < getNumResults() && getFailOnPayloadTooSmall() &&
(numPayloadOps != 0 || !getPassThroughEmptyHandle()))
return produceNumOpsError();
// Distribute payload ops.
SmallVector<SmallVector<Operation *, 1>> resultHandles(getNumResults(), {});
if (getOverflowResult())
resultHandles[*getOverflowResult()].reserve(numPayloadOps -
getNumResults());
for (auto &&en : llvm::enumerate(state.getPayloadOps(getHandle()))) {
int64_t resultNum = en.index();
if (resultNum >= getNumResults())
resultNum = *getOverflowResult();
resultHandles[resultNum].push_back(en.value());
}
// Set transform op results.
for (auto &&it : llvm::enumerate(resultHandles))
results.set(llvm::cast<OpResult>(getResult(it.index())), it.value());
return DiagnosedSilenceableFailure::success();
}
void transform::SplitHandleOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getHandle(), effects);
producesHandle(getResults(), effects);
// There are no effects on the Payload IR as this is only a handle
// manipulation.
}
LogicalResult transform::SplitHandleOp::verify() {
if (getOverflowResult().has_value() &&
!(*getOverflowResult() < getNumResults()))
return emitOpError("overflow_result is not a valid result index");
return success();
}
//===----------------------------------------------------------------------===//
// ReplicateOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::ReplicateOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern()));
for (const auto &en : llvm::enumerate(getHandles())) {
Value handle = en.value();
if (isa<TransformHandleTypeInterface>(handle.getType())) {
SmallVector<Operation *> current =
llvm::to_vector(state.getPayloadOps(handle));
SmallVector<Operation *> payload;
payload.reserve(numRepetitions * current.size());
for (unsigned i = 0; i < numRepetitions; ++i)
llvm::append_range(payload, current);
results.set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload);
} else {
assert(llvm::isa<TransformParamTypeInterface>(handle.getType()) &&
"expected param type");
ArrayRef<Attribute> current = state.getParams(handle);
SmallVector<Attribute> params;
params.reserve(numRepetitions * current.size());
for (unsigned i = 0; i < numRepetitions; ++i)
llvm::append_range(params, current);
results.setParams(llvm::cast<OpResult>(getReplicated()[en.index()]),
params);
}
}
return DiagnosedSilenceableFailure::success();
}
void transform::ReplicateOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getPattern(), effects);
onlyReadsHandle(getHandles(), effects);
producesHandle(getReplicated(), effects);
}
//===----------------------------------------------------------------------===//
// SequenceOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::SequenceOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
// Map the entry block argument to the list of operations.
auto scope = state.make_region_scope(*getBodyBlock()->getParent());
if (failed(mapBlockArguments(state)))
return DiagnosedSilenceableFailure::definiteFailure();
return applySequenceBlock(*getBodyBlock(), getFailurePropagationMode(), state,
results);
}
static ParseResult parseSequenceOpOperands(
OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
Type &rootType,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
SmallVectorImpl<Type> &extraBindingTypes) {
OpAsmParser::UnresolvedOperand rootOperand;
OptionalParseResult hasRoot = parser.parseOptionalOperand(rootOperand);
if (!hasRoot.has_value()) {
root = std::nullopt;
return success();
}
if (failed(hasRoot.value()))
return failure();
root = rootOperand;
if (succeeded(parser.parseOptionalComma())) {
if (failed(parser.parseOperandList(extraBindings)))
return failure();
}
if (failed(parser.parseColon()))
return failure();
// The paren is truly optional.
(void)parser.parseOptionalLParen();
if (failed(parser.parseType(rootType))) {
return failure();
}
if (!extraBindings.empty()) {
if (parser.parseComma() || parser.parseTypeList(extraBindingTypes))
return failure();
}
if (extraBindingTypes.size() != extraBindings.size()) {
return parser.emitError(parser.getNameLoc(),
"expected types to be provided for all operands");
}
// The paren is truly optional.
(void)parser.parseOptionalRParen();
return success();
}
static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
Value root, Type rootType,
ValueRange extraBindings,
TypeRange extraBindingTypes) {
if (!root)
return;
printer << root;
bool hasExtras = !extraBindings.empty();
if (hasExtras) {
printer << ", ";
printer.printOperands(extraBindings);
}
printer << " : ";
if (hasExtras)
printer << "(";
printer << rootType;
if (hasExtras) {
printer << ", ";
llvm::interleaveComma(extraBindingTypes, printer.getStream());
printer << ")";
}
}
/// Returns `true` if the given op operand may be consuming the handle value in
/// the Transform IR. That is, if it may have a Free effect on it.
static bool isValueUsePotentialConsumer(OpOperand &use) {
// Conservatively assume the effect being present in absence of the interface.
auto iface = dyn_cast<transform::TransformOpInterface>(use.getOwner());
if (!iface)
return true;
return isHandleConsumed(use.get(), iface);
}
LogicalResult
checkDoubleConsume(Value value,
function_ref<InFlightDiagnostic()> reportError) {
OpOperand *potentialConsumer = nullptr;
for (OpOperand &use : value.getUses()) {
if (!isValueUsePotentialConsumer(use))
continue;
if (!potentialConsumer) {
potentialConsumer = &use;
continue;
}
InFlightDiagnostic diag = reportError()
<< " has more than one potential consumer";
diag.attachNote(potentialConsumer->getOwner()->getLoc())
<< "used here as operand #" << potentialConsumer->getOperandNumber();
diag.attachNote(use.getOwner()->getLoc())
<< "used here as operand #" << use.getOperandNumber();
return diag;
}
return success();
}
LogicalResult transform::SequenceOp::verify() {
assert(getBodyBlock()->getNumArguments() >= 1 &&
"the number of arguments must have been verified to be more than 1 by "
"PossibleTopLevelTransformOpTrait");
if (!getRoot() && !getExtraBindings().empty()) {
return emitOpError()
<< "does not expect extra operands when used as top-level";
}
// Check if a block argument has more than one consuming use.
for (BlockArgument arg : getBodyBlock()->getArguments()) {
if (failed(checkDoubleConsume(arg, [this, arg]() {
return (emitOpError() << "block argument #" << arg.getArgNumber());
}))) {
return failure();
}
}
// Check properties of the nested operations they cannot check themselves.
for (Operation &child : *getBodyBlock()) {
if (!isa<TransformOpInterface>(child) &&
&child != &getBodyBlock()->back()) {
InFlightDiagnostic diag =
emitOpError()
<< "expected children ops to implement TransformOpInterface";
diag.attachNote(child.getLoc()) << "op without interface";
return diag;
}
for (OpResult result : child.getResults()) {
auto report = [&]() {
return (child.emitError() << "result #" << result.getResultNumber());
};
if (failed(checkDoubleConsume(result, report)))
return failure();
}
}
if (!getBodyBlock()->mightHaveTerminator())
return emitOpError() << "expects to have a terminator in the body";
if (getBodyBlock()->getTerminator()->getOperandTypes() !=
getOperation()->getResultTypes()) {
InFlightDiagnostic diag = emitOpError()
<< "expects the types of the terminator operands "
"to match the types of the result";
diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
return diag;
}
return success();
}
void transform::SequenceOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
getPotentialTopLevelEffects(effects);
}
OperandRange
transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) {
assert(point == getBody() && "unexpected region index");
if (getOperation()->getNumOperands() > 0)
return getOperation()->getOperands();
return OperandRange(getOperation()->operand_end(),
getOperation()->operand_end());
}
void transform::SequenceOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
if (point.isParent()) {
Region *bodyRegion = &getBody();
regions.emplace_back(bodyRegion, getNumOperands() != 0
? bodyRegion->getArguments()
: Block::BlockArgListType());
return;
}
assert(point == getBody() && "unexpected region index");
regions.emplace_back(getOperation()->getResults());
}
void transform::SequenceOp::getRegionInvocationBounds(
ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
(void)operands;
bounds.emplace_back(1, 1);
}
void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
TypeRange resultTypes,
FailurePropagationMode failurePropagationMode,
Value root,
SequenceBodyBuilderFn bodyBuilder) {
build(builder, state, resultTypes, failurePropagationMode, root,
/*extra_bindings=*/ValueRange());
Type bbArgType = root.getType();
buildSequenceBody(builder, state, bbArgType,
/*extraBindingTypes=*/TypeRange(), bodyBuilder);
}
void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
TypeRange resultTypes,
FailurePropagationMode failurePropagationMode,
Value root, ValueRange extraBindings,
SequenceBodyBuilderArgsFn bodyBuilder) {
build(builder, state, resultTypes, failurePropagationMode, root,
extraBindings);
buildSequenceBody(builder, state, root.getType(), extraBindings.getTypes(),
bodyBuilder);
}
void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
TypeRange resultTypes,
FailurePropagationMode failurePropagationMode,
Type bbArgType,
SequenceBodyBuilderFn bodyBuilder) {
build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
/*extra_bindings=*/ValueRange());
buildSequenceBody(builder, state, bbArgType,
/*extraBindingTypes=*/TypeRange(), bodyBuilder);
}
void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
TypeRange resultTypes,
FailurePropagationMode failurePropagationMode,
Type bbArgType, TypeRange extraBindingTypes,
SequenceBodyBuilderArgsFn bodyBuilder) {
build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
/*extra_bindings=*/ValueRange());
buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder);
}
//===----------------------------------------------------------------------===//
// PrintOp
//===----------------------------------------------------------------------===//
void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
StringRef name) {
if (!name.empty())
result.getOrAddProperties<Properties>().name = builder.getStringAttr(name);
}
void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
Value target, StringRef name) {
result.addOperands({target});
build(builder, result, name);
}
DiagnosedSilenceableFailure
transform::PrintOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
llvm::outs() << "[[[ IR printer: ";
if (getName().has_value())
llvm::outs() << *getName() << " ";
OpPrintingFlags printFlags;
if (getAssumeVerified().value_or(false))
printFlags.assumeVerified();
if (getUseLocalScope().value_or(false))
printFlags.useLocalScope();
if (getSkipRegions().value_or(false))
printFlags.skipRegions();
if (!getTarget()) {
llvm::outs() << "top-level ]]]\n";
state.getTopLevel()->print(llvm::outs(), printFlags);
llvm::outs() << "\n";
return DiagnosedSilenceableFailure::success();
}
llvm::outs() << "]]]\n";
for (Operation *target : state.getPayloadOps(getTarget())) {
target->print(llvm::outs(), printFlags);
llvm::outs() << "\n";
}
return DiagnosedSilenceableFailure::success();
}
void transform::PrintOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
// We don't really care about mutability here, but `getTarget` now
// unconditionally casts to a specific type before verification could run
// here.
if (!getTargetMutable().empty())
onlyReadsHandle(getTargetMutable()[0].get(), effects);
onlyReadsPayload(effects);
// There is no resource for stderr file descriptor, so just declare print
// writes into the default resource.
effects.emplace_back(MemoryEffects::Write::get());
}
//===----------------------------------------------------------------------===//
// VerifyOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::VerifyOp::applyToOne(transform::TransformRewriter &rewriter,
Operation *target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
if (failed(::mlir::verify(target))) {
DiagnosedDefiniteFailure diag = emitDefiniteFailure()
<< "failed to verify payload op";
diag.attachNote(target->getLoc()) << "payload op";
return diag;
}
return DiagnosedSilenceableFailure::success();
}
void transform::VerifyOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getTarget(), effects);
}
//===----------------------------------------------------------------------===//
// YieldOp
//===----------------------------------------------------------------------===//
void transform::YieldOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getOperands(), effects);
}