Alex Zinenko 5f0d4f208e [mlir] Introduce Transform ops for loops
Introduce transform ops for "for" loops, in particular for peeling, software
pipelining and unrolling, along with a couple of "IR navigation" ops. These ops
are intended to be generalized to different kinds of loops when possible and
therefore use the "loop" prefix. They currently live in the SCF dialect as
there is no clear place to put transform ops that may span across several
dialects, this decision is postponed until the ops actually need to handle
non-SCF loops.

Additionally refactor some common utilities for transform ops into trait or
interface methods, and change the loop pipelining to be a returning pattern.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D127300
2022-06-09 11:41:55 +02:00

233 lines
8.5 KiB
C++

//===- SCFTransformOps.cpp - Implementation of SCF transformation ops -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/Patterns.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/SCF/Transforms.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
using namespace mlir;
namespace {
/// A simple pattern rewriter that implements no special logic.
class SimpleRewriter : public PatternRewriter {
public:
SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
};
} // namespace
//===----------------------------------------------------------------------===//
// GetParentForOp
//===----------------------------------------------------------------------===//
LogicalResult
transform::GetParentForOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
SetVector<Operation *> parents;
for (Operation *target : state.getPayloadOps(getTarget())) {
scf::ForOp loop;
Operation *current = target;
for (unsigned i = 0, e = getNumLoops(); i < e; ++i) {
loop = current->getParentOfType<scf::ForOp>();
if (!loop) {
InFlightDiagnostic diag = emitError() << "could not find an '"
<< scf::ForOp::getOperationName()
<< "' parent";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}
current = loop;
}
parents.insert(loop);
}
results.set(getResult().cast<OpResult>(), parents.getArrayRef());
return success();
}
//===----------------------------------------------------------------------===//
// LoopOutlineOp
//===----------------------------------------------------------------------===//
/// Wraps the given operation `op` into an `scf.execute_region` operation. Uses
/// the provided rewriter for all operations to remain compatible with the
/// rewriting infra, as opposed to just splicing the op in place.
static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b,
Operation *op) {
if (op->getNumRegions() != 1)
return nullptr;
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(op);
scf::ExecuteRegionOp executeRegionOp =
b.create<scf::ExecuteRegionOp>(op->getLoc(), op->getResultTypes());
{
OpBuilder::InsertionGuard g(b);
b.setInsertionPointToStart(&executeRegionOp.getRegion().emplaceBlock());
Operation *clonedOp = b.cloneWithoutRegions(*op);
Region &clonedRegion = clonedOp->getRegions().front();
assert(clonedRegion.empty() && "expected empty region");
b.inlineRegionBefore(op->getRegions().front(), clonedRegion,
clonedRegion.end());
b.create<scf::YieldOp>(op->getLoc(), clonedOp->getResults());
}
b.replaceOp(op, executeRegionOp.getResults());
return executeRegionOp;
}
LogicalResult
transform::LoopOutlineOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<Operation *> transformed;
DenseMap<Operation *, SymbolTable> symbolTables;
for (Operation *target : state.getPayloadOps(getTarget())) {
Location location = target->getLoc();
Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(target);
SimpleRewriter rewriter(getContext());
scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target);
if (!exec) {
InFlightDiagnostic diag = emitError() << "failed to outline";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}
func::CallOp call;
FailureOr<func::FuncOp> outlined = outlineSingleBlockRegion(
rewriter, location, exec.getRegion(), getFuncName(), &call);
if (failed(outlined))
return reportUnknownTransformError(target);
if (symbolTableOp) {
SymbolTable &symbolTable =
symbolTables.try_emplace(symbolTableOp, symbolTableOp)
.first->getSecond();
symbolTable.insert(*outlined);
call.setCalleeAttr(FlatSymbolRefAttr::get(*outlined));
}
transformed.push_back(*outlined);
}
results.set(getTransformed().cast<OpResult>(), transformed);
return success();
}
//===----------------------------------------------------------------------===//
// LoopPeelOp
//===----------------------------------------------------------------------===//
FailureOr<scf::ForOp> transform::LoopPeelOp::applyToOne(scf::ForOp loop) {
scf::ForOp result;
IRRewriter rewriter(loop->getContext());
LogicalResult status =
scf::peelAndCanonicalizeForLoop(rewriter, loop, result);
if (failed(status)) {
if (getFailIfAlreadyDivisible())
return reportUnknownTransformError(loop);
return loop;
}
return result;
}
//===----------------------------------------------------------------------===//
// LoopPipelineOp
//===----------------------------------------------------------------------===//
/// Callback for PipeliningOption. Populates `schedule` with the mapping from an
/// operation to its logical time position given the iteration interval and the
/// read latency. The latter is only relevant for vector transfers.
static void
loopScheduling(scf::ForOp forOp,
std::vector<std::pair<Operation *, unsigned>> &schedule,
unsigned iterationInterval, unsigned readLatency) {
auto getLatency = [&](Operation *op) -> unsigned {
if (isa<vector::TransferReadOp>(op))
return readLatency;
return 1;
};
DenseMap<Operation *, unsigned> opCycles;
std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
for (Operation &op : forOp.getBody()->getOperations()) {
if (isa<scf::YieldOp>(op))
continue;
unsigned earlyCycle = 0;
for (Value operand : op.getOperands()) {
Operation *def = operand.getDefiningOp();
if (!def)
continue;
earlyCycle = std::max(earlyCycle, opCycles[def] + getLatency(def));
}
opCycles[&op] = earlyCycle;
wrappedSchedule[earlyCycle % iterationInterval].push_back(&op);
}
for (auto it : wrappedSchedule) {
for (Operation *op : it.second) {
unsigned cycle = opCycles[op];
schedule.push_back(std::make_pair(op, cycle / iterationInterval));
}
}
}
FailureOr<scf::ForOp> transform::LoopPipelineOp::applyToOne(scf::ForOp loop) {
scf::PipeliningOption options;
options.getScheduleFn =
[this](scf::ForOp forOp,
std::vector<std::pair<Operation *, unsigned>> &schedule) mutable {
loopScheduling(forOp, schedule, getIterationInterval(),
getReadLatency());
};
scf::ForLoopPipeliningPattern pattern(options, loop->getContext());
SimpleRewriter rewriter(getContext());
rewriter.setInsertionPoint(loop);
FailureOr<scf::ForOp> patternResult =
pattern.returningMatchAndRewrite(loop, rewriter);
if (failed(patternResult))
return reportUnknownTransformError(loop);
return patternResult;
}
//===----------------------------------------------------------------------===//
// LoopUnrollOp
//===----------------------------------------------------------------------===//
LogicalResult transform::LoopUnrollOp::applyToOne(scf::ForOp loop) {
if (failed(loopUnrollByFactor(loop, getFactor())))
return reportUnknownTransformError(loop);
return success();
}
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
namespace {
class SCFTransformDialectExtension
: public transform::TransformDialectExtension<
SCFTransformDialectExtension> {
public:
SCFTransformDialectExtension() {
declareDependentDialect<AffineDialect>();
declareDependentDialect<func::FuncDialect>();
registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
>();
}
};
} // namespace
#define GET_OP_CLASSES
#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
void mlir::scf::registerTransformDialectExtension(DialectRegistry &registry) {
registry.addExtensions<SCFTransformDialectExtension>();
}