Rolf Morel fe7bf4b90b
[MLIR][Transform] apply_registered_pass op's options as a dict (#143159)
Improve ApplyRegisteredPassOp's support for taking options by taking
them as a dict (vs a list of string-valued key-value pairs).

Values of options are provided as either static attributes or as params
(which pass in attributes at interpreter runtime). In either case, the
keys and value attributes are converted to strings and a single
options-string, in the format used on the commandline, is constructed to
pass to the `addToPipeline`-pass API.
2025-06-11 17:33:55 +01:00

3087 lines
124 KiB
C++

//===- TransformOps.cpp - Transform dialect operations --------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/CSE.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/InterleavedRange.h"
#include <optional>
#define DEBUG_TYPE "transform-dialect"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
#define DEBUG_TYPE_MATCHER "transform-matcher"
#define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ")
#define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x)
using namespace mlir;
static ParseResult parseApplyRegisteredPassOptions(
OpAsmParser &parser, DictionaryAttr &options,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions);
static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
Operation *op,
DictionaryAttr options,
ValueRange dynamicOptions);
static ParseResult parseSequenceOpOperands(
OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
Type &rootType,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
SmallVectorImpl<Type> &extraBindingTypes);
static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
Value root, Type rootType,
ValueRange extraBindings,
TypeRange extraBindingTypes);
static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op,
ArrayAttr matchers, ArrayAttr actions);
static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
ArrayAttr &matchers,
ArrayAttr &actions);
/// Helper function to check if the given transform op is contained in (or
/// equal to) the given payload target op. In that case, an error is returned.
/// Transforming transform IR that is currently executing is generally unsafe.
static DiagnosedSilenceableFailure
ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform,
Operation *payload) {
Operation *transformAncestor = transform.getOperation();
while (transformAncestor) {
if (transformAncestor == payload) {
DiagnosedDefiniteFailure diag =
transform.emitDefiniteFailure()
<< "cannot apply transform to itself (or one of its ancestors)";
diag.attachNote(payload->getLoc()) << "target payload op";
return diag;
}
transformAncestor = transformAncestor->getParentOp();
}
return DiagnosedSilenceableFailure::success();
}
#define GET_OP_CLASSES
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
//===----------------------------------------------------------------------===//
// AlternativesOp
//===----------------------------------------------------------------------===//
OperandRange
transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) {
if (!point.isParent() && getOperation()->getNumOperands() == 1)
return getOperation()->getOperands();
return OperandRange(getOperation()->operand_end(),
getOperation()->operand_end());
}
void transform::AlternativesOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
for (Region &alternative : llvm::drop_begin(
getAlternatives(),
point.isParent() ? 0
: point.getRegionOrNull()->getRegionNumber() + 1)) {
regions.emplace_back(&alternative, !getOperands().empty()
? alternative.getArguments()
: Block::BlockArgListType());
}
if (!point.isParent())
regions.emplace_back(getOperation()->getResults());
}
void transform::AlternativesOp::getRegionInvocationBounds(
ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
(void)operands;
// The region corresponding to the first alternative is always executed, the
// remaining may or may not be executed.
bounds.reserve(getNumRegions());
bounds.emplace_back(1, 1);
bounds.resize(getNumRegions(), InvocationBounds(0, 1));
}
static void forwardEmptyOperands(Block *block, transform::TransformState &state,
transform::TransformResults &results) {
for (const auto &res : block->getParentOp()->getOpResults())
results.set(res, {});
}
DiagnosedSilenceableFailure
transform::AlternativesOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<Operation *> originals;
if (Value scopeHandle = getScope())
llvm::append_range(originals, state.getPayloadOps(scopeHandle));
else
originals.push_back(state.getTopLevel());
for (Operation *original : originals) {
if (original->isAncestor(getOperation())) {
auto diag = emitDefiniteFailure()
<< "scope must not contain the transforms being applied";
diag.attachNote(original->getLoc()) << "scope";
return diag;
}
if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
auto diag = emitDefiniteFailure()
<< "only isolated-from-above ops can be alternative scopes";
diag.attachNote(original->getLoc()) << "scope";
return diag;
}
}
for (Region &reg : getAlternatives()) {
// Clone the scope operations and make the transforms in this alternative
// region apply to them by virtue of mapping the block argument (the only
// visible handle) to the cloned scope operations. This effectively prevents
// the transformation from accessing any IR outside the scope.
auto scope = state.make_region_scope(reg);
auto clones = llvm::to_vector(
llvm::map_range(originals, [](Operation *op) { return op->clone(); }));
auto deleteClones = llvm::make_scope_exit([&] {
for (Operation *clone : clones)
clone->erase();
});
if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
return DiagnosedSilenceableFailure::definiteFailure();
bool failed = false;
for (Operation &transform : reg.front().without_terminator()) {
DiagnosedSilenceableFailure result =
state.applyTransform(cast<TransformOpInterface>(transform));
if (result.isSilenceableFailure()) {
LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage()
<< "\n");
failed = true;
break;
}
if (::mlir::failed(result.silence()))
return DiagnosedSilenceableFailure::definiteFailure();
}
// If all operations in the given alternative succeeded, no need to consider
// the rest. Replace the original scoping operation with the clone on which
// the transformations were performed.
if (!failed) {
// We will be using the clones, so cancel their scheduled deletion.
deleteClones.release();
TrackingListener listener(state, *this);
IRRewriter rewriter(getContext(), &listener);
for (const auto &kvp : llvm::zip(originals, clones)) {
Operation *original = std::get<0>(kvp);
Operation *clone = std::get<1>(kvp);
original->getBlock()->getOperations().insert(original->getIterator(),
clone);
rewriter.replaceOp(original, clone->getResults());
}
detail::forwardTerminatorOperands(&reg.front(), state, results);
return DiagnosedSilenceableFailure::success();
}
}
return emitSilenceableError() << "all alternatives failed";
}
void transform::AlternativesOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
consumesHandle(getOperation()->getOpOperands(), effects);
producesHandle(getOperation()->getOpResults(), effects);
for (Region *region : getRegions()) {
if (!region->empty())
producesHandle(region->front().getArguments(), effects);
}
modifiesPayload(effects);
}
LogicalResult transform::AlternativesOp::verify() {
for (Region &alternative : getAlternatives()) {
Block &block = alternative.front();
Operation *terminator = block.getTerminator();
if (terminator->getOperands().getTypes() != getResults().getTypes()) {
InFlightDiagnostic diag = emitOpError()
<< "expects terminator operands to have the "
"same type as results of the operation";
diag.attachNote(terminator->getLoc()) << "terminator";
return diag;
}
}
return success();
}
//===----------------------------------------------------------------------===//
// AnnotateOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::AnnotateOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<Operation *> targets =
llvm::to_vector(state.getPayloadOps(getTarget()));
Attribute attr = UnitAttr::get(getContext());
if (auto paramH = getParam()) {
ArrayRef<Attribute> params = state.getParams(paramH);
if (params.size() != 1) {
if (targets.size() != params.size()) {
return emitSilenceableError()
<< "parameter and target have different payload lengths ("
<< params.size() << " vs " << targets.size() << ")";
}
for (auto &&[target, attr] : llvm::zip_equal(targets, params))
target->setAttr(getName(), attr);
return DiagnosedSilenceableFailure::success();
}
attr = params[0];
}
for (auto *target : targets)
target->setAttr(getName(), attr);
return DiagnosedSilenceableFailure::success();
}
void transform::AnnotateOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getTargetMutable(), effects);
onlyReadsHandle(getParamMutable(), effects);
modifiesPayload(effects);
}
//===----------------------------------------------------------------------===//
// ApplyCommonSubexpressionEliminationOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
transform::TransformRewriter &rewriter, Operation *target,
ApplyToEachResultList &results, transform::TransformState &state) {
// Make sure that this transform is not applied to itself. Modifying the
// transform IR while it is being interpreted is generally dangerous.
DiagnosedSilenceableFailure payloadCheck =
ensurePayloadIsSeparateFromTransform(*this, target);
if (!payloadCheck.succeeded())
return payloadCheck;
DominanceInfo domInfo;
mlir::eliminateCommonSubExpressions(rewriter, domInfo, target);
return DiagnosedSilenceableFailure::success();
}
void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getTargetMutable(), effects);
transform::modifiesPayload(effects);
}
//===----------------------------------------------------------------------===//
// ApplyDeadCodeEliminationOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform::ApplyDeadCodeEliminationOp::applyToOne(
transform::TransformRewriter &rewriter, Operation *target,
ApplyToEachResultList &results, transform::TransformState &state) {
// Make sure that this transform is not applied to itself. Modifying the
// transform IR while it is being interpreted is generally dangerous.
DiagnosedSilenceableFailure payloadCheck =
ensurePayloadIsSeparateFromTransform(*this, target);
if (!payloadCheck.succeeded())
return payloadCheck;
// Maintain a worklist of potentially dead ops.
SetVector<Operation *> worklist;
// Helper function that adds all defining ops of used values (operands and
// operands of nested ops).
auto addDefiningOpsToWorklist = [&](Operation *op) {
op->walk([&](Operation *op) {
for (Value v : op->getOperands())
if (Operation *defOp = v.getDefiningOp())
if (target->isProperAncestor(defOp))
worklist.insert(defOp);
});
};
// Helper function that erases an op.
auto eraseOp = [&](Operation *op) {
// Remove op and nested ops from the worklist.
op->walk([&](Operation *op) {
const auto *it = llvm::find(worklist, op);
if (it != worklist.end())
worklist.erase(it);
});
rewriter.eraseOp(op);
};
// Initial walk over the IR.
target->walk<WalkOrder::PostOrder>([&](Operation *op) {
if (op != target && isOpTriviallyDead(op)) {
addDefiningOpsToWorklist(op);
eraseOp(op);
}
});
// Erase all ops that have become dead.
while (!worklist.empty()) {
Operation *op = worklist.pop_back_val();
if (!isOpTriviallyDead(op))
continue;
addDefiningOpsToWorklist(op);
eraseOp(op);
}
return DiagnosedSilenceableFailure::success();
}
void transform::ApplyDeadCodeEliminationOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getTargetMutable(), effects);
transform::modifiesPayload(effects);
}
//===----------------------------------------------------------------------===//
// ApplyPatternsOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
transform::TransformRewriter &rewriter, Operation *target,
ApplyToEachResultList &results, transform::TransformState &state) {
// Make sure that this transform is not applied to itself. Modifying the
// transform IR while it is being interpreted is generally dangerous. Even
// more so for the ApplyPatternsOp because the GreedyPatternRewriteDriver
// performs many additional simplifications such as dead code elimination.
DiagnosedSilenceableFailure payloadCheck =
ensurePayloadIsSeparateFromTransform(*this, target);
if (!payloadCheck.succeeded())
return payloadCheck;
// Gather all specified patterns.
MLIRContext *ctx = target->getContext();
RewritePatternSet patterns(ctx);
if (!getRegion().empty()) {
for (Operation &op : getRegion().front()) {
cast<transform::PatternDescriptorOpInterface>(&op)
.populatePatternsWithState(patterns, state);
}
}
// Configure the GreedyPatternRewriteDriver.
GreedyRewriteConfig config;
config.setListener(
static_cast<RewriterBase::Listener *>(rewriter.getListener()));
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
config.setMaxIterations(getMaxIterations() == static_cast<uint64_t>(-1)
? GreedyRewriteConfig::kNoLimit
: getMaxIterations());
config.setMaxNumRewrites(getMaxNumRewrites() == static_cast<uint64_t>(-1)
? GreedyRewriteConfig::kNoLimit
: getMaxNumRewrites());
// Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE
// was requested, apply the greedy pattern rewrite only once. (The greedy
// pattern rewrite driver already iterates to a fixpoint internally.)
bool cseChanged = false;
// One or two iterations should be sufficient. Stop iterating after a certain
// threshold to make debugging easier.
static const int64_t kNumMaxIterations = 50;
int64_t iteration = 0;
do {
LogicalResult result = failure();
if (target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
// Op is isolated from above. Apply patterns and also perform region
// simplification.
result = applyPatternsGreedily(target, frozenPatterns, config);
} else {
// Manually gather list of ops because the other
// GreedyPatternRewriteDriver overloads only accepts ops that are isolated
// from above. This way, patterns can be applied to ops that are not
// isolated from above. Regions are not being simplified. Furthermore,
// only a single greedy rewrite iteration is performed.
SmallVector<Operation *> ops;
target->walk([&](Operation *nestedOp) {
if (target != nestedOp)
ops.push_back(nestedOp);
});
result = applyOpPatternsGreedily(ops, frozenPatterns, config);
}
// A failure typically indicates that the pattern application did not
// converge.
if (failed(result)) {
return emitSilenceableFailure(target)
<< "greedy pattern application failed";
}
if (getApplyCse()) {
DominanceInfo domInfo;
mlir::eliminateCommonSubExpressions(rewriter, domInfo, target,
&cseChanged);
}
} while (cseChanged && ++iteration < kNumMaxIterations);
if (iteration == kNumMaxIterations)
return emitDefiniteFailure() << "fixpoint iteration did not converge";
return DiagnosedSilenceableFailure::success();
}
LogicalResult transform::ApplyPatternsOp::verify() {
if (!getRegion().empty()) {
for (Operation &op : getRegion().front()) {
if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
InFlightDiagnostic diag = emitOpError()
<< "expected children ops to implement "
"PatternDescriptorOpInterface";
diag.attachNote(op.getLoc()) << "op without interface";
return diag;
}
}
}
return success();
}
void transform::ApplyPatternsOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getTargetMutable(), effects);
transform::modifiesPayload(effects);
}
void transform::ApplyPatternsOp::build(
OpBuilder &builder, OperationState &result, Value target,
function_ref<void(OpBuilder &, Location)> bodyBuilder) {
result.addOperands(target);
OpBuilder::InsertionGuard g(builder);
Region *region = result.addRegion();
builder.createBlock(region);
if (bodyBuilder)
bodyBuilder(builder, result.location);
}
//===----------------------------------------------------------------------===//
// ApplyCanonicalizationPatternsOp
//===----------------------------------------------------------------------===//
void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
MLIRContext *ctx = patterns.getContext();
for (Dialect *dialect : ctx->getLoadedDialects())
dialect->getCanonicalizationPatterns(patterns);
for (RegisteredOperationName op : ctx->getRegisteredOperations())
op.getCanonicalizationPatterns(patterns, ctx);
}
//===----------------------------------------------------------------------===//
// ApplyConversionPatternsOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
MLIRContext *ctx = getContext();
// Instantiate the default type converter if a type converter builder is
// specified.
std::unique_ptr<TypeConverter> defaultTypeConverter;
transform::TypeConverterBuilderOpInterface typeConverterBuilder =
getDefaultTypeConverter();
if (typeConverterBuilder)
defaultTypeConverter = typeConverterBuilder.getTypeConverter();
// Configure conversion target.
ConversionTarget conversionTarget(*getContext());
if (getLegalOps())
for (Attribute attr : cast<ArrayAttr>(*getLegalOps()))
conversionTarget.addLegalOp(
OperationName(cast<StringAttr>(attr).getValue(), ctx));
if (getIllegalOps())
for (Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
conversionTarget.addIllegalOp(
OperationName(cast<StringAttr>(attr).getValue(), ctx));
if (getLegalDialects())
for (Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
if (getIllegalDialects())
for (Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
// Gather all specified patterns.
RewritePatternSet patterns(ctx);
// Need to keep the converters alive until after pattern application because
// the patterns take a reference to an object that would otherwise get out of
// scope.
SmallVector<std::unique_ptr<TypeConverter>> keepAliveConverters;
if (!getPatterns().empty()) {
for (Operation &op : getPatterns().front()) {
auto descriptor =
cast<transform::ConversionPatternDescriptorOpInterface>(&op);
// Check if this pattern set specifies a type converter.
std::unique_ptr<TypeConverter> typeConverter =
descriptor.getTypeConverter();
TypeConverter *converter = nullptr;
if (typeConverter) {
keepAliveConverters.emplace_back(std::move(typeConverter));
converter = keepAliveConverters.back().get();
} else {
// No type converter specified: Use the default type converter.
if (!defaultTypeConverter) {
auto diag = emitDefiniteFailure()
<< "pattern descriptor does not specify type "
"converter and apply_conversion_patterns op has "
"no default type converter";
diag.attachNote(op.getLoc()) << "pattern descriptor op";
return diag;
}
converter = defaultTypeConverter.get();
}
// Add descriptor-specific updates to the conversion target, which may
// depend on the final type converter. In structural converters, the
// legality of types dictates the dynamic legality of an operation.
descriptor.populateConversionTargetRules(*converter, conversionTarget);
descriptor.populatePatterns(*converter, patterns);
}
}
// Attach a tracking listener if handles should be preserved. We configure the
// listener to allow op replacements with different names, as conversion
// patterns typically replace ops with replacement ops that have a different
// name.
TrackingListenerConfig trackingConfig;
trackingConfig.requireMatchingReplacementOpName = false;
ErrorCheckingTrackingListener trackingListener(state, *this, trackingConfig);
ConversionConfig conversionConfig;
if (getPreserveHandles())
conversionConfig.listener = &trackingListener;
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
for (Operation *target : state.getPayloadOps(getTarget())) {
// Make sure that this transform is not applied to itself. Modifying the
// transform IR while it is being interpreted is generally dangerous.
DiagnosedSilenceableFailure payloadCheck =
ensurePayloadIsSeparateFromTransform(*this, target);
if (!payloadCheck.succeeded())
return payloadCheck;
LogicalResult status = failure();
if (getPartialConversion()) {
status = applyPartialConversion(target, conversionTarget, frozenPatterns,
conversionConfig);
} else {
status = applyFullConversion(target, conversionTarget, frozenPatterns,
conversionConfig);
}
// Check dialect conversion state.
DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success();
if (failed(status)) {
diag = emitSilenceableError() << "dialect conversion failed";
diag.attachNote(target->getLoc()) << "target op";
}
// Check tracking listener error state.
DiagnosedSilenceableFailure trackingFailure =
trackingListener.checkAndResetError();
if (!trackingFailure.succeeded()) {
if (diag.succeeded()) {
// Tracking failure is the only failure.
return trackingFailure;
} else {
diag.attachNote() << "tracking listener also failed: "
<< trackingFailure.getMessage();
(void)trackingFailure.silence();
}
}
if (!diag.succeeded())
return diag;
}
return DiagnosedSilenceableFailure::success();
}
LogicalResult transform::ApplyConversionPatternsOp::verify() {
if (getNumRegions() != 1 && getNumRegions() != 2)
return emitOpError() << "expected 1 or 2 regions";
if (!getPatterns().empty()) {
for (Operation &op : getPatterns().front()) {
if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
InFlightDiagnostic diag =
emitOpError() << "expected pattern children ops to implement "
"ConversionPatternDescriptorOpInterface";
diag.attachNote(op.getLoc()) << "op without interface";
return diag;
}
}
}
if (getNumRegions() == 2) {
Region &typeConverterRegion = getRegion(1);
if (!llvm::hasSingleElement(typeConverterRegion.front()))
return emitOpError()
<< "expected exactly one op in default type converter region";
Operation *maybeTypeConverter = &typeConverterRegion.front().front();
auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
maybeTypeConverter);
if (!typeConverterOp) {
InFlightDiagnostic diag = emitOpError()
<< "expected default converter child op to "
"implement TypeConverterBuilderOpInterface";
diag.attachNote(maybeTypeConverter->getLoc()) << "op without interface";
return diag;
}
// Check default type converter type.
if (!getPatterns().empty()) {
for (Operation &op : getPatterns().front()) {
auto descriptor =
cast<transform::ConversionPatternDescriptorOpInterface>(&op);
if (failed(descriptor.verifyTypeConverter(typeConverterOp)))
return failure();
}
}
}
return success();
}
void transform::ApplyConversionPatternsOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
if (!getPreserveHandles()) {
transform::consumesHandle(getTargetMutable(), effects);
} else {
transform::onlyReadsHandle(getTargetMutable(), effects);
}
transform::modifiesPayload(effects);
}
void transform::ApplyConversionPatternsOp::build(
OpBuilder &builder, OperationState &result, Value target,
function_ref<void(OpBuilder &, Location)> patternsBodyBuilder,
function_ref<void(OpBuilder &, Location)> typeConverterBodyBuilder) {
result.addOperands(target);
{
OpBuilder::InsertionGuard g(builder);
Region *region1 = result.addRegion();
builder.createBlock(region1);
if (patternsBodyBuilder)
patternsBodyBuilder(builder, result.location);
}
{
OpBuilder::InsertionGuard g(builder);
Region *region2 = result.addRegion();
builder.createBlock(region2);
if (typeConverterBodyBuilder)
typeConverterBodyBuilder(builder, result.location);
}
}
//===----------------------------------------------------------------------===//
// ApplyToLLVMConversionPatternsOp
//===----------------------------------------------------------------------===//
void transform::ApplyToLLVMConversionPatternsOp::populatePatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns) {
Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
assert(dialect && "expected that dialect is loaded");
auto *iface = cast<ConvertToLLVMPatternInterface>(dialect);
// ConversionTarget is currently ignored because the enclosing
// apply_conversion_patterns op sets up its own ConversionTarget.
ConversionTarget target(*getContext());
iface->populateConvertToLLVMConversionPatterns(
target, static_cast<LLVMTypeConverter &>(typeConverter), patterns);
}
LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter(
transform::TypeConverterBuilderOpInterface builder) {
if (builder.getTypeConverterType() != "LLVMTypeConverter")
return emitOpError("expected LLVMTypeConverter");
return success();
}
LogicalResult transform::ApplyToLLVMConversionPatternsOp::verify() {
Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
if (!dialect)
return emitOpError("unknown dialect or dialect not loaded: ")
<< getDialectName();
auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
if (!iface)
return emitOpError(
"dialect does not implement ConvertToLLVMPatternInterface or "
"extension was not loaded: ")
<< getDialectName();
return success();
}
//===----------------------------------------------------------------------===//
// ApplyLoopInvariantCodeMotionOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::ApplyLoopInvariantCodeMotionOp::applyToOne(
transform::TransformRewriter &rewriter, LoopLikeOpInterface target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
// Currently, LICM does not remove operations, so we don't need tracking.
// If this ever changes, add a LICM entry point that takes a rewriter.
moveLoopInvariantCode(target);
return DiagnosedSilenceableFailure::success();
}
void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getTargetMutable(), effects);
transform::modifiesPayload(effects);
}
//===----------------------------------------------------------------------===//
// ApplyRegisteredPassOp
//===----------------------------------------------------------------------===//
void transform::ApplyRegisteredPassOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
consumesHandle(getTargetMutable(), effects);
onlyReadsHandle(getDynamicOptionsMutable(), effects);
producesHandle(getOperation()->getOpResults(), effects);
modifiesPayload(effects);
}
DiagnosedSilenceableFailure
transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
// Obtain a single options-string to pass to the pass(-pipeline) from options
// passed in as a dictionary of keys mapping to values which are either
// attributes or param-operands pointing to attributes.
std::string options;
llvm::raw_string_ostream optionsStream(options); // For "printing" attrs.
OperandRange dynamicOptions = getDynamicOptions();
for (auto [idx, namedAttribute] : llvm::enumerate(getOptions())) {
if (idx > 0)
optionsStream << " "; // Interleave options separator.
optionsStream << namedAttribute.getName().str(); // Append the key.
optionsStream << "="; // And the key-value separator.
Attribute valueAttrToAppend;
if (auto paramOperandIndex =
dyn_cast<transform::ParamOperandAttr>(namedAttribute.getValue())) {
// The corresponding value attribute is passed in via a param.
// Obtain the param-operand via its specified index.
size_t dynamicOptionIdx = paramOperandIndex.getIndex().getInt();
assert(dynamicOptionIdx < dynamicOptions.size() &&
"number of dynamic option markers (UnitAttr) in options ArrayAttr "
"should be the same as the number of options passed as params");
ArrayRef<Attribute> dynamicOption =
state.getParams(dynamicOptions[dynamicOptionIdx]);
if (dynamicOption.size() != 1)
return emitSilenceableError()
<< "options passed as a param must have "
"a single value associated, param "
<< dynamicOptionIdx << " associates " << dynamicOption.size();
valueAttrToAppend = dynamicOption[0];
} else {
// Value is a static attribute.
valueAttrToAppend = namedAttribute.getValue();
}
// Append string representation of value attribute.
if (auto strAttr = dyn_cast<StringAttr>(valueAttrToAppend)) {
optionsStream << strAttr.getValue().str();
} else {
valueAttrToAppend.print(optionsStream, /*elideType=*/true);
}
}
optionsStream.flush();
// Get pass or pass pipeline from registry.
const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName());
if (!info)
info = PassInfo::lookup(getPassName());
if (!info)
return emitDefiniteFailure()
<< "unknown pass or pass pipeline: " << getPassName();
// Create pass manager and add the pass or pass pipeline.
PassManager pm(getContext());
if (failed(info->addToPipeline(pm, options, [&](const Twine &msg) {
emitError(msg);
return failure();
}))) {
return emitDefiniteFailure()
<< "failed to add pass or pass pipeline to pipeline: "
<< getPassName();
}
auto targets = SmallVector<Operation *>(state.getPayloadOps(getTarget()));
for (Operation *target : targets) {
// Make sure that this transform is not applied to itself. Modifying the
// transform IR while it is being interpreted is generally dangerous. Even
// more so when applying passes because they may perform a wide range of IR
// modifications.
DiagnosedSilenceableFailure payloadCheck =
ensurePayloadIsSeparateFromTransform(*this, target);
if (!payloadCheck.succeeded())
return payloadCheck;
// Run the pass or pass pipeline on the current target operation.
if (failed(pm.run(target))) {
auto diag = emitSilenceableError() << "pass pipeline failed";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}
}
// The applied pass will have directly modified the payload IR(s).
results.set(llvm::cast<OpResult>(getResult()), targets);
return DiagnosedSilenceableFailure::success();
}
static ParseResult parseApplyRegisteredPassOptions(
OpAsmParser &parser, DictionaryAttr &options,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions) {
// Construct the options DictionaryAttr per a `{ key = value, ... }` syntax.
SmallVector<NamedAttribute> keyValuePairs;
size_t dynamicOptionsIdx = 0;
auto parseKeyValuePair = [&]() -> ParseResult {
// Parse items of the form `key = value` where `key` is a bare identifier or
// a string and `value` is either an attribute or an operand.
std::string key;
Attribute valueAttr;
if (parser.parseOptionalKeywordOrString(&key))
return parser.emitError(parser.getCurrentLocation())
<< "expected key to either be an identifier or a string";
if (key.empty())
return failure();
if (parser.parseEqual())
return parser.emitError(parser.getCurrentLocation())
<< "expected '=' after key in key-value pair";
// Parse the value, which can be either an attribute or an operand.
OptionalParseResult parsedValueAttr =
parser.parseOptionalAttribute(valueAttr);
if (!parsedValueAttr.has_value()) {
OpAsmParser::UnresolvedOperand operand;
ParseResult parsedOperand = parser.parseOperand(operand);
if (failed(parsedOperand))
return parser.emitError(parser.getCurrentLocation())
<< "expected a valid attribute or operand as value associated "
<< "to key '" << key << "'";
// To make use of the operand, we need to store it in the options dict.
// As SSA-values cannot occur in attributes, what we do instead is store
// an attribute in its place that contains the index of the param-operand,
// so that an attr-value associated to the param can be resolved later on.
dynamicOptions.push_back(operand);
auto wrappedIndex = IntegerAttr::get(
IntegerType::get(parser.getContext(), 64), dynamicOptionsIdx++);
valueAttr =
transform::ParamOperandAttr::get(parser.getContext(), wrappedIndex);
} else if (failed(parsedValueAttr.value())) {
return failure(); // NB: Attempted parse should have output error message.
} else if (isa<transform::ParamOperandAttr>(valueAttr)) {
return parser.emitError(parser.getCurrentLocation())
<< "the param_operand attribute is a marker reserved for "
<< "indicating a value will be passed via params and is only used "
<< "in the generic print format";
}
keyValuePairs.push_back(NamedAttribute(key, valueAttr));
return success();
};
if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Braces,
parseKeyValuePair,
" in options dictionary"))
return failure(); // NB: Attempted parse should have output error message.
if (DictionaryAttr::findDuplicate(
keyValuePairs, /*isSorted=*/false) // Also sorts the keyValuePairs.
.has_value())
return parser.emitError(parser.getCurrentLocation())
<< "duplicate keys found in options dictionary";
options = DictionaryAttr::getWithSorted(parser.getContext(), keyValuePairs);
return success();
}
static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
Operation *op,
DictionaryAttr options,
ValueRange dynamicOptions) {
if (options.empty())
return;
printer << "{";
llvm::interleaveComma(options, printer, [&](NamedAttribute namedAttribute) {
printer << namedAttribute.getName() << " = ";
Attribute value = namedAttribute.getValue();
if (auto indexAttr = dyn_cast<transform::ParamOperandAttr>(value)) {
// Resolve index of param-operand to its actual SSA-value and print that.
printer.printOperand(dynamicOptions[indexAttr.getIndex().getInt()]);
} else {
printer.printAttribute(value);
}
});
printer << "}";
}
LogicalResult transform::ApplyRegisteredPassOp::verify() {
// Check that there is a one-to-one correspondence between param operands
// and references to dynamic options in the options dictionary.
auto dynamicOptions = SmallVector<Value>(getDynamicOptions());
for (NamedAttribute namedAttr : getOptions())
if (auto paramOperand =
dyn_cast<transform::ParamOperandAttr>(namedAttr.getValue())) {
size_t dynamicOptionIdx = paramOperand.getIndex().getInt();
if (dynamicOptionIdx < 0 || dynamicOptionIdx >= dynamicOptions.size())
return emitOpError()
<< "dynamic option index " << dynamicOptionIdx
<< " is out of bounds for the number of dynamic options: "
<< dynamicOptions.size();
if (dynamicOptions[dynamicOptionIdx] == nullptr)
return emitOpError() << "dynamic option index " << dynamicOptionIdx
<< " is already used in options";
dynamicOptions[dynamicOptionIdx] = nullptr; // Mark this option as used.
}
for (Value dynamicOption : dynamicOptions)
if (dynamicOption)
return emitOpError() << "a param operand does not have a corresponding "
<< "param_operand attr in the options dict";
return success();
}
//===----------------------------------------------------------------------===//
// CastOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::CastOp::applyToOne(transform::TransformRewriter &rewriter,
Operation *target, ApplyToEachResultList &results,
transform::TransformState &state) {
results.push_back(target);
return DiagnosedSilenceableFailure::success();
}
void transform::CastOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsPayload(effects);
onlyReadsHandle(getInputMutable(), effects);
producesHandle(getOperation()->getOpResults(), effects);
}
bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
assert(inputs.size() == 1 && "expected one input");
assert(outputs.size() == 1 && "expected one output");
return llvm::all_of(
std::initializer_list<Type>{inputs.front(), outputs.front()},
llvm::IsaPred<transform::TransformHandleTypeInterface>);
}
//===----------------------------------------------------------------------===//
// CollectMatchingOp
//===----------------------------------------------------------------------===//
/// Applies matcher operations from the given `block` using
/// `blockArgumentMapping` to initialize block arguments. Updates `state`
/// accordingly. If any of the matcher produces a silenceable failure, discards
/// it (printing the content to the debug output stream) and returns failure. If
/// any of the matchers produces a definite failure, reports it and returns
/// failure. If all matchers in the block succeed, populates `mappings` with the
/// payload entities associated with the block terminator operands. Note that
/// `mappings` will be cleared before that.
static DiagnosedSilenceableFailure
matchBlock(Block &block,
ArrayRef<SmallVector<transform::MappedValue>> blockArgumentMapping,
transform::TransformState &state,
SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings) {
assert(block.getParent() && "cannot match using a detached block");
auto matchScope = state.make_region_scope(*block.getParent());
if (failed(
state.mapBlockArguments(block.getArguments(), blockArgumentMapping)))
return DiagnosedSilenceableFailure::definiteFailure();
for (Operation &match : block.without_terminator()) {
if (!isa<transform::MatchOpInterface>(match)) {
return emitDefiniteFailure(match.getLoc())
<< "expected operations in the match part to "
"implement MatchOpInterface";
}
DiagnosedSilenceableFailure diag =
state.applyTransform(cast<transform::TransformOpInterface>(match));
if (diag.succeeded())
continue;
return diag;
}
// Remember the values mapped to the terminator operands so we can
// forward them to the action.
ValueRange yieldedValues = block.getTerminator()->getOperands();
// Our contract with the caller is that the mappings will contain only the
// newly mapped values, clear the rest.
mappings.clear();
transform::detail::prepareValueMappings(mappings, yieldedValues, state);
return DiagnosedSilenceableFailure::success();
}
/// Returns `true` if both types implement one of the interfaces provided as
/// template parameters.
template <typename... Tys>
static bool implementSameInterface(Type t1, Type t2) {
return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
}
/// Returns `true` if both types implement one of the transform dialect
/// interfaces.
static bool implementSameTransformInterface(Type t1, Type t2) {
return implementSameInterface<transform::TransformHandleTypeInterface,
transform::TransformParamTypeInterface,
transform::TransformValueHandleTypeInterface>(
t1, t2);
}
//===----------------------------------------------------------------------===//
// CollectMatchingOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
getOperation(), getMatcher());
if (matcher.isExternal()) {
return emitDefiniteFailure()
<< "unresolved external symbol " << getMatcher();
}
SmallVector<SmallVector<MappedValue>, 2> rawResults;
rawResults.resize(getOperation()->getNumResults());
std::optional<DiagnosedSilenceableFailure> maybeFailure;
for (Operation *root : state.getPayloadOps(getRoot())) {
WalkResult walkResult = root->walk([&](Operation *op) {
DEBUG_MATCHER({
DBGS_MATCHER() << "matching ";
op->print(llvm::dbgs(),
OpPrintingFlags().assumeVerified().skipRegions());
llvm::dbgs() << " @" << op << "\n";
});
// Try matching.
SmallVector<SmallVector<MappedValue>> mappings;
SmallVector<transform::MappedValue> inputMapping({op});
DiagnosedSilenceableFailure diag = matchBlock(
matcher.getFunctionBody().front(),
ArrayRef<SmallVector<transform::MappedValue>>(inputMapping), state,
mappings);
if (diag.isDefiniteFailure())
return WalkResult::interrupt();
if (diag.isSilenceableFailure()) {
DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
<< " failed: " << diag.getMessage());
return WalkResult::advance();
}
// If succeeded, collect results.
for (auto &&[i, mapping] : llvm::enumerate(mappings)) {
if (mapping.size() != 1) {
maybeFailure.emplace(emitSilenceableError()
<< "result #" << i << ", associated with "
<< mapping.size()
<< " payload objects, expected 1");
return WalkResult::interrupt();
}
rawResults[i].push_back(mapping[0]);
}
return WalkResult::advance();
});
if (walkResult.wasInterrupted())
return std::move(*maybeFailure);
assert(!maybeFailure && "failure set but the walk was not interrupted");
for (auto &&[opResult, rawResult] :
llvm::zip_equal(getOperation()->getResults(), rawResults)) {
results.setMappedValues(opResult, rawResult);
}
}
return DiagnosedSilenceableFailure::success();
}
void transform::CollectMatchingOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getRootMutable(), effects);
producesHandle(getOperation()->getOpResults(), effects);
onlyReadsPayload(effects);
}
LogicalResult transform::CollectMatchingOp::verifySymbolUses(
SymbolTableCollection &symbolTable) {
auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher()));
if (!matcherSymbol ||
!isa<TransformOpInterface>(matcherSymbol.getOperation()))
return emitError() << "unresolved matcher symbol " << getMatcher();
ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes();
if (argumentTypes.size() != 1 ||
!isa<TransformHandleTypeInterface>(argumentTypes[0])) {
return emitError()
<< "expected the matcher to take one operation handle argument";
}
if (!matcherSymbol.getArgAttr(
0, transform::TransformDialect::kArgReadOnlyAttrName)) {
return emitError() << "expected the matcher argument to be marked readonly";
}
ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes();
if (resultTypes.size() != getOperation()->getNumResults()) {
return emitError()
<< "expected the matcher to yield as many values as op has results ("
<< getOperation()->getNumResults() << "), got "
<< resultTypes.size();
}
for (auto &&[i, matcherType, resultType] :
llvm::enumerate(resultTypes, getOperation()->getResultTypes())) {
if (implementSameTransformInterface(matcherType, resultType))
continue;
return emitError()
<< "mismatching type interfaces for matcher result and op result #"
<< i;
}
return success();
}
//===----------------------------------------------------------------------===//
// ForeachMatchOp
//===----------------------------------------------------------------------===//
// This is fine because nothing is actually consumed by this op.
bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() { return true; }
DiagnosedSilenceableFailure
transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<std::pair<FunctionOpInterface, FunctionOpInterface>>
matchActionPairs;
matchActionPairs.reserve(getMatchers().size());
SymbolTableCollection symbolTable;
for (auto &&[matcher, action] :
llvm::zip_equal(getMatchers(), getActions())) {
auto matcherSymbol =
symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
getOperation(), cast<SymbolRefAttr>(matcher));
auto actionSymbol =
symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
getOperation(), cast<SymbolRefAttr>(action));
assert(matcherSymbol && actionSymbol &&
"unresolved symbols not caught by the verifier");
if (matcherSymbol.isExternal())
return emitDefiniteFailure() << "unresolved external symbol " << matcher;
if (actionSymbol.isExternal())
return emitDefiniteFailure() << "unresolved external symbol " << action;
matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
}
DiagnosedSilenceableFailure overallDiag =
DiagnosedSilenceableFailure::success();
SmallVector<SmallVector<MappedValue>> matchInputMapping;
SmallVector<SmallVector<MappedValue>> matchOutputMapping;
SmallVector<SmallVector<MappedValue>> actionResultMapping;
// Explicitly add the mapping for the first block argument (the op being
// matched).
matchInputMapping.emplace_back();
transform::detail::prepareValueMappings(matchInputMapping,
getForwardedInputs(), state);
SmallVector<MappedValue> &firstMatchArgument = matchInputMapping.front();
actionResultMapping.resize(getForwardedOutputs().size());
for (Operation *root : state.getPayloadOps(getRoot())) {
WalkResult walkResult = root->walk([&](Operation *op) {
// If getRestrictRoot is not present, skip over the root op itself so we
// don't invalidate it.
if (!getRestrictRoot() && op == root)
return WalkResult::advance();
DEBUG_MATCHER({
DBGS_MATCHER() << "matching ";
op->print(llvm::dbgs(),
OpPrintingFlags().assumeVerified().skipRegions());
llvm::dbgs() << " @" << op << "\n";
});
firstMatchArgument.clear();
firstMatchArgument.push_back(op);
// Try all the match/action pairs until the first successful match.
for (auto [matcher, action] : matchActionPairs) {
DiagnosedSilenceableFailure diag =
matchBlock(matcher.getFunctionBody().front(), matchInputMapping,
state, matchOutputMapping);
if (diag.isDefiniteFailure())
return WalkResult::interrupt();
if (diag.isSilenceableFailure()) {
DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
<< " failed: " << diag.getMessage());
continue;
}
auto scope = state.make_region_scope(action.getFunctionBody());
if (failed(state.mapBlockArguments(
action.getFunctionBody().front().getArguments(),
matchOutputMapping))) {
return WalkResult::interrupt();
}
for (Operation &transform :
action.getFunctionBody().front().without_terminator()) {
DiagnosedSilenceableFailure result =
state.applyTransform(cast<TransformOpInterface>(transform));
if (result.isDefiniteFailure())
return WalkResult::interrupt();
if (result.isSilenceableFailure()) {
if (overallDiag.succeeded()) {
overallDiag = emitSilenceableError() << "actions failed";
}
overallDiag.attachNote(action->getLoc())
<< "failed action: " << result.getMessage();
overallDiag.attachNote(op->getLoc())
<< "when applied to this matching payload";
(void)result.silence();
continue;
}
}
if (failed(detail::appendValueMappings(
MutableArrayRef<SmallVector<MappedValue>>(actionResultMapping),
action.getFunctionBody().front().getTerminator()->getOperands(),
state, getFlattenResults()))) {
emitDefiniteFailure()
<< "action @" << action.getName()
<< " has results associated with multiple payload entities, "
"but flattening was not requested";
return WalkResult::interrupt();
}
break;
}
return WalkResult::advance();
});
if (walkResult.wasInterrupted())
return DiagnosedSilenceableFailure::definiteFailure();
}
// The root operation should not have been affected, so we can just reassign
// the payload to the result. Note that we need to consume the root handle to
// make sure any handles to operations inside, that could have been affected
// by actions, are invalidated.
results.set(llvm::cast<OpResult>(getUpdated()),
state.getPayloadOps(getRoot()));
for (auto &&[result, mapping] :
llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) {
results.setMappedValues(result, mapping);
}
return overallDiag;
}
void transform::ForeachMatchOp::getAsmResultNames(
OpAsmSetValueNameFn setNameFn) {
setNameFn(getUpdated(), "updated_root");
for (Value v : getForwardedOutputs()) {
setNameFn(v, "yielded");
}
}
void transform::ForeachMatchOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
// Bail if invalid.
if (getOperation()->getNumOperands() < 1 ||
getOperation()->getNumResults() < 1) {
return modifiesPayload(effects);
}
consumesHandle(getRootMutable(), effects);
onlyReadsHandle(getForwardedInputsMutable(), effects);
producesHandle(getOperation()->getOpResults(), effects);
modifiesPayload(effects);
}
/// Parses the comma-separated list of symbol reference pairs of the format
/// `@matcher -> @action`.
static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
ArrayAttr &matchers,
ArrayAttr &actions) {
StringAttr matcher;
StringAttr action;
SmallVector<Attribute> matcherList;
SmallVector<Attribute> actionList;
do {
if (parser.parseSymbolName(matcher) || parser.parseArrow() ||
parser.parseSymbolName(action)) {
return failure();
}
matcherList.push_back(SymbolRefAttr::get(matcher));
actionList.push_back(SymbolRefAttr::get(action));
} while (parser.parseOptionalComma().succeeded());
matchers = parser.getBuilder().getArrayAttr(matcherList);
actions = parser.getBuilder().getArrayAttr(actionList);
return success();
}
/// Prints the comma-separated list of symbol reference pairs of the format
/// `@matcher -> @action`.
static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op,
ArrayAttr matchers, ArrayAttr actions) {
printer.increaseIndent();
printer.increaseIndent();
for (auto &&[matcher, action, idx] : llvm::zip_equal(
matchers, actions, llvm::seq<unsigned>(0, matchers.size()))) {
printer.printNewline();
printer << cast<SymbolRefAttr>(matcher) << " -> "
<< cast<SymbolRefAttr>(action);
if (idx != matchers.size() - 1)
printer << ", ";
}
printer.decreaseIndent();
printer.decreaseIndent();
}
LogicalResult transform::ForeachMatchOp::verify() {
if (getMatchers().size() != getActions().size())
return emitOpError() << "expected the same number of matchers and actions";
if (getMatchers().empty())
return emitOpError() << "expected at least one match/action pair";
llvm::SmallPtrSet<Attribute, 8> matcherNames;
for (Attribute name : getMatchers()) {
if (matcherNames.insert(name).second)
continue;
emitWarning() << "matcher " << name
<< " is used more than once, only the first match will apply";
}
return success();
}
/// Checks that the attributes of the function-like operation have correct
/// consumption effect annotations. If `alsoVerifyInternal`, checks for
/// annotations being present even if they can be inferred from the body.
static DiagnosedSilenceableFailure
verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings,
bool alsoVerifyInternal = false) {
auto transformOp = cast<transform::TransformOpInterface>(op.getOperation());
llvm::SmallDenseSet<unsigned> consumedArguments;
if (!op.isExternal()) {
transform::getConsumedBlockArguments(op.getFunctionBody().front(),
consumedArguments);
}
for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
bool isConsumed =
op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
nullptr;
bool isReadOnly =
op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
nullptr;
if (isConsumed && isReadOnly) {
return transformOp.emitSilenceableError()
<< "argument #" << i << " cannot be both readonly and consumed";
}
if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
return transformOp.emitSilenceableError()
<< "must provide consumed/readonly status for arguments of "
"external or called ops";
}
if (op.isExternal())
continue;
if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
return transformOp.emitSilenceableError()
<< "argument #" << i
<< " is consumed in the body but is not marked as such";
}
if (emitWarnings && !consumedArguments.contains(i) && isConsumed) {
// Cannot use op.emitWarning() here as it would attempt to verify the op
// before printing, resulting in infinite recursion.
emitWarning(op->getLoc())
<< "op argument #" << i
<< " is not consumed in the body but is marked as consumed";
}
}
return DiagnosedSilenceableFailure::success();
}
LogicalResult transform::ForeachMatchOp::verifySymbolUses(
SymbolTableCollection &symbolTable) {
assert(getMatchers().size() == getActions().size());
auto consumedAttr =
StringAttr::get(getContext(), TransformDialect::kArgConsumedAttrName);
for (auto &&[matcher, action] :
llvm::zip_equal(getMatchers(), getActions())) {
// Presence and typing.
auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
symbolTable.lookupNearestSymbolFrom(getOperation(),
cast<SymbolRefAttr>(matcher)));
auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>(
symbolTable.lookupNearestSymbolFrom(getOperation(),
cast<SymbolRefAttr>(action)));
if (!matcherSymbol ||
!isa<TransformOpInterface>(matcherSymbol.getOperation()))
return emitError() << "unresolved matcher symbol " << matcher;
if (!actionSymbol ||
!isa<TransformOpInterface>(actionSymbol.getOperation()))
return emitError() << "unresolved action symbol " << action;
if (failed(verifyFunctionLikeConsumeAnnotations(matcherSymbol,
/*emitWarnings=*/false,
/*alsoVerifyInternal=*/true)
.checkAndReport())) {
return failure();
}
if (failed(verifyFunctionLikeConsumeAnnotations(actionSymbol,
/*emitWarnings=*/false,
/*alsoVerifyInternal=*/true)
.checkAndReport())) {
return failure();
}
// Input -> matcher forwarding.
TypeRange operandTypes = getOperandTypes();
TypeRange matcherArguments = matcherSymbol.getArgumentTypes();
if (operandTypes.size() != matcherArguments.size()) {
InFlightDiagnostic diag =
emitError() << "the number of operands (" << operandTypes.size()
<< ") doesn't match the number of matcher arguments ("
<< matcherArguments.size() << ") for " << matcher;
diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
return diag;
}
for (auto &&[i, operand, argument] :
llvm::enumerate(operandTypes, matcherArguments)) {
if (matcherSymbol.getArgAttr(i, consumedAttr)) {
InFlightDiagnostic diag =
emitOpError()
<< "does not expect matcher symbol to consume its operand #" << i;
diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
return diag;
}
if (implementSameTransformInterface(operand, argument))
continue;
InFlightDiagnostic diag =
emitError()
<< "mismatching type interfaces for operand and matcher argument #"
<< i << " of matcher " << matcher;
diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
return diag;
}
// Matcher -> action forwarding.
TypeRange matcherResults = matcherSymbol.getResultTypes();
TypeRange actionArguments = actionSymbol.getArgumentTypes();
if (matcherResults.size() != actionArguments.size()) {
return emitError() << "mismatching number of matcher results and "
"action arguments between "
<< matcher << " (" << matcherResults.size() << ") and "
<< action << " (" << actionArguments.size() << ")";
}
for (auto &&[i, matcherType, actionType] :
llvm::enumerate(matcherResults, actionArguments)) {
if (implementSameTransformInterface(matcherType, actionType))
continue;
return emitError() << "mismatching type interfaces for matcher result "
"and action argument #"
<< i << "of matcher " << matcher << " and action "
<< action;
}
// Action -> result forwarding.
TypeRange actionResults = actionSymbol.getResultTypes();
auto resultTypes = TypeRange(getResultTypes()).drop_front();
if (actionResults.size() != resultTypes.size()) {
InFlightDiagnostic diag =
emitError() << "the number of action results ("
<< actionResults.size() << ") for " << action
<< " doesn't match the number of extra op results ("
<< resultTypes.size() << ")";
diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
return diag;
}
for (auto &&[i, resultType, actionType] :
llvm::enumerate(resultTypes, actionResults)) {
if (implementSameTransformInterface(resultType, actionType))
continue;
InFlightDiagnostic diag =
emitError() << "mismatching type interfaces for action result #" << i
<< " of action " << action << " and op result";
diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
return diag;
}
}
return success();
}
//===----------------------------------------------------------------------===//
// ForeachOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
// We store the payloads before executing the body as ops may be removed from
// the mapping by the TrackingRewriter while iteration is in progress.
SmallVector<SmallVector<MappedValue>> payloads;
detail::prepareValueMappings(payloads, getTargets(), state);
size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
bool withZipShortest = getWithZipShortest();
// In case of `zip_shortest`, set the number of iterations to the
// smallest payload in the targets.
if (withZipShortest) {
numIterations =
llvm::min_element(payloads, [&](const SmallVector<MappedValue> &A,
const SmallVector<MappedValue> &B) {
return A.size() < B.size();
})->size();
for (size_t argIdx = 0; argIdx < payloads.size(); argIdx++)
payloads[argIdx].resize(numIterations);
}
// As we will be "zipping" over them, check all payloads have the same size.
// `zip_shortest` adjusts all payloads to the same size, so skip this check
// when true.
for (size_t argIdx = 1; !withZipShortest && argIdx < payloads.size();
argIdx++) {
if (payloads[argIdx].size() != numIterations) {
return emitSilenceableError()
<< "prior targets' payload size (" << numIterations
<< ") differs from payload size (" << payloads[argIdx].size()
<< ") of target " << getTargets()[argIdx];
}
}
// Start iterating, indexing into payloads to obtain the right arguments to
// call the body with - each slice of payloads at the same argument index
// corresponding to a tuple to use as the body's block arguments.
ArrayRef<BlockArgument> blockArguments = getBody().front().getArguments();
SmallVector<SmallVector<MappedValue>> zippedResults(getNumResults(), {});
for (size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) {
auto scope = state.make_region_scope(getBody());
// Set up arguments to the region's block.
for (auto &&[argIdx, blockArg] : llvm::enumerate(blockArguments)) {
MappedValue argument = payloads[argIdx][iterIdx];
// Note that each blockArg's handle gets associated with just a single
// element from the corresponding target's payload.
if (failed(state.mapBlockArgument(blockArg, {argument})))
return DiagnosedSilenceableFailure::definiteFailure();
}
// Execute loop body.
for (Operation &transform : getBody().front().without_terminator()) {
DiagnosedSilenceableFailure result = state.applyTransform(
llvm::cast<transform::TransformOpInterface>(transform));
if (!result.succeeded())
return result;
}
// Append yielded payloads to corresponding results from prior iterations.
OperandRange yieldOperands = getYieldOp().getOperands();
for (auto &&[result, yieldOperand, resTuple] :
llvm::zip_equal(getResults(), yieldOperands, zippedResults))
// NB: each iteration we add any number of ops/vals/params to a result.
if (isa<TransformHandleTypeInterface>(result.getType()))
llvm::append_range(resTuple, state.getPayloadOps(yieldOperand));
else if (isa<TransformValueHandleTypeInterface>(result.getType()))
llvm::append_range(resTuple, state.getPayloadValues(yieldOperand));
else if (isa<TransformParamTypeInterface>(result.getType()))
llvm::append_range(resTuple, state.getParams(yieldOperand));
else
assert(false && "unhandled handle type");
}
// Associate the accumulated result payloads to the op's actual results.
for (auto &&[result, resPayload] : zip_equal(getResults(), zippedResults))
results.setMappedValues(llvm::cast<OpResult>(result), resPayload);
return DiagnosedSilenceableFailure::success();
}
void transform::ForeachOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
// NB: this `zip` should be `zip_equal` - while this op's verifier catches
// arity errors, this method might get called before/in absence of `verify()`.
for (auto &&[target, blockArg] :
llvm::zip(getTargetsMutable(), getBody().front().getArguments())) {
BlockArgument blockArgument = blockArg;
if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
return isHandleConsumed(blockArgument,
cast<TransformOpInterface>(&op));
})) {
consumesHandle(target, effects);
} else {
onlyReadsHandle(target, effects);
}
}
if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
return doesModifyPayload(cast<TransformOpInterface>(&op));
})) {
modifiesPayload(effects);
} else if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
return doesReadPayload(cast<TransformOpInterface>(&op));
})) {
onlyReadsPayload(effects);
}
producesHandle(getOperation()->getOpResults(), effects);
}
void transform::ForeachOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
Region *bodyRegion = &getBody();
if (point.isParent()) {
regions.emplace_back(bodyRegion, bodyRegion->getArguments());
return;
}
// Branch back to the region or the parent.
assert(point == getBody() && "unexpected region index");
regions.emplace_back(bodyRegion, bodyRegion->getArguments());
regions.emplace_back();
}
OperandRange
transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) {
// Each block argument handle is mapped to a subset (one op to be precise)
// of the payload of the corresponding `targets` operand of ForeachOp.
assert(point == getBody() && "unexpected region index");
return getOperation()->getOperands();
}
transform::YieldOp transform::ForeachOp::getYieldOp() {
return cast<transform::YieldOp>(getBody().front().getTerminator());
}
LogicalResult transform::ForeachOp::verify() {
for (auto [targetOpt, bodyArgOpt] :
llvm::zip_longest(getTargets(), getBody().front().getArguments())) {
if (!targetOpt || !bodyArgOpt)
return emitOpError() << "expects the same number of targets as the body "
"has block arguments";
if (targetOpt.value().getType() != bodyArgOpt.value().getType())
return emitOpError(
"expects co-indexed targets and the body's "
"block arguments to have the same op/value/param type");
}
for (auto [resultOpt, yieldOperandOpt] :
llvm::zip_longest(getResults(), getYieldOp().getOperands())) {
if (!resultOpt || !yieldOperandOpt)
return emitOpError() << "expects the same number of results as the "
"yield terminator has operands";
if (resultOpt.value().getType() != yieldOperandOpt.value().getType())
return emitOpError("expects co-indexed results and yield "
"operands to have the same op/value/param type");
}
return success();
}
//===----------------------------------------------------------------------===//
// GetParentOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::GetParentOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<Operation *> parents;
DenseSet<Operation *> resultSet;
for (Operation *target : state.getPayloadOps(getTarget())) {
Operation *parent = target;
for (int64_t i = 0, e = getNthParent(); i < e; ++i) {
parent = parent->getParentOp();
while (parent) {
bool checkIsolatedFromAbove =
!getIsolatedFromAbove() ||
parent->hasTrait<OpTrait::IsIsolatedFromAbove>();
bool checkOpName = !getOpName().has_value() ||
parent->getName().getStringRef() == *getOpName();
if (checkIsolatedFromAbove && checkOpName)
break;
parent = parent->getParentOp();
}
if (!parent) {
if (getAllowEmptyResults()) {
results.set(llvm::cast<OpResult>(getResult()), parents);
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure diag =
emitSilenceableError()
<< "could not find a parent op that matches all requirements";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}
}
if (getDeduplicate()) {
if (resultSet.insert(parent).second)
parents.push_back(parent);
} else {
parents.push_back(parent);
}
}
results.set(llvm::cast<OpResult>(getResult()), parents);
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// GetConsumersOfResult
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::GetConsumersOfResult::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
int64_t resultNumber = getResultNumber();
auto payloadOps = state.getPayloadOps(getTarget());
if (std::empty(payloadOps)) {
results.set(cast<OpResult>(getResult()), {});
return DiagnosedSilenceableFailure::success();
}
if (!llvm::hasSingleElement(payloadOps))
return emitDefiniteFailure()
<< "handle must be mapped to exactly one payload op";
Operation *target = *payloadOps.begin();
if (target->getNumResults() <= resultNumber)
return emitDefiniteFailure() << "result number overflow";
results.set(llvm::cast<OpResult>(getResult()),
llvm::to_vector(target->getResult(resultNumber).getUsers()));
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// GetDefiningOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::GetDefiningOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<Operation *> definingOps;
for (Value v : state.getPayloadValues(getTarget())) {
if (llvm::isa<BlockArgument>(v)) {
DiagnosedSilenceableFailure diag =
emitSilenceableError() << "cannot get defining op of block argument";
diag.attachNote(v.getLoc()) << "target value";
return diag;
}
definingOps.push_back(v.getDefiningOp());
}
results.set(llvm::cast<OpResult>(getResult()), definingOps);
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// GetProducerOfOperand
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
int64_t operandNumber = getOperandNumber();
SmallVector<Operation *> producers;
for (Operation *target : state.getPayloadOps(getTarget())) {
Operation *producer =
target->getNumOperands() <= operandNumber
? nullptr
: target->getOperand(operandNumber).getDefiningOp();
if (!producer) {
DiagnosedSilenceableFailure diag =
emitSilenceableError()
<< "could not find a producer for operand number: " << operandNumber
<< " of " << *target;
diag.attachNote(target->getLoc()) << "target op";
return diag;
}
producers.push_back(producer);
}
results.set(llvm::cast<OpResult>(getResult()), producers);
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// GetOperandOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::GetOperandOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<Value> operands;
for (Operation *target : state.getPayloadOps(getTarget())) {
SmallVector<int64_t> operandPositions;
DiagnosedSilenceableFailure diag = expandTargetSpecification(
getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
target->getNumOperands(), operandPositions);
if (diag.isSilenceableFailure()) {
diag.attachNote(target->getLoc())
<< "while considering positions of this payload operation";
return diag;
}
llvm::append_range(operands,
llvm::map_range(operandPositions, [&](int64_t pos) {
return target->getOperand(pos);
}));
}
results.setValues(cast<OpResult>(getResult()), operands);
return DiagnosedSilenceableFailure::success();
}
LogicalResult transform::GetOperandOp::verify() {
return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
getIsInverted(), getIsAll());
}
//===----------------------------------------------------------------------===//
// GetResultOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::GetResultOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<Value> opResults;
for (Operation *target : state.getPayloadOps(getTarget())) {
SmallVector<int64_t> resultPositions;
DiagnosedSilenceableFailure diag = expandTargetSpecification(
getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
target->getNumResults(), resultPositions);
if (diag.isSilenceableFailure()) {
diag.attachNote(target->getLoc())
<< "while considering positions of this payload operation";
return diag;
}
llvm::append_range(opResults,
llvm::map_range(resultPositions, [&](int64_t pos) {
return target->getResult(pos);
}));
}
results.setValues(cast<OpResult>(getResult()), opResults);
return DiagnosedSilenceableFailure::success();
}
LogicalResult transform::GetResultOp::verify() {
return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
getIsInverted(), getIsAll());
}
//===----------------------------------------------------------------------===//
// GetTypeOp
//===----------------------------------------------------------------------===//
void transform::GetTypeOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getValueMutable(), effects);
producesHandle(getOperation()->getOpResults(), effects);
onlyReadsPayload(effects);
}
DiagnosedSilenceableFailure
transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<Attribute> params;
for (Value value : state.getPayloadValues(getValue())) {
Type type = value.getType();
if (getElemental()) {
if (auto shaped = dyn_cast<ShapedType>(type)) {
type = shaped.getElementType();
}
}
params.push_back(TypeAttr::get(type));
}
results.setParams(cast<OpResult>(getResult()), params);
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// IncludeOp
//===----------------------------------------------------------------------===//
/// Applies the transform ops contained in `block`. Maps `results` to the same
/// values as the operands of the block terminator.
static DiagnosedSilenceableFailure
applySequenceBlock(Block &block, transform::FailurePropagationMode mode,
transform::TransformState &state,
transform::TransformResults &results) {
// Apply the sequenced ops one by one.
for (Operation &transform : block.without_terminator()) {
DiagnosedSilenceableFailure result =
state.applyTransform(cast<transform::TransformOpInterface>(transform));
if (result.isDefiniteFailure())
return result;
if (result.isSilenceableFailure()) {
if (mode == transform::FailurePropagationMode::Propagate) {
// Propagate empty results in case of early exit.
forwardEmptyOperands(&block, state, results);
return result;
}
(void)result.silence();
}
}
// Forward the operation mapping for values yielded from the sequence to the
// values produced by the sequence op.
transform::detail::forwardTerminatorOperands(&block, state, results);
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure
transform::IncludeOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
getOperation(), getTarget());
assert(callee && "unverified reference to unknown symbol");
if (callee.isExternal())
return emitDefiniteFailure() << "unresolved external named sequence";
// Map operands to block arguments.
SmallVector<SmallVector<MappedValue>> mappings;
detail::prepareValueMappings(mappings, getOperands(), state);
auto scope = state.make_region_scope(callee.getBody());
for (auto &&[arg, map] :
llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
if (failed(state.mapBlockArgument(arg, map)))
return DiagnosedSilenceableFailure::definiteFailure();
}
DiagnosedSilenceableFailure result = applySequenceBlock(
callee.getBody().front(), getFailurePropagationMode(), state, results);
mappings.clear();
detail::prepareValueMappings(
mappings, callee.getBody().front().getTerminator()->getOperands(), state);
for (auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings))
results.setMappedValues(result, mapping);
return result;
}
static DiagnosedSilenceableFailure
verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings);
void transform::IncludeOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
// Always mark as modifying the payload.
// TODO: a mechanism to annotate effects on payload. Even when all handles are
// only read, the payload may still be modified, so we currently stay on the
// conservative side and always indicate modification. This may prevent some
// code reordering.
modifiesPayload(effects);
// Results are always produced.
producesHandle(getOperation()->getOpResults(), effects);
// Adds default effects to operands and results. This will be added if
// preconditions fail so the trait verifier doesn't complain about missing
// effects and the real precondition failure is reported later on.
auto defaultEffects = [&] {
onlyReadsHandle(getOperation()->getOpOperands(), effects);
};
// Bail if the callee is unknown. This may run as part of the verification
// process before we verified the validity of the callee or of this op.
auto target =
getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
if (!target)
return defaultEffects();
auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
getOperation(), getTarget());
if (!callee)
return defaultEffects();
DiagnosedSilenceableFailure earlyVerifierResult =
verifyNamedSequenceOp(callee, /*emitWarnings=*/false);
if (!earlyVerifierResult.succeeded()) {
(void)earlyVerifierResult.silence();
return defaultEffects();
}
for (unsigned i = 0, e = getNumOperands(); i < e; ++i) {
if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
consumesHandle(getOperation()->getOpOperand(i), effects);
else
onlyReadsHandle(getOperation()->getOpOperand(i), effects);
}
}
LogicalResult
transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// Access through indirection and do additional checking because this may be
// running before the main op verifier.
auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>("target");
if (!targetAttr)
return emitOpError() << "expects a 'target' symbol reference attribute";
auto target = symbolTable.lookupNearestSymbolFrom<transform::NamedSequenceOp>(
*this, targetAttr);
if (!target)
return emitOpError() << "does not reference a named transform sequence";
FunctionType fnType = target.getFunctionType();
if (fnType.getNumInputs() != getNumOperands())
return emitError("incorrect number of operands for callee");
for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
if (getOperand(i).getType() != fnType.getInput(i)) {
return emitOpError("operand type mismatch: expected operand type ")
<< fnType.getInput(i) << ", but provided "
<< getOperand(i).getType() << " for operand number " << i;
}
}
if (fnType.getNumResults() != getNumResults())
return emitError("incorrect number of results for callee");
for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
Type resultType = getResult(i).getType();
Type funcType = fnType.getResult(i);
if (!implementSameTransformInterface(resultType, funcType)) {
return emitOpError() << "type of result #" << i
<< " must implement the same transform dialect "
"interface as the corresponding callee result";
}
}
return verifyFunctionLikeConsumeAnnotations(
cast<FunctionOpInterface>(*target), /*emitWarnings=*/false,
/*alsoVerifyInternal=*/true)
.checkAndReport();
}
//===----------------------------------------------------------------------===//
// MatchOperationEmptyOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation(
::std::optional<::mlir::Operation *> maybeCurrent,
transform::TransformResults &results, transform::TransformState &state) {
if (!maybeCurrent.has_value()) {
DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp success\n"; });
return DiagnosedSilenceableFailure::success();
}
DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp failure\n"; });
return emitSilenceableError() << "operation is not empty";
}
//===----------------------------------------------------------------------===//
// MatchOperationNameOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform::MatchOperationNameOp::matchOperation(
Operation *current, transform::TransformResults &results,
transform::TransformState &state) {
StringRef currentOpName = current->getName().getStringRef();
for (auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
if (acceptedAttr.getValue() == currentOpName)
return DiagnosedSilenceableFailure::success();
}
return emitSilenceableError() << "wrong operation name";
}
//===----------------------------------------------------------------------===//
// MatchParamCmpIOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::MatchParamCmpIOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
auto signedAPIntAsString = [&](const APInt &value) {
std::string str;
llvm::raw_string_ostream os(str);
value.print(os, /*isSigned=*/true);
return str;
};
ArrayRef<Attribute> params = state.getParams(getParam());
ArrayRef<Attribute> references = state.getParams(getReference());
if (params.size() != references.size()) {
return emitSilenceableError()
<< "parameters have different payload lengths (" << params.size()
<< " vs " << references.size() << ")";
}
for (auto &&[i, param, reference] : llvm::enumerate(params, references)) {
auto intAttr = llvm::dyn_cast<IntegerAttr>(param);
auto refAttr = llvm::dyn_cast<IntegerAttr>(reference);
if (!intAttr || !refAttr) {
return emitDefiniteFailure()
<< "non-integer parameter value not expected";
}
if (intAttr.getType() != refAttr.getType()) {
return emitDefiniteFailure()
<< "mismatching integer attribute types in parameter #" << i;
}
APInt value = intAttr.getValue();
APInt refValue = refAttr.getValue();
// TODO: this copy will not be necessary in C++20.
int64_t position = i;
auto reportError = [&](StringRef direction) {
DiagnosedSilenceableFailure diag =
emitSilenceableError() << "expected parameter to be " << direction
<< " " << signedAPIntAsString(refValue)
<< ", got " << signedAPIntAsString(value);
diag.attachNote(getParam().getLoc())
<< "value # " << position
<< " associated with the parameter defined here";
return diag;
};
switch (getPredicate()) {
case MatchCmpIPredicate::eq:
if (value.eq(refValue))
break;
return reportError("equal to");
case MatchCmpIPredicate::ne:
if (value.ne(refValue))
break;
return reportError("not equal to");
case MatchCmpIPredicate::lt:
if (value.slt(refValue))
break;
return reportError("less than");
case MatchCmpIPredicate::le:
if (value.sle(refValue))
break;
return reportError("less than or equal to");
case MatchCmpIPredicate::gt:
if (value.sgt(refValue))
break;
return reportError("greater than");
case MatchCmpIPredicate::ge:
if (value.sge(refValue))
break;
return reportError("greater than or equal to");
}
}
return DiagnosedSilenceableFailure::success();
}
void transform::MatchParamCmpIOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getParamMutable(), effects);
onlyReadsHandle(getReferenceMutable(), effects);
}
//===----------------------------------------------------------------------===//
// ParamConstantOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::ParamConstantOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
results.setParams(cast<OpResult>(getParam()), {getValue()});
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// MergeHandlesOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::MergeHandlesOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
ValueRange handles = getHandles();
if (isa<TransformHandleTypeInterface>(handles.front().getType())) {
SmallVector<Operation *> operations;
for (Value operand : handles)
llvm::append_range(operations, state.getPayloadOps(operand));
if (!getDeduplicate()) {
results.set(llvm::cast<OpResult>(getResult()), operations);
return DiagnosedSilenceableFailure::success();
}
SetVector<Operation *> uniqued(llvm::from_range, operations);
results.set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
return DiagnosedSilenceableFailure::success();
}
if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
SmallVector<Attribute> attrs;
for (Value attribute : handles)
llvm::append_range(attrs, state.getParams(attribute));
if (!getDeduplicate()) {
results.setParams(cast<OpResult>(getResult()), attrs);
return DiagnosedSilenceableFailure::success();
}
SetVector<Attribute> uniqued(llvm::from_range, attrs);
results.setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
return DiagnosedSilenceableFailure::success();
}
assert(
llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
"expected value handle type");
SmallVector<Value> payloadValues;
for (Value value : handles)
llvm::append_range(payloadValues, state.getPayloadValues(value));
if (!getDeduplicate()) {
results.setValues(cast<OpResult>(getResult()), payloadValues);
return DiagnosedSilenceableFailure::success();
}
SetVector<Value> uniqued(llvm::from_range, payloadValues);
results.setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
return DiagnosedSilenceableFailure::success();
}
bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
// Handles may be the same if deduplicating is enabled.
return getDeduplicate();
}
void transform::MergeHandlesOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getHandlesMutable(), effects);
producesHandle(getOperation()->getOpResults(), effects);
// There are no effects on the Payload IR as this is only a handle
// manipulation.
}
OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
if (getDeduplicate() || getHandles().size() != 1)
return {};
// If deduplication is not required and there is only one operand, it can be
// used directly instead of merging.
return getHandles().front();
}
//===----------------------------------------------------------------------===//
// NamedSequenceOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
if (isExternal())
return emitDefiniteFailure() << "unresolved external named sequence";
// Map the entry block argument to the list of operations.
// Note: this is the same implementation as PossibleTopLevelTransformOp but
// without attaching the interface / trait since that is tailored to a
// dangling top-level op that does not get "called".
auto scope = state.make_region_scope(getBody());
if (failed(detail::mapPossibleTopLevelTransformOpBlockArguments(
state, this->getOperation(), getBody())))
return DiagnosedSilenceableFailure::definiteFailure();
return applySequenceBlock(getBody().front(),
FailurePropagationMode::Propagate, state, results);
}
void transform::NamedSequenceOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
ParseResult transform::NamedSequenceOp::parse(OpAsmParser &parser,
OperationState &result) {
return function_interface_impl::parseFunctionOp(
parser, result, /*allowVariadic=*/false,
getFunctionTypeAttrName(result.name),
[](Builder &builder, ArrayRef<Type> inputs, ArrayRef<Type> results,
function_interface_impl::VariadicFlag,
std::string &) { return builder.getFunctionType(inputs, results); },
getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
}
void transform::NamedSequenceOp::print(OpAsmPrinter &printer) {
function_interface_impl::printFunctionOp(
printer, cast<FunctionOpInterface>(getOperation()), /*isVariadic=*/false,
getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
getResAttrsAttrName());
}
/// Verifies that a symbol function-like transform dialect operation has the
/// signature and the terminator that have conforming types, i.e., types
/// implementing the same transform dialect type interface. If `allowExternal`
/// is set, allow external symbols (declarations) and don't check the terminator
/// as it may not exist.
static DiagnosedSilenceableFailure
verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal) {
if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
DiagnosedSilenceableFailure diag =
emitSilenceableFailure(op)
<< "cannot be defined inside another transform op";
diag.attachNote(parent.getLoc()) << "ancestor transform op";
return diag;
}
if (op.isExternal() || op.getFunctionBody().empty()) {
if (allowExternal)
return DiagnosedSilenceableFailure::success();
return emitSilenceableFailure(op) << "cannot be external";
}
if (op.getFunctionBody().front().empty())
return emitSilenceableFailure(op) << "expected a non-empty body block";
Operation *terminator = &op.getFunctionBody().front().back();
if (!isa<transform::YieldOp>(terminator)) {
DiagnosedSilenceableFailure diag = emitSilenceableFailure(op)
<< "expected '"
<< transform::YieldOp::getOperationName()
<< "' as terminator";
diag.attachNote(terminator->getLoc()) << "terminator";
return diag;
}
if (terminator->getNumOperands() != op.getResultTypes().size()) {
return emitSilenceableFailure(terminator)
<< "expected terminator to have as many operands as the parent op "
"has results";
}
for (auto [i, operandType, resultType] : llvm::zip_equal(
llvm::seq<unsigned>(0, terminator->getNumOperands()),
terminator->getOperands().getType(), op.getResultTypes())) {
if (operandType == resultType)
continue;
return emitSilenceableFailure(terminator)
<< "the type of the terminator operand #" << i
<< " must match the type of the corresponding parent op result ("
<< operandType << " vs " << resultType << ")";
}
return DiagnosedSilenceableFailure::success();
}
/// Verification of a NamedSequenceOp. This does not report the error
/// immediately, so it can be used to check for op's well-formedness before the
/// verifier runs, e.g., during trait verification.
static DiagnosedSilenceableFailure
verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings) {
if (Operation *parent = op->getParentWithTrait<OpTrait::SymbolTable>()) {
if (!parent->getAttr(
transform::TransformDialect::kWithNamedSequenceAttrName)) {
DiagnosedSilenceableFailure diag =
emitSilenceableFailure(op)
<< "expects the parent symbol table to have the '"
<< transform::TransformDialect::kWithNamedSequenceAttrName
<< "' attribute";
diag.attachNote(parent->getLoc()) << "symbol table operation";
return diag;
}
}
if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
DiagnosedSilenceableFailure diag =
emitSilenceableFailure(op)
<< "cannot be defined inside another transform op";
diag.attachNote(parent.getLoc()) << "ancestor transform op";
return diag;
}
if (op.isExternal() || op.getBody().empty())
return verifyFunctionLikeConsumeAnnotations(cast<FunctionOpInterface>(*op),
emitWarnings);
if (op.getBody().front().empty())
return emitSilenceableFailure(op) << "expected a non-empty body block";
Operation *terminator = &op.getBody().front().back();
if (!isa<transform::YieldOp>(terminator)) {
DiagnosedSilenceableFailure diag = emitSilenceableFailure(op)
<< "expected '"
<< transform::YieldOp::getOperationName()
<< "' as terminator";
diag.attachNote(terminator->getLoc()) << "terminator";
return diag;
}
if (terminator->getNumOperands() != op.getFunctionType().getNumResults()) {
return emitSilenceableFailure(terminator)
<< "expected terminator to have as many operands as the parent op "
"has results";
}
for (auto [i, operandType, resultType] :
llvm::zip_equal(llvm::seq<unsigned>(0, terminator->getNumOperands()),
terminator->getOperands().getType(),
op.getFunctionType().getResults())) {
if (operandType == resultType)
continue;
return emitSilenceableFailure(terminator)
<< "the type of the terminator operand #" << i
<< " must match the type of the corresponding parent op result ("
<< operandType << " vs " << resultType << ")";
}
auto funcOp = cast<FunctionOpInterface>(*op);
DiagnosedSilenceableFailure diag =
verifyFunctionLikeConsumeAnnotations(funcOp, emitWarnings);
if (!diag.succeeded())
return diag;
return verifyYieldingSingleBlockOp(funcOp,
/*allowExternal=*/true);
}
LogicalResult transform::NamedSequenceOp::verify() {
// Actual verification happens in a separate function for reusability.
return verifyNamedSequenceOp(*this, /*emitWarnings=*/true).checkAndReport();
}
template <typename FnTy>
static void buildSequenceBody(OpBuilder &builder, OperationState &state,
Type bbArgType, TypeRange extraBindingTypes,
FnTy bodyBuilder) {
SmallVector<Type> types;
types.reserve(1 + extraBindingTypes.size());
types.push_back(bbArgType);
llvm::append_range(types, extraBindingTypes);
OpBuilder::InsertionGuard guard(builder);
Region *region = state.regions.back().get();
Block *bodyBlock =
builder.createBlock(region, region->begin(), types,
SmallVector<Location>(types.size(), state.location));
// Populate body.
builder.setInsertionPointToStart(bodyBlock);
if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
} else {
bodyBuilder(builder, state.location, bodyBlock->getArgument(0),
bodyBlock->getArguments().drop_front());
}
}
void transform::NamedSequenceOp::build(OpBuilder &builder,
OperationState &state, StringRef symName,
Type rootType, TypeRange resultTypes,
SequenceBodyBuilderFn bodyBuilder,
ArrayRef<NamedAttribute> attrs,
ArrayRef<DictionaryAttr> argAttrs) {
state.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(symName));
state.addAttribute(getFunctionTypeAttrName(state.name),
TypeAttr::get(FunctionType::get(builder.getContext(),
rootType, resultTypes)));
state.attributes.append(attrs.begin(), attrs.end());
state.addRegion();
buildSequenceBody(builder, state, rootType,
/*extraBindingTypes=*/TypeRange(), bodyBuilder);
}
//===----------------------------------------------------------------------===//
// NumAssociationsOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
size_t numAssociations =
llvm::TypeSwitch<Type, size_t>(getHandle().getType())
.Case([&](TransformHandleTypeInterface opHandle) {
return llvm::range_size(state.getPayloadOps(getHandle()));
})
.Case([&](TransformValueHandleTypeInterface valueHandle) {
return llvm::range_size(state.getPayloadValues(getHandle()));
})
.Case([&](TransformParamTypeInterface param) {
return llvm::range_size(state.getParams(getHandle()));
})
.Default([](Type) {
llvm_unreachable("unknown kind of transform dialect type");
return 0;
});
results.setParams(cast<OpResult>(getNum()),
rewriter.getI64IntegerAttr(numAssociations));
return DiagnosedSilenceableFailure::success();
}
LogicalResult transform::NumAssociationsOp::verify() {
// Verify that the result type accepts an i64 attribute as payload.
auto resultType = cast<TransformParamTypeInterface>(getNum().getType());
return resultType
.checkPayload(getLoc(), {Builder(getContext()).getI64IntegerAttr(0)})
.checkAndReport();
}
//===----------------------------------------------------------------------===//
// SelectOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::SelectOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<Operation *> result;
auto payloadOps = state.getPayloadOps(getTarget());
for (Operation *op : payloadOps) {
if (op->getName().getStringRef() == getOpName())
result.push_back(op);
}
results.set(cast<OpResult>(getResult()), result);
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// SplitHandleOp
//===----------------------------------------------------------------------===//
void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result,
Value target, int64_t numResultHandles) {
result.addOperands(target);
result.addTypes(SmallVector<Type>(numResultHandles, target.getType()));
}
DiagnosedSilenceableFailure
transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
int64_t numPayloads =
llvm::TypeSwitch<Type, int64_t>(getHandle().getType())
.Case<TransformHandleTypeInterface>([&](auto x) {
return llvm::range_size(state.getPayloadOps(getHandle()));
})
.Case<TransformValueHandleTypeInterface>([&](auto x) {
return llvm::range_size(state.getPayloadValues(getHandle()));
})
.Case<TransformParamTypeInterface>([&](auto x) {
return llvm::range_size(state.getParams(getHandle()));
})
.Default([](auto x) {
llvm_unreachable("unknown transform dialect type interface");
return -1;
});
auto produceNumOpsError = [&]() {
return emitSilenceableError()
<< getHandle() << " expected to contain " << this->getNumResults()
<< " payloads but it contains " << numPayloads << " payloads";
};
// Fail if there are more payload ops than results and no overflow result was
// specified.
if (numPayloads > getNumResults() && !getOverflowResult().has_value())
return produceNumOpsError();
// Fail if there are more results than payload ops. Unless:
// - "fail_on_payload_too_small" is set to "false", or
// - "pass_through_empty_handle" is set to "true" and there are 0 payload ops.
if (numPayloads < getNumResults() && getFailOnPayloadTooSmall() &&
(numPayloads != 0 || !getPassThroughEmptyHandle()))
return produceNumOpsError();
// Distribute payloads.
SmallVector<SmallVector<MappedValue, 1>> resultHandles(getNumResults(), {});
if (getOverflowResult())
resultHandles[*getOverflowResult()].reserve(numPayloads - getNumResults());
auto container = [&]() {
if (isa<TransformHandleTypeInterface>(getHandle().getType())) {
return llvm::map_to_vector(
state.getPayloadOps(getHandle()),
[](Operation *op) -> MappedValue { return op; });
}
if (isa<TransformValueHandleTypeInterface>(getHandle().getType())) {
return llvm::map_to_vector(state.getPayloadValues(getHandle()),
[](Value v) -> MappedValue { return v; });
}
assert(isa<TransformParamTypeInterface>(getHandle().getType()) &&
"unsupported kind of transform dialect type");
return llvm::map_to_vector(state.getParams(getHandle()),
[](Attribute a) -> MappedValue { return a; });
}();
for (auto &&en : llvm::enumerate(container)) {
int64_t resultNum = en.index();
if (resultNum >= getNumResults())
resultNum = *getOverflowResult();
resultHandles[resultNum].push_back(en.value());
}
// Set transform op results.
for (auto &&it : llvm::enumerate(resultHandles))
results.setMappedValues(llvm::cast<OpResult>(getResult(it.index())),
it.value());
return DiagnosedSilenceableFailure::success();
}
void transform::SplitHandleOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getHandleMutable(), effects);
producesHandle(getOperation()->getOpResults(), effects);
// There are no effects on the Payload IR as this is only a handle
// manipulation.
}
LogicalResult transform::SplitHandleOp::verify() {
if (getOverflowResult().has_value() &&
!(*getOverflowResult() < getNumResults()))
return emitOpError("overflow_result is not a valid result index");
for (Type resultType : getResultTypes()) {
if (implementSameTransformInterface(getHandle().getType(), resultType))
continue;
return emitOpError("expects result types to implement the same transform "
"interface as the operand type");
}
return success();
}
//===----------------------------------------------------------------------===//
// ReplicateOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::ReplicateOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern()));
for (const auto &en : llvm::enumerate(getHandles())) {
Value handle = en.value();
if (isa<TransformHandleTypeInterface>(handle.getType())) {
SmallVector<Operation *> current =
llvm::to_vector(state.getPayloadOps(handle));
SmallVector<Operation *> payload;
payload.reserve(numRepetitions * current.size());
for (unsigned i = 0; i < numRepetitions; ++i)
llvm::append_range(payload, current);
results.set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload);
} else {
assert(llvm::isa<TransformParamTypeInterface>(handle.getType()) &&
"expected param type");
ArrayRef<Attribute> current = state.getParams(handle);
SmallVector<Attribute> params;
params.reserve(numRepetitions * current.size());
for (unsigned i = 0; i < numRepetitions; ++i)
llvm::append_range(params, current);
results.setParams(llvm::cast<OpResult>(getReplicated()[en.index()]),
params);
}
}
return DiagnosedSilenceableFailure::success();
}
void transform::ReplicateOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getPatternMutable(), effects);
onlyReadsHandle(getHandlesMutable(), effects);
producesHandle(getOperation()->getOpResults(), effects);
}
//===----------------------------------------------------------------------===//
// SequenceOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::SequenceOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
// Map the entry block argument to the list of operations.
auto scope = state.make_region_scope(*getBodyBlock()->getParent());
if (failed(mapBlockArguments(state)))
return DiagnosedSilenceableFailure::definiteFailure();
return applySequenceBlock(*getBodyBlock(), getFailurePropagationMode(), state,
results);
}
static ParseResult parseSequenceOpOperands(
OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
Type &rootType,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
SmallVectorImpl<Type> &extraBindingTypes) {
OpAsmParser::UnresolvedOperand rootOperand;
OptionalParseResult hasRoot = parser.parseOptionalOperand(rootOperand);
if (!hasRoot.has_value()) {
root = std::nullopt;
return success();
}
if (failed(hasRoot.value()))
return failure();
root = rootOperand;
if (succeeded(parser.parseOptionalComma())) {
if (failed(parser.parseOperandList(extraBindings)))
return failure();
}
if (failed(parser.parseColon()))
return failure();
// The paren is truly optional.
(void)parser.parseOptionalLParen();
if (failed(parser.parseType(rootType))) {
return failure();
}
if (!extraBindings.empty()) {
if (parser.parseComma() || parser.parseTypeList(extraBindingTypes))
return failure();
}
if (extraBindingTypes.size() != extraBindings.size()) {
return parser.emitError(parser.getNameLoc(),
"expected types to be provided for all operands");
}
// The paren is truly optional.
(void)parser.parseOptionalRParen();
return success();
}
static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
Value root, Type rootType,
ValueRange extraBindings,
TypeRange extraBindingTypes) {
if (!root)
return;
printer << root;
bool hasExtras = !extraBindings.empty();
if (hasExtras) {
printer << ", ";
printer.printOperands(extraBindings);
}
printer << " : ";
if (hasExtras)
printer << "(";
printer << rootType;
if (hasExtras)
printer << ", " << llvm::interleaved(extraBindingTypes) << ')';
}
/// Returns `true` if the given op operand may be consuming the handle value in
/// the Transform IR. That is, if it may have a Free effect on it.
static bool isValueUsePotentialConsumer(OpOperand &use) {
// Conservatively assume the effect being present in absence of the interface.
auto iface = dyn_cast<transform::TransformOpInterface>(use.getOwner());
if (!iface)
return true;
return isHandleConsumed(use.get(), iface);
}
LogicalResult
checkDoubleConsume(Value value,
function_ref<InFlightDiagnostic()> reportError) {
OpOperand *potentialConsumer = nullptr;
for (OpOperand &use : value.getUses()) {
if (!isValueUsePotentialConsumer(use))
continue;
if (!potentialConsumer) {
potentialConsumer = &use;
continue;
}
InFlightDiagnostic diag = reportError()
<< " has more than one potential consumer";
diag.attachNote(potentialConsumer->getOwner()->getLoc())
<< "used here as operand #" << potentialConsumer->getOperandNumber();
diag.attachNote(use.getOwner()->getLoc())
<< "used here as operand #" << use.getOperandNumber();
return diag;
}
return success();
}
LogicalResult transform::SequenceOp::verify() {
assert(getBodyBlock()->getNumArguments() >= 1 &&
"the number of arguments must have been verified to be more than 1 by "
"PossibleTopLevelTransformOpTrait");
if (!getRoot() && !getExtraBindings().empty()) {
return emitOpError()
<< "does not expect extra operands when used as top-level";
}
// Check if a block argument has more than one consuming use.
for (BlockArgument arg : getBodyBlock()->getArguments()) {
if (failed(checkDoubleConsume(arg, [this, arg]() {
return (emitOpError() << "block argument #" << arg.getArgNumber());
}))) {
return failure();
}
}
// Check properties of the nested operations they cannot check themselves.
for (Operation &child : *getBodyBlock()) {
if (!isa<TransformOpInterface>(child) &&
&child != &getBodyBlock()->back()) {
InFlightDiagnostic diag =
emitOpError()
<< "expected children ops to implement TransformOpInterface";
diag.attachNote(child.getLoc()) << "op without interface";
return diag;
}
for (OpResult result : child.getResults()) {
auto report = [&]() {
return (child.emitError() << "result #" << result.getResultNumber());
};
if (failed(checkDoubleConsume(result, report)))
return failure();
}
}
if (!getBodyBlock()->mightHaveTerminator())
return emitOpError() << "expects to have a terminator in the body";
if (getBodyBlock()->getTerminator()->getOperandTypes() !=
getOperation()->getResultTypes()) {
InFlightDiagnostic diag = emitOpError()
<< "expects the types of the terminator operands "
"to match the types of the result";
diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
return diag;
}
return success();
}
void transform::SequenceOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
getPotentialTopLevelEffects(effects);
}
OperandRange
transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) {
assert(point == getBody() && "unexpected region index");
if (getOperation()->getNumOperands() > 0)
return getOperation()->getOperands();
return OperandRange(getOperation()->operand_end(),
getOperation()->operand_end());
}
void transform::SequenceOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
if (point.isParent()) {
Region *bodyRegion = &getBody();
regions.emplace_back(bodyRegion, getNumOperands() != 0
? bodyRegion->getArguments()
: Block::BlockArgListType());
return;
}
assert(point == getBody() && "unexpected region index");
regions.emplace_back(getOperation()->getResults());
}
void transform::SequenceOp::getRegionInvocationBounds(
ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
(void)operands;
bounds.emplace_back(1, 1);
}
void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
TypeRange resultTypes,
FailurePropagationMode failurePropagationMode,
Value root,
SequenceBodyBuilderFn bodyBuilder) {
build(builder, state, resultTypes, failurePropagationMode, root,
/*extra_bindings=*/ValueRange());
Type bbArgType = root.getType();
buildSequenceBody(builder, state, bbArgType,
/*extraBindingTypes=*/TypeRange(), bodyBuilder);
}
void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
TypeRange resultTypes,
FailurePropagationMode failurePropagationMode,
Value root, ValueRange extraBindings,
SequenceBodyBuilderArgsFn bodyBuilder) {
build(builder, state, resultTypes, failurePropagationMode, root,
extraBindings);
buildSequenceBody(builder, state, root.getType(), extraBindings.getTypes(),
bodyBuilder);
}
void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
TypeRange resultTypes,
FailurePropagationMode failurePropagationMode,
Type bbArgType,
SequenceBodyBuilderFn bodyBuilder) {
build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
/*extra_bindings=*/ValueRange());
buildSequenceBody(builder, state, bbArgType,
/*extraBindingTypes=*/TypeRange(), bodyBuilder);
}
void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
TypeRange resultTypes,
FailurePropagationMode failurePropagationMode,
Type bbArgType, TypeRange extraBindingTypes,
SequenceBodyBuilderArgsFn bodyBuilder) {
build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
/*extra_bindings=*/ValueRange());
buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder);
}
//===----------------------------------------------------------------------===//
// PrintOp
//===----------------------------------------------------------------------===//
void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
StringRef name) {
if (!name.empty())
result.getOrAddProperties<Properties>().name = builder.getStringAttr(name);
}
void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
Value target, StringRef name) {
result.addOperands({target});
build(builder, result, name);
}
DiagnosedSilenceableFailure
transform::PrintOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
llvm::outs() << "[[[ IR printer: ";
if (getName().has_value())
llvm::outs() << *getName() << " ";
OpPrintingFlags printFlags;
if (getAssumeVerified().value_or(false))
printFlags.assumeVerified();
if (getUseLocalScope().value_or(false))
printFlags.useLocalScope();
if (getSkipRegions().value_or(false))
printFlags.skipRegions();
if (!getTarget()) {
llvm::outs() << "top-level ]]]\n";
state.getTopLevel()->print(llvm::outs(), printFlags);
llvm::outs() << "\n";
llvm::outs().flush();
return DiagnosedSilenceableFailure::success();
}
llvm::outs() << "]]]\n";
for (Operation *target : state.getPayloadOps(getTarget())) {
target->print(llvm::outs(), printFlags);
llvm::outs() << "\n";
}
llvm::outs().flush();
return DiagnosedSilenceableFailure::success();
}
void transform::PrintOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
// We don't really care about mutability here, but `getTarget` now
// unconditionally casts to a specific type before verification could run
// here.
if (!getTargetMutable().empty())
onlyReadsHandle(getTargetMutable()[0], effects);
onlyReadsPayload(effects);
// There is no resource for stderr file descriptor, so just declare print
// writes into the default resource.
effects.emplace_back(MemoryEffects::Write::get());
}
//===----------------------------------------------------------------------===//
// VerifyOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::VerifyOp::applyToOne(transform::TransformRewriter &rewriter,
Operation *target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
if (failed(::mlir::verify(target))) {
DiagnosedDefiniteFailure diag = emitDefiniteFailure()
<< "failed to verify payload op";
diag.attachNote(target->getLoc()) << "payload op";
return diag;
}
return DiagnosedSilenceableFailure::success();
}
void transform::VerifyOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getTargetMutable(), effects);
}
//===----------------------------------------------------------------------===//
// YieldOp
//===----------------------------------------------------------------------===//
void transform::YieldOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getOperandsMutable(), effects);
}