llvm-project/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Alex Zinenko 3963b4d0dc [mlir] Transform op for multitile size generation
Introduce a structured transform op that emits IR computing the multi-tile
sizes with requested parameters (target size and divisor) for the given
structured op. The sizes may fold to arithmetic constant operations when the
shape is constant. These operations may then be used to call the existing
tiling transformation with a single non-zero dynamic size (i.e. perform
strip-mining) for each of the dimensions separately, thus achieving multi-size
tiling with optional loop interchange. A separate test exercises the entire
script.

Depends On D129217

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D129287
2022-07-12 12:36:28 +00:00

856 lines
34 KiB
C++

//===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::linalg;
using namespace mlir::transform;
/// Extracts a vector of int64_t from an array attribute. Asserts if the
/// attribute contains values other than integers.
static SmallVector<int64_t> extractI64Array(ArrayAttr attr) {
SmallVector<int64_t> result;
result.reserve(attr.size());
for (APInt value : attr.getAsValueRange<IntegerAttr>())
result.push_back(value.getSExtValue());
return result;
}
/// Extracts a vector of unsigned from an array attribute. Asserts if the
/// attribute contains values other than intergers. May truncate.
static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) {
SmallVector<unsigned> result;
result.reserve(attr.size());
for (APInt value : attr.getAsValueRange<IntegerAttr>())
result.push_back(value.getZExtValue());
return result;
}
namespace {
/// A simple pattern rewriter that implements no special logic.
class SimpleRewriter : public PatternRewriter {
public:
SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
};
} // namespace
/// Attempts to apply the pattern specified as template argument to the given
/// operation. The pattern is expected to have a `returningMatchAndRewrite`
/// function that returns the "main" result or failure. Returns failure if the
/// pattern failed to apply. Extra arguments are forwarded to the pattern
/// constructor.
template <typename PatternTy, typename... Args>
static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
// Check if the given operation has the type expected by the pattern.
using OpTy = typename llvm::function_traits<
decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
auto op = dyn_cast<OpTy>(operation);
if (!op)
return failure();
// Apply the pattern directly to the op.
PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
SimpleRewriter rewriter(operation->getContext());
rewriter.setInsertionPoint(operation);
auto result = pattern.returningMatchAndRewrite(op, rewriter);
if (failed(result))
return failure();
return cast<LinalgOp>(result->getOperation());
}
//===----------------------------------------------------------------------===//
// DecomposeOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::DecomposeOp::applyToOne(linalg::LinalgOp target,
SmallVectorImpl<Operation *> &results,
transform::TransformState &state) {
FailureOr<LinalgOp> windowed =
tryApply<DownscaleSizeOneWindowed2DConvolution>(target);
if (succeeded(windowed)) {
results.push_back(*windowed);
return DiagnosedSilenceableFailure(success());
}
FailureOr<LinalgOp> depthwise =
tryApply<DownscaleDepthwiseConv2DNhwcHwcOp>(target);
if (succeeded(depthwise)) {
results.push_back(*depthwise);
return DiagnosedSilenceableFailure(success());
}
results.assign(1, nullptr);
return emitDefaultSilenceableFailure(target);
}
//===----------------------------------------------------------------------===//
// FuseOp
//===----------------------------------------------------------------------===//
/// Apply a tiling transformation to all payload ops and store both the
/// tiled operation as well as the created tile loops.
static LogicalResult
applyTilingToAll(Operation *transformOp, ArrayRef<Operation *> payloadOps,
unsigned numLoops,
transform::TransformResults &transformResults,
function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) {
SmallVector<Operation *> tiledLinalgOps;
SmallVector<SmallVector<Operation *>> loopOps(numLoops);
for (unsigned int i = 0; i < numLoops; ++i)
loopOps[i].reserve(payloadOps.size());
for (Operation *target : payloadOps) {
auto linalgOp = dyn_cast<linalg::LinalgOp>(target);
if (!linalgOp)
return transformOp->emitError("only LinalgOps are supported");
FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp);
if (failed(tiled))
return failure();
tiledLinalgOps.push_back(tiled->op);
if (tiled->loops.size() != numLoops)
// Not enough loops were generated. This usually means that the input size
// was smaller than the tiling size.
// TODO: LinalgTilingPattern should return failure().
return failure();
for (unsigned int i = 0; i < numLoops; ++i)
loopOps[i].push_back(tiled->loops[i]);
}
transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
for (unsigned int i = 0; i < numLoops; ++i)
transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
return success();
}
/// Parse a tiling-like operation that returns the tiled op as well as the
/// created tile loops. The function counts the non-zero tile sizes to compute
/// the number of results.
static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result,
StringRef sizesAttrName) {
OpAsmParser::UnresolvedOperand targetOperand;
SMLoc opLoc = parser.getCurrentLocation();
if (parser.parseOperand(targetOperand) ||
parser.parseOptionalAttrDict(result.attributes))
return failure();
Attribute sizesAttr = result.attributes.get(sizesAttrName);
if (!sizesAttr)
return parser.emitError(opLoc)
<< "expected '" << sizesAttrName << "' attribute";
auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>();
if (!sizesArrayAttr)
return parser.emitError(opLoc)
<< "'" << sizesAttrName << "' attribute must be an array";
Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
size_t numExpectedLoops =
sizesArrayAttr.size() - llvm::count(extractI64Array(sizesArrayAttr), 0);
result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType));
if (parser.resolveOperand(targetOperand, pdlOpType, result.operands))
return failure();
return success();
}
DiagnosedSilenceableFailure
transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
mlir::transform::TransformState &state) {
LinalgTilingAndFusionOptions fusionOptions;
fusionOptions.tileSizes = extractI64Array(getTileSizes());
fusionOptions.tileInterchange = extractI64Array(getTileInterchange());
LogicalResult result = applyTilingToAll(
getOperation(), state.getPayloadOps(getTarget()),
fusionOptions.tileSizes.size() - llvm::count(fusionOptions.tileSizes, 0),
transformResults, [&](LinalgOp linalgOp) -> FailureOr<TiledLinalgOp> {
LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions);
SimpleRewriter rewriter(getContext());
rewriter.setInsertionPoint(linalgOp);
FailureOr<TileLoopNest> tileLoopNest =
pattern.returningMatchAndRewrite(linalgOp, rewriter);
if (failed(tileLoopNest))
return failure();
TiledLinalgOp tiledLinalgOp;
tiledLinalgOp.op = tileLoopNest->getRootOp();
tiledLinalgOp.loops = {tileLoopNest->getLoopOps().begin(),
tileLoopNest->getLoopOps().end()};
return tiledLinalgOp;
});
return DiagnosedSilenceableFailure(result);
}
ParseResult transform::FuseOp::parse(OpAsmParser &parser,
OperationState &result) {
return parseTileLikeOp(
parser, result,
transform::FuseOp::getTileSizesAttrName(result.name).getValue());
}
void transform::FuseOp::print(OpAsmPrinter &p) {
p << ' ';
p << getTarget();
p.printOptionalAttrDict((*this)->getAttrs());
}
LogicalResult transform::FuseOp::verify() {
SmallVector<int64_t> permutation = extractI64Array(getTileInterchange());
auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
if (!std::is_permutation(sequence.begin(), sequence.end(),
permutation.begin(), permutation.end())) {
return emitOpError() << "expects interchange to be a permutation, found "
<< getTileInterchange();
}
return success();
}
//===----------------------------------------------------------------------===//
// GeneralizeOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::GeneralizeOp::applyToOne(linalg::LinalgOp target,
SmallVectorImpl<Operation *> &results,
transform::TransformState &state) {
// Exit early if no transformation is needed.
if (isa<GenericOp>(target)) {
results.push_back(target);
return DiagnosedSilenceableFailure(success());
}
FailureOr<LinalgOp> generic = tryApply<LinalgGeneralizationPattern>(target);
if (succeeded(generic)) {
results.push_back(generic->getOperation());
return DiagnosedSilenceableFailure(success());
}
results.assign(1, nullptr);
return emitDefaultSilenceableFailure(target);
}
//===----------------------------------------------------------------------===//
// InterchangeOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::InterchangeOp::applyToOne(linalg::GenericOp target,
SmallVectorImpl<Operation *> &results,
transform::TransformState &state) {
SmallVector<unsigned> interchangeVector =
extractUIntArray(getIteratorInterchange());
// Exit early if no transformation is needed.
if (interchangeVector.empty()) {
results.push_back(target);
return DiagnosedSilenceableFailure(success());
}
SimpleRewriter rewriter(target->getContext());
FailureOr<GenericOp> res =
interchangeGenericOp(rewriter, target, interchangeVector);
if (failed(res))
return DiagnosedSilenceableFailure::definiteFailure();
results.push_back(res->getOperation());
return DiagnosedSilenceableFailure(success());
}
LogicalResult transform::InterchangeOp::verify() {
SmallVector<unsigned> permutation =
extractUIntArray(getIteratorInterchange());
auto sequence = llvm::to_vector(llvm::seq<unsigned>(0, permutation.size()));
if (!std::is_permutation(sequence.begin(), sequence.end(),
permutation.begin(), permutation.end())) {
return emitOpError()
<< "expects iterator_interchange to be a permutation, found "
<< getIteratorInterchange();
}
return success();
}
//===---------------------------------------------------------------------===//
// MultiTileSizesOp
//===---------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
LinalgOp target, SmallVector<Operation *> &results, TransformState &state) {
OpBuilder builder(target.getContext());
builder.setInsertionPoint(target);
OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
OpFoldResult divisor = builder.getIndexAttr(getDivisor());
FailureOr<MultiSizeSpecification> spec = computeMultiTileSizes(
builder, target, getDimension(), targetSize, divisor);
if (failed(spec)) {
return emitSilenceableError() << "could not generate tile size computation";
}
Operation *splitPoint =
builder
.createOrFold<arith::MulIOp>(target.getLoc(), spec->lowTileSize,
spec->lowTripCount)
.getDefiningOp();
Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
Operation *highTileSize = spec->highTileSize.getDefiningOp();
assert(lowTileSize && highTileSize && splitPoint &&
"tile sizes are not produced by operations");
results.reserve(results.size() + 3);
results.push_back(lowTileSize);
results.push_back(highTileSize);
results.push_back(splitPoint);
return DiagnosedSilenceableFailure::success();
}
void transform::MultiTileSizesOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
effects.emplace_back(MemoryEffects::Read::get(), getTarget(),
transform::TransformMappingResource::get());
for (Value result : getResults()) {
effects.emplace_back(MemoryEffects::Allocate::get(), result,
transform::TransformMappingResource::get());
effects.emplace_back(MemoryEffects::Write::get(), result,
transform::TransformMappingResource::get());
}
effects.emplace_back(MemoryEffects::Read::get(),
transform::PayloadIRResource::get());
effects.emplace_back(MemoryEffects::Write::get(),
transform::PayloadIRResource::get());
}
//===---------------------------------------------------------------------===//
// PadOp
//===---------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::PadOp::applyToOne(linalg::LinalgOp target,
SmallVectorImpl<Operation *> &results,
transform::TransformState &state) {
// Convert the integer packing flags to booleans.
SmallVector<bool> packPaddings;
for (int64_t packPadding : extractI64Array(getPackPaddings()))
packPaddings.push_back(static_cast<bool>(packPadding));
// Convert the padding values to attributes.
SmallVector<Attribute> paddingValues;
for (auto const &it :
llvm::zip(getPaddingValues(), target->getOperandTypes())) {
Attribute attr = std::get<0>(it);
Type elementType = getElementTypeOrSelf(std::get<1>(it));
// Try to parse string attributes to obtain an attribute of element type.
if (auto stringAttr = attr.dyn_cast<StringAttr>()) {
paddingValues.push_back(
parseAttribute(attr.cast<StringAttr>(), elementType));
if (!paddingValues.back()) {
auto diag = this->emitOpError("expects a padding that parses to ")
<< elementType << ", got " << std::get<0>(it);
diag.attachNote(target.getLoc()) << "when applied to this op";
return DiagnosedSilenceableFailure::definiteFailure();
}
continue;
}
// Otherwise, add the attribute directly.
if (attr.getType() != elementType) {
auto diag = this->emitOpError("expects a padding value of type ")
<< elementType << ", got " << attr;
diag.attachNote(target.getLoc()) << "when applied to this op";
return DiagnosedSilenceableFailure::definiteFailure();
}
paddingValues.push_back(attr);
}
// Extract the transpose vectors.
SmallVector<SmallVector<int64_t>> transposePaddings;
for (Attribute transposeVector : getTransposePaddings().cast<ArrayAttr>())
transposePaddings.push_back(
extractI64Array(transposeVector.cast<ArrayAttr>()));
LinalgPaddingOptions paddingOptions;
paddingOptions.setPaddingValues(paddingValues);
paddingOptions.setPaddingDimensions(extractI64Array(getPaddingDimensions()));
paddingOptions.setPackPaddings(packPaddings);
paddingOptions.setHoistPaddings(extractI64Array(getHoistPaddings()));
paddingOptions.setTransposePaddings(transposePaddings);
FailureOr<LinalgOp> result =
tryApply<LinalgPaddingPattern>(target, paddingOptions);
if (succeeded(result)) {
results.push_back(result->getOperation());
return DiagnosedSilenceableFailure(success());
}
results.assign(1, nullptr);
return emitDefaultSilenceableFailure(target);
}
LogicalResult transform::PadOp::verify() {
SmallVector<int64_t> packPaddings = extractI64Array(getPackPaddings());
if (any_of(packPaddings, [](int64_t packPadding) {
return packPadding != 0 && packPadding != 1;
})) {
return emitOpError()
<< "expects pack_paddings to contain booleans (0/1), found "
<< getPackPaddings();
}
SmallVector<int64_t> paddingDimensions =
extractI64Array(getPaddingDimensions());
if (any_of(paddingDimensions,
[](int64_t paddingDimension) { return paddingDimension < 0; })) {
return emitOpError()
<< "expects padding_dimensions to contain positive integers, found "
<< getPaddingDimensions();
}
SmallVector<int64_t> hoistPaddings = extractI64Array(getHoistPaddings());
if (any_of(hoistPaddings,
[](int64_t hoistPadding) { return hoistPadding < 0; })) {
return emitOpError()
<< "expects hoist_paddings to contain positive integers, found "
<< getHoistPaddings();
}
ArrayAttr transposes = getTransposePaddings();
for (Attribute attr : transposes) {
SmallVector<int64_t> transpose = extractFromI64ArrayAttr(attr);
auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
if (!std::is_permutation(sequence.begin(), sequence.end(),
transpose.begin(), transpose.end())) {
return emitOpError()
<< "expects transpose_paddings to be a permutation, found "
<< attr;
}
}
return success();
}
//===----------------------------------------------------------------------===//
// ScalarizeOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::ScalarizeOp::applyToOne(linalg::LinalgOp target,
SmallVectorImpl<Operation *> &results,
transform::TransformState &state) {
LinalgTilingOptions tilingOptions;
tilingOptions.scalarizeDynamicDims();
// Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile
// sizes and asserts that it is not already set.
SmallVector<int64_t> emptyTileSizes;
LinalgTilingPattern pattern(getContext(), tilingOptions);
SimpleRewriter rewriter(getContext());
rewriter.setInsertionPoint(target);
FailureOr<TiledLinalgOp> result =
pattern.returningMatchAndRewrite(target, rewriter);
if (failed(result))
return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
results.push_back(result->op);
return DiagnosedSilenceableFailure(success());
}
//===----------------------------------------------------------------------===//
// SplitOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
TransformState &state) {
// Collect the dynamic split points if provided.
ArrayRef<Operation *> payload = state.getPayloadOps(getTarget());
SimpleRewriter rewriter(getContext());
SmallVector<OpFoldResult> splitPoints;
splitPoints.reserve(payload.size());
if (getDynamicSplitPoint()) {
auto diag = DiagnosedSilenceableFailure::success();
splitPoints = llvm::to_vector(llvm::map_range(
state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) {
if (op->getNumResults() != 1 ||
!op->getResult(0).getType().isIndex()) {
diag = emitSilenceableError()
<< "expected dynamic split point handle to point to a "
"single-result index-typed op";
diag.attachNote(op->getLoc()) << "dynamic split point";
}
return OpFoldResult(op->getResult(0));
}));
if (!diag.succeeded())
return diag;
if (splitPoints.size() != payload.size()) {
emitError() << "expected the dynamic split point handle to point to as "
"many operations ("
<< splitPoints.size() << ") as the target handle ("
<< payload.size() << ")";
return DiagnosedSilenceableFailure::definiteFailure();
}
} else {
splitPoints.resize(payload.size(),
rewriter.getIndexAttr(getStaticSplitPoint()));
}
// Split each target operation.
SmallVector<Operation *> first, second;
for (const auto &pair : llvm::zip(payload, splitPoints)) {
Operation *target = std::get<0>(pair);
auto linalgOp = dyn_cast<LinalgOp>(target);
if (!linalgOp) {
auto diag = emitSilenceableError() << "only applies to structured ops";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}
if (getDimension() >= linalgOp.getNumLoops()) {
auto diag = emitSilenceableError() << "dimension " << getDimension()
<< " does not exist in target op";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}
rewriter.setInsertionPoint(linalgOp);
std::tie(first.emplace_back(), second.emplace_back()) =
linalg::splitOp(rewriter, linalgOp, getDimension(), std::get<1>(pair));
}
results.set(getFirst().cast<OpResult>(), first);
results.set(getSecond().cast<OpResult>(), second);
return DiagnosedSilenceableFailure::success();
}
void SplitOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
// The target handle is consumed.
effects.emplace_back(MemoryEffects::Read::get(), getTarget(),
TransformMappingResource::get());
effects.emplace_back(MemoryEffects::Free::get(), getTarget(),
TransformMappingResource::get());
// The dynamic split point handle is not consumed.
if (getDynamicSplitPoint()) {
effects.emplace_back(MemoryEffects::Read::get(), getDynamicSplitPoint(),
TransformMappingResource::get());
}
// The resulting handles are produced.
for (Value result : getResults()) {
effects.emplace_back(MemoryEffects::Allocate::get(), result,
TransformMappingResource::get());
effects.emplace_back(MemoryEffects::Write::get(), result,
TransformMappingResource::get());
}
effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
}
ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
OpAsmParser::UnresolvedOperand target, dynamicSplitPoint;
IntegerAttr staticSplitPoint;
auto pdlOperationType =
pdl::OperationType::get(parser.getBuilder().getContext());
if (parser.parseOperand(target) ||
parser.resolveOperand(target, pdlOperationType, result.operands) ||
parser.parseKeyword("after"))
return failure();
OptionalParseResult dynamicPointParseResult =
parser.parseOptionalOperand(dynamicSplitPoint);
if (!dynamicPointParseResult.hasValue()) {
int64_t staticSplitPointValue;
if (failed(parser.parseInteger(staticSplitPointValue)))
return failure();
staticSplitPoint =
parser.getBuilder().getI64IntegerAttr(staticSplitPointValue);
} else {
if (failed(*dynamicPointParseResult) ||
parser.resolveOperand(dynamicSplitPoint, pdlOperationType,
result.operands)) {
return failure();
}
staticSplitPoint =
parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamicSize);
}
result.addAttribute(
SplitOp::getStaticSplitPointAttrName(result.name).getValue(),
staticSplitPoint);
if (failed(parser.parseOptionalAttrDict(result.attributes)))
return failure();
result.addTypes({pdlOperationType, pdlOperationType});
return success();
}
void SplitOp::print(OpAsmPrinter &printer) {
printer << " " << getTarget() << " after ";
int64_t staticSplitSize = static_cast<int64_t>(getStaticSplitPoint());
if (staticSplitSize != ShapedType::kDynamicSize)
printer << staticSplitSize;
else
printer << getDynamicSplitPoint();
printer << " ";
printer.printOptionalAttrDict(getOperation()->getAttrs(),
{getStaticSplitPointAttrName()});
}
LogicalResult SplitOp::verify() {
if ((static_cast<int64_t>(getStaticSplitPoint()) !=
ShapedType::kDynamicSize) ^
(getDynamicSplitPoint() == nullptr)) {
return emitOpError()
<< "expects either a dynamic or a static split point to be provided";
}
return success();
}
//===----------------------------------------------------------------------===//
// SplitReductionOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::SplitReductionOp::applyToOne(linalg::LinalgOp target,
SmallVectorImpl<Operation *> &results,
transform::TransformState &state) {
ControlSplitReductionFn splitFn = [&](LinalgOp) {
return std::pair<int64_t, unsigned>(getSplitFactor(),
getInsertSplitDimension());
};
SimpleRewriter rewriter(getContext());
rewriter.setInsertionPoint(target);
FailureOr<SplitReductionResult> splitResult =
(getUseScalingAlgorithm())
? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
: splitReduction(rewriter, target, splitFn, getUseAlloc());
if (failed(splitResult))
return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
results.push_back(splitResult->initOrAlloc);
results.push_back(splitResult->fillOp);
results.push_back(splitResult->splitLinalgOp);
results.push_back(splitResult->resultCombiningLinalgOp);
return DiagnosedSilenceableFailure(success());
}
//===----------------------------------------------------------------------===//
// TileOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::TileOp::apply(TransformResults &transformResults,
TransformState &state) {
LinalgTilingOptions tilingOptions;
SmallVector<int64_t> tileSizes = extractI64Array(getStaticSizes());
ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
SmallVector<ArrayRef<Operation *>> dynamicSizeProducers;
dynamicSizeProducers.reserve(getDynamicSizes().size());
for (Value dynamicSizeProducerHandle : getDynamicSizes()) {
dynamicSizeProducers.push_back(
state.getPayloadOps(dynamicSizeProducerHandle));
if (dynamicSizeProducers.back().size() != targets.size()) {
DiagnosedSilenceableFailure diag =
emitSilenceableError()
<< "expected as many dynamic size-producing operations ("
<< dynamicSizeProducers.back().size() << ") as target ops ("
<< targets.size() << ")";
diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle";
return diag;
}
for (Operation *op : dynamicSizeProducers.back()) {
if (op->getNumResults() == 1 &&
op->getResult(0).getType().isa<IndexType>())
continue;
DiagnosedSilenceableFailure diag =
emitSilenceableError() << "expected sizes to be produced by ops "
"with a single index-type result";
diag.attachNote(op->getLoc()) << "size producer op";
diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle";
return diag;
}
}
SmallVector<Operation *> tiled;
SmallVector<SmallVector<Operation *, 4>, 4> loops;
loops.resize(getLoops().size());
for (auto &en : llvm::enumerate(targets)) {
auto linalgOp = dyn_cast<LinalgOp>(en.value());
if (!linalgOp) {
DiagnosedSilenceableFailure diag = emitSilenceableError()
<< "only linalg ops are supported";
diag.attachNote(en.value()->getLoc()) << "target op";
return diag;
}
unsigned index = en.index();
if (!tileSizes.empty()) {
tilingOptions.setTileSizeComputationFunction(
[&, index](OpBuilder &b, Operation *) {
SmallVector<Value, 4> sizes;
sizes.reserve(tileSizes.size());
unsigned dynamicIdx = 0;
for (OpFoldResult ofr : getMixedSizes()) {
if (auto attr = ofr.dyn_cast<Attribute>()) {
sizes.push_back(b.create<arith::ConstantIndexOp>(
getLoc(), attr.cast<IntegerAttr>().getInt()));
} else {
sizes.push_back(
dynamicSizeProducers[dynamicIdx++][index]->getResult(0));
}
}
return sizes;
});
}
tilingOptions.setInterchange(extractUIntArray(getInterchange()));
LinalgTilingPattern pattern(getContext(), tilingOptions);
SimpleRewriter rewriter(linalgOp.getContext());
FailureOr<TiledLinalgOp> tiledOp =
pattern.returningMatchAndRewrite(linalgOp, rewriter);
if (failed(tiledOp))
return DiagnosedSilenceableFailure::definiteFailure();
tiled.push_back(tiledOp->op);
for (const auto &en2 : llvm::enumerate(tiledOp->loops))
loops[en2.index()].push_back(en2.value());
}
transformResults.set(getTiledLinalgOp().cast<OpResult>(), tiled);
for (const auto &en : llvm::enumerate(loops))
transformResults.set(getLoops()[en.index()].cast<OpResult>(), en.value());
return DiagnosedSilenceableFailure::success();
}
SmallVector<OpFoldResult> transform::TileOp::getMixedSizes() {
ValueRange dynamic = getDynamicSizes();
SmallVector<int64_t> tileSizes = extractI64Array(getStaticSizes());
SmallVector<OpFoldResult> results;
results.reserve(tileSizes.size());
unsigned dynamicPos = 0;
Builder builder(getContext());
for (int64_t size : tileSizes) {
if (size == ShapedType::kDynamicSize) {
results.push_back(dynamic[dynamicPos++]);
} else {
results.push_back(builder.getIndexAttr(size));
}
}
return results;
}
ParseResult transform::TileOp::parse(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::UnresolvedOperand target;
SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizes;
ArrayAttr staticSizes;
auto pdlOperationType = pdl::OperationType::get(parser.getContext());
if (parser.parseOperand(target) ||
parser.resolveOperand(target, pdlOperationType, result.operands) ||
parseOperandsOrIntegersSizesList(parser, dynamicSizes, staticSizes) ||
parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) ||
parser.parseOptionalAttrDict(result.attributes))
return ParseResult::failure();
result.addAttribute(getStaticSizesAttrName(result.name), staticSizes);
size_t numExpectedLoops =
staticSizes.size() - llvm::count(extractI64Array(staticSizes), 0);
result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOperationType));
return success();
}
void TileOp::print(OpAsmPrinter &p) {
p << ' ' << getTarget();
printOperandsOrIntegersSizesList(p, getOperation(), getDynamicSizes(),
getStaticSizes());
p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()});
}
void transform::TileOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
consumesHandle(getTarget(), effects);
onlyReadsHandle(getDynamicSizes(), effects);
producesHandle(getTiledLinalgOp(), effects);
producesHandle(getLoops(), effects);
modifiesPayload(effects);
}
//===----------------------------------------------------------------------===//
// VectorizeOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::VectorizeOp::applyToOne(Operation *target,
SmallVectorImpl<Operation *> &results,
transform::TransformState &state) {
if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
auto diag = this->emitOpError("requires isolated-from-above targets");
diag.attachNote(target->getLoc()) << "non-isolated target";
return DiagnosedSilenceableFailure::definiteFailure();
}
MLIRContext *ctx = getContext();
RewritePatternSet patterns(ctx);
patterns.add<LinalgVectorizationPattern>(ctx);
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
vector::populateVectorReductionToContractPatterns(patterns);
patterns.add<linalg::LinalgCopyVTRForwardingPattern,
linalg::LinalgCopyVTWForwardingPattern>(ctx,
/*benefit=*/2);
vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
if (getVectorizePadding())
linalg::populatePadOpVectorizationPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns))))
return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
results.push_back(target);
return DiagnosedSilenceableFailure(success());
}
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
namespace {
/// Registers new ops and declares PDL as dependent dialect since the additional
/// ops are using PDL types for operands and results.
class LinalgTransformDialectExtension
: public transform::TransformDialectExtension<
LinalgTransformDialectExtension> {
public:
LinalgTransformDialectExtension() {
declareDependentDialect<AffineDialect>();
declareDependentDialect<arith::ArithmeticDialect>();
declareDependentDialect<pdl::PDLDialect>();
declareDependentDialect<scf::SCFDialect>();
declareDependentDialect<vector::VectorDialect>();
registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
>();
}
};
} // namespace
#define GET_OP_CLASSES
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
void mlir::linalg::registerTransformDialectExtension(
DialectRegistry &registry) {
registry.addExtensions<LinalgTransformDialectExtension>();
}