The existing `scf::tileAndFuseConsumerOfSlices` takes a list of slices (and loops they are part of), tries to find the consumer of these slices (all slices are expected to be the same consumer), and then tiles the consumer into the loop nest using the `TilingInterface`. A more natural way of doing consumer fusion is to just start from the consumer, look for operands that are produced by the loop nest passed in as `loops` (presumably these loops are generated by tiling, but that is not a requirement for consumer fusion). Using the consumer you can find the slices of the operands that are accessed within the loop which you can then use to tile and fuse the consumer (using `TilingInterface`). This handles more naturally the case where multiple operands of the consumer come from the loop nest. The `scf::tileAndFuseConsumerOfSlices` was implemented as a mirror of `scf::tileAndFuseProducerOfSlice`. For the latter, the slice has a single producer for the source of the slice, which makes it a natural way of specifying producer fusion. But for consumers, the result might have multiple users, resulting in multiple candidates for fusion, as well as a fusion candidate using multiple results from the tiled loop nest. This means using slices (`tensor.insert_slice`/`tensor.parallel_insert_slice`) as a hook for consumer fusion turns out to be quite hard to navigate. The use of the consumer directly avoids all those pain points. In time the `scf::tileAndFuseConsumerOfSlices` should be deprecated in favor of `scf::tileAndFuseConsumer`. There is a lot of tech-debt that has accumulated in `scf::tileAndFuseConsumerOfSlices` that needs to be cleanedup. So while that gets cleaned up, and required functionality is moved to `scf::tileAndFuseConsumer`, the old path is still maintained. The test for `scf::tileAndFuseConsumerUsingSlices` is copied to `tile-and-fuse-consumer.mlir` to `tile-and-fuse-consumer-using-slices.mlir`. All the tests that were there in this file are now using the `tileAndFuseConsumer` method. The test op `test.tile_and_fuse_consumer` is modified to call `scf::tileAndFuseConsumer`, while a new op `test.tile_and_fuse_consumer_of_slice` is used to keep the old path tested while it is deprecated. --------- Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
719 lines
29 KiB
C++
719 lines
29 KiB
C++
//===- TestTilingInterfaceTransformOps.cpp - Test `TilingInterface` ------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file defines transform dialect operations used for testing
|
|
// TilingInterface
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Index/IR/IndexDialect.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
|
|
#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
|
|
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
|
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
|
|
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
|
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
|
#include "mlir/IR/Dominance.h"
|
|
#include "mlir/IR/OpImplementation.h"
|
|
#include "mlir/Interfaces/TilingInterface.h"
|
|
#include "llvm/Support/Debug.h"
|
|
|
|
#define DEBUG_TYPE "test-tiling-interface"
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "TestTilingInterfaceTransformOps.h.inc"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::transform;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TestFuseAndYieldOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static llvm::SmallDenseSet<Operation *> collectTiledAndFusedOps(Operation *op) {
|
|
SmallVector<Operation *> worklist;
|
|
llvm::SmallDenseSet<Operation *> producers;
|
|
worklist.push_back(op);
|
|
producers.insert(op);
|
|
while (!worklist.empty()) {
|
|
Operation *current = worklist.pop_back_val();
|
|
for (OpOperand &operand : current->getOpOperands()) {
|
|
Operation *producer = operand.get().getDefiningOp();
|
|
if (!producer || !isa<TilingInterface>(producer) ||
|
|
producers.contains(producer))
|
|
continue;
|
|
worklist.push_back(producer);
|
|
producers.insert(producer);
|
|
}
|
|
}
|
|
return producers;
|
|
}
|
|
|
|
/// Apply a tile and fuse transformation to all payload ops and store both the
|
|
/// tiled operation as well as the created tile loops.
|
|
template <typename Range>
|
|
static LogicalResult
|
|
applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
|
|
Range &&payloadOps, unsigned numLoops,
|
|
scf::SCFTilingOptions tilingOptions,
|
|
TransformResults &transformResults) {
|
|
SmallVector<Operation *> tiledOps;
|
|
SmallVector<SmallVector<Operation *>> loopOps(numLoops);
|
|
|
|
for (Operation *target : payloadOps) {
|
|
auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
|
|
if (!tilingInterfaceOp)
|
|
return transformOp->emitError("only TilingInterface ops are supported");
|
|
DominanceInfo dominanceInfo(tilingInterfaceOp);
|
|
|
|
llvm::SmallDenseSet<Operation *> tiledAndFusedOps =
|
|
collectTiledAndFusedOps(tilingInterfaceOp);
|
|
llvm::DenseSet<Operation *> yieldReplacementsFor;
|
|
for (auto op : tiledAndFusedOps) {
|
|
if (llvm::any_of(op->getUsers(), [&](Operation *user) {
|
|
return dominanceInfo.properlyDominates(tilingInterfaceOp, user);
|
|
})) {
|
|
yieldReplacementsFor.insert(op);
|
|
}
|
|
}
|
|
|
|
scf::SCFTileAndFuseOptions tileAndFuseOptions;
|
|
tileAndFuseOptions.setTilingOptions(tilingOptions);
|
|
|
|
scf::SCFTileAndFuseOptions::ControlFnTy controlFn =
|
|
[&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
|
|
bool isDestinationOperand)
|
|
-> std::optional<scf::SCFTileAndFuseOptions::ControlFnResult> {
|
|
Operation *owner = originalProducer.getOwner();
|
|
bool yieldProducerReplacement = yieldReplacementsFor.contains(owner);
|
|
return scf::SCFTileAndFuseOptions::ControlFnResult{
|
|
yieldProducerReplacement};
|
|
};
|
|
tileAndFuseOptions.setFusionControlFn(controlFn);
|
|
|
|
rewriter.setInsertionPoint(target);
|
|
FailureOr<scf::SCFTileAndFuseResult> tiledResults =
|
|
scf::tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
|
|
tileAndFuseOptions);
|
|
if (failed(tiledResults))
|
|
return failure();
|
|
|
|
// Perform the replacement of tiled and fused values.
|
|
SmallVector<Operation *> opsToReplace{target};
|
|
llvm::append_range(opsToReplace, tiledResults->fusedProducers);
|
|
for (Operation *toReplace : opsToReplace) {
|
|
for (OpResult res : toReplace->getResults())
|
|
if (auto replacement = tiledResults->replacements.lookup(res)) {
|
|
Operation *replacementOp = replacement.getDefiningOp();
|
|
rewriter.replaceUsesWithIf(res, replacement, [&](OpOperand &use) {
|
|
Operation *user = use.getOwner();
|
|
return dominanceInfo.properlyDominates(replacementOp, user) &&
|
|
user->getParentOp() == replacementOp->getParentOp();
|
|
});
|
|
}
|
|
|
|
if (toReplace->use_empty()) {
|
|
rewriter.eraseOp(toReplace);
|
|
}
|
|
}
|
|
|
|
// Report back the relevant handles to the transform op.
|
|
tiledOps.push_back(tiledResults->tiledAndFusedOps.front());
|
|
assert(tiledResults->loops.size() == numLoops &&
|
|
"Mismatched number of loops, tile and fuse transform should have "
|
|
"failed");
|
|
for (unsigned int i = 0; i < numLoops; ++i)
|
|
loopOps[i].push_back(tiledResults->loops[i]);
|
|
}
|
|
|
|
transformResults.set(transformOp->getOpResult(0), tiledOps);
|
|
for (unsigned int i = 0; i < numLoops; ++i)
|
|
transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
|
|
|
|
return success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
|
|
TransformResults &transformResults,
|
|
TransformState &state) {
|
|
SmallVector<int64_t> tileSizes =
|
|
extractFromIntegerArrayAttr<int64_t>(getTileSizes());
|
|
SmallVector<int64_t> tileInterchange =
|
|
extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
|
|
|
|
SmallVector<OpFoldResult> tileSizesOfr =
|
|
getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
|
|
|
|
scf::SCFTilingOptions tilingOptions;
|
|
tilingOptions.setTileSizes(tileSizesOfr).setInterchange(tileInterchange);
|
|
if (getUseForall()) {
|
|
tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
|
|
}
|
|
|
|
LogicalResult result = applyTileAndFuseToAll(
|
|
rewriter, getOperation(), state.getPayloadOps(getTarget()),
|
|
tileSizes.size() - llvm::count(tileSizes, 0), tilingOptions,
|
|
transformResults);
|
|
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
|
|
: DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TestFuseConsumerOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Fuse the consumer and store both the original consumer operation as well as
|
|
/// the fused consumer operation.
|
|
static LogicalResult
|
|
applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp,
|
|
Operation *consumer,
|
|
MutableArrayRef<LoopLikeOpInterface> loops,
|
|
TransformResults &transformResults) {
|
|
SmallVector<Operation *> fusedConsumerOps;
|
|
rewriter.setInsertionPoint(consumer);
|
|
|
|
FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
|
|
scf::tileAndFuseConsumer(rewriter, consumer, loops);
|
|
if (failed(fuseConsumerResults))
|
|
return consumer->emitOpError("failed to fuse consumer of slice");
|
|
|
|
// Report back the relevant handles to the transform op.
|
|
for (OpOperand *tiledAndFusedConsumerOperand :
|
|
fuseConsumerResults->tiledAndFusedConsumerOperands) {
|
|
fusedConsumerOps.push_back(tiledAndFusedConsumerOperand->getOwner());
|
|
}
|
|
transformResults.set(transformOp->getOpResult(0), fusedConsumerOps);
|
|
for (auto [index, loop] : llvm::enumerate(loops)) {
|
|
transformResults.set(transformOp->getOpResult(index + 1), {loop});
|
|
}
|
|
return success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
|
|
TransformResults &transformResults,
|
|
TransformState &state) {
|
|
Operation *consumer = *state.getPayloadOps(getConsumer()).begin();
|
|
|
|
SmallVector<LoopLikeOpInterface> loops;
|
|
// Since the matcher works inside-out, we need to iterate the loops in
|
|
// reverse.
|
|
for (auto loop : llvm::reverse(getLoops())) {
|
|
auto loopLikeOp =
|
|
dyn_cast<LoopLikeOpInterface>(*state.getPayloadOps(loop).begin());
|
|
if (!loopLikeOp) {
|
|
return DiagnosedSilenceableFailure::definiteFailure();
|
|
}
|
|
loops.push_back(loopLikeOp);
|
|
}
|
|
LogicalResult result = applyFuseConsumer(rewriter, getOperation(), consumer,
|
|
loops, transformResults);
|
|
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
|
|
: DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void transform::TestFuseConsumerOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
consumesHandle(getConsumerMutable(), effects);
|
|
consumesHandle(getLoopsMutable(), effects);
|
|
producesHandle(getOperation()->getOpResults(), effects);
|
|
modifiesPayload(effects);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TestFuseConsumerUsingSliceOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Apply fusing of consumer transformation to all payload ops and store both
|
|
/// the original consumer operation as well as the fused consumer operation.
|
|
static LogicalResult applyFuseConsumerUsingSlices(
|
|
RewriterBase &rewriter, Operation *transformOp,
|
|
ArrayRef<Operation *> slices, MutableArrayRef<LoopLikeOpInterface> loops,
|
|
uint32_t numConsumerToFuse, TransformResults &transformResults) {
|
|
SmallVector<Operation *> originalConsumerOps;
|
|
SmallVector<Operation *> fusedConsumerOps;
|
|
|
|
rewriter.setInsertionPoint(slices.front());
|
|
|
|
while (numConsumerToFuse--) {
|
|
FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
|
|
scf::tileAndFuseConsumerOfSlices(rewriter, slices, loops);
|
|
|
|
if (failed(fuseConsumerResults))
|
|
return slices.front()->emitOpError("failed to fuse consumer of slice");
|
|
|
|
// Report back the relevant handles to the transform op.
|
|
for (OpOperand *origConsumerOperand :
|
|
fuseConsumerResults->origConsumerOperands) {
|
|
originalConsumerOps.push_back(origConsumerOperand->getOwner());
|
|
}
|
|
for (OpOperand *tiledAndFusedConsumerOperand :
|
|
fuseConsumerResults->tiledAndFusedConsumerOperands) {
|
|
fusedConsumerOps.push_back(tiledAndFusedConsumerOperand->getOwner());
|
|
}
|
|
}
|
|
|
|
transformResults.set(transformOp->getOpResult(0), originalConsumerOps);
|
|
transformResults.set(transformOp->getOpResult(1), fusedConsumerOps);
|
|
return success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure transform::TestFuseConsumerUsingSliceOp::apply(
|
|
TransformRewriter &rewriter, TransformResults &transformResults,
|
|
TransformState &state) {
|
|
SmallVector<Operation *> slices;
|
|
for (auto op : getTargets()) {
|
|
auto sliceOp = *state.getPayloadOps(op).begin();
|
|
slices.push_back(sliceOp);
|
|
}
|
|
|
|
SmallVector<LoopLikeOpInterface> loops;
|
|
for (auto op : llvm::reverse(getLoops())) {
|
|
auto loopLikeOp =
|
|
dyn_cast<LoopLikeOpInterface>(*state.getPayloadOps(op).begin());
|
|
if (!loopLikeOp) {
|
|
return DiagnosedSilenceableFailure::definiteFailure();
|
|
}
|
|
loops.push_back(loopLikeOp);
|
|
}
|
|
LogicalResult result =
|
|
applyFuseConsumerUsingSlices(rewriter, getOperation(), slices, loops,
|
|
getNumConsumerToFuse(), transformResults);
|
|
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
|
|
: DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void transform::TestFuseConsumerUsingSliceOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
consumesHandle(getTargetsMutable(), effects);
|
|
consumesHandle(getLoopsMutable(), effects);
|
|
producesHandle(getOperation()->getOpResults(), effects);
|
|
modifiesPayload(effects);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TestTileUsingForallOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Apply a tiling transformation to all payload ops and store both the
|
|
/// tiled operation as well as the created tile loops.
|
|
template <typename Range>
|
|
static LogicalResult
|
|
applyTileToAll(RewriterBase &rewriter, Operation *transformOp,
|
|
Range &&payloadOps, ArrayRef<OpFoldResult> tileSizes,
|
|
ArrayRef<int64_t> interchange, std::optional<ArrayAttr> mapping,
|
|
TransformResults &transformResults) {
|
|
SmallVector<Operation *> tiledOps;
|
|
SmallVector<Operation *> loopOps;
|
|
|
|
for (Operation *target : payloadOps) {
|
|
auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
|
|
if (!tilingInterfaceOp)
|
|
return transformOp->emitError("only TilingInterface ops are supported");
|
|
scf::SCFTilingOptions tilingOptions;
|
|
tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
|
|
if (mapping) {
|
|
tilingOptions.setMapping(mapping.value().getValue());
|
|
}
|
|
tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
|
|
|
|
rewriter.setInsertionPoint(target);
|
|
FailureOr<scf::SCFTilingResult> tiledResults =
|
|
scf::tileUsingSCF(rewriter, tilingInterfaceOp, tilingOptions);
|
|
if (failed(tiledResults))
|
|
return failure();
|
|
|
|
// Perform the replacement of tiled and fused values.
|
|
rewriter.replaceOp(tilingInterfaceOp, tiledResults->replacements);
|
|
|
|
// Report back the relevant handles to the transform op.
|
|
tiledOps.push_back(tiledResults->tiledOps.front());
|
|
for (Operation *loop : tiledResults->loops)
|
|
loopOps.push_back(loop);
|
|
}
|
|
|
|
transformResults.set(transformOp->getOpResult(0), tiledOps);
|
|
for (auto [index, loop] : llvm::enumerate(loopOps))
|
|
transformResults.set(transformOp->getOpResult(index + 1), {loop});
|
|
|
|
return success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
transform::TestTileUsingForallOp::apply(TransformRewriter &rewriter,
|
|
TransformResults &transformResults,
|
|
TransformState &state) {
|
|
SmallVector<int64_t> tileSizes =
|
|
extractFromIntegerArrayAttr<int64_t>(getTileSizes());
|
|
SmallVector<int64_t> interchange =
|
|
extractFromIntegerArrayAttr<int64_t>(getInterchange());
|
|
SmallVector<OpFoldResult> tileSizesOfr =
|
|
getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
|
|
|
|
LogicalResult result =
|
|
applyTileToAll(rewriter, getOperation(), state.getPayloadOps(getTarget()),
|
|
tileSizesOfr, interchange, getMapping(), transformResults);
|
|
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
|
|
: DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void transform::TestTileUsingForallOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
consumesHandle(getTargetMutable(), effects);
|
|
producesHandle(getOperation()->getOpResults(), effects);
|
|
modifiesPayload(effects);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TestFuseUsingForallOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Apply a tiling transformation to all payload ops and store both the
|
|
/// tiled operation as well as the created tile loops.
|
|
template <typename Range>
|
|
static LogicalResult applyTilingToAll(
|
|
RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
|
|
unsigned numLoops, TransformResults &transformResults,
|
|
function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
|
|
applyFn) {
|
|
SmallVector<Operation *> tiledLinalgOps;
|
|
SmallVector<SmallVector<Operation *>> loopOps(1);
|
|
|
|
for (Operation *target : payloadOps) {
|
|
auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
|
|
if (!tilingInterfaceOp)
|
|
return transformOp->emitError("only TilingInterface ops are supported");
|
|
|
|
rewriter.setInsertionPoint(target);
|
|
FailureOr<scf::SCFTileAndFuseResult> tiledResults =
|
|
applyFn(tilingInterfaceOp);
|
|
if (failed(tiledResults))
|
|
return failure();
|
|
|
|
// Perform the replacement of tiled and fused values.
|
|
SmallVector<Operation *> opsToReplace{target};
|
|
llvm::append_range(opsToReplace, tiledResults->fusedProducers);
|
|
for (Operation *toReplace : opsToReplace) {
|
|
for (OpResult res : toReplace->getResults())
|
|
if (auto replacement = tiledResults->replacements.lookup(res))
|
|
rewriter.replaceAllUsesWith(res, replacement);
|
|
if (toReplace->use_empty())
|
|
rewriter.eraseOp(toReplace);
|
|
}
|
|
|
|
// Report back the relevant handles to the transform op.
|
|
tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
|
|
assert(tiledResults->loops.size() == 1 &&
|
|
cast<scf::ForallOp>(tiledResults->loops[0]).getRank() == numLoops &&
|
|
"Mismatched number of loops, tile and fuse transform should have "
|
|
"failed");
|
|
loopOps[0] = {tiledResults->loops[0]};
|
|
}
|
|
|
|
transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
|
|
if (!loopOps.empty())
|
|
transformResults.set(transformOp->getOpResult(1), loopOps[0]);
|
|
|
|
return success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
transform::TestFuseUsingForallOp::apply(TransformRewriter &rewriter,
|
|
TransformResults &transformResults,
|
|
TransformState &state) {
|
|
SmallVector<int64_t> tileSizes =
|
|
extractFromIntegerArrayAttr<int64_t>(getTileSizes());
|
|
SmallVector<int64_t> tileInterchange =
|
|
extractFromIntegerArrayAttr<int64_t>(getInterchange());
|
|
|
|
scf::SCFTilingOptions tilingOptions;
|
|
tilingOptions.interchangeVector = tileInterchange;
|
|
SmallVector<OpFoldResult> tileSizesOfr =
|
|
getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
|
|
tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
|
|
tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
|
|
scf::SCFTileAndFuseOptions tileAndFuseOptions;
|
|
tileAndFuseOptions.tilingOptions = tilingOptions;
|
|
LogicalResult result = applyTilingToAll(
|
|
rewriter, getOperation(), state.getPayloadOps(getRootOp()),
|
|
tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
|
|
[&](TilingInterface tilingInterfaceOp)
|
|
-> FailureOr<scf::SCFTileAndFuseResult> {
|
|
return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
|
|
tileAndFuseOptions);
|
|
});
|
|
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
|
|
: DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void transform::TestFuseUsingForallOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
consumesHandle(getRootOpMutable(), effects);
|
|
producesHandle(getOperation()->getOpResults(), effects);
|
|
modifiesPayload(effects);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TestTileAndFuseOuterParallelPartialReduction
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
DiagnosedSilenceableFailure
|
|
transform::TestTileAndFuseOuterParallelPartialReductionOp::apply(
|
|
TransformRewriter &rewriter, TransformResults &transformResults,
|
|
TransformState &state) {
|
|
auto target =
|
|
dyn_cast<TilingInterface>(*state.getPayloadOps(getRootOp()).begin());
|
|
if (!target) {
|
|
emitOpError("expected root operation to implement `TilingInterface`");
|
|
return DiagnosedSilenceableFailure::definiteFailure();
|
|
}
|
|
|
|
SmallVector<unsigned> reductionDims =
|
|
extractFromIntegerArrayAttr<unsigned>(getReductionDims());
|
|
if (reductionDims.empty()) {
|
|
for (auto [index, iterator] :
|
|
llvm::enumerate(target.getLoopIteratorTypes()))
|
|
if (iterator == utils::IteratorType::reduction)
|
|
reductionDims.push_back(index);
|
|
}
|
|
|
|
if (reductionDims.empty()) {
|
|
emitOpError(
|
|
"no reduction dimension specified or found in the target operation");
|
|
return DiagnosedSilenceableFailure::definiteFailure();
|
|
}
|
|
|
|
SmallVector<int64_t> reductionTileSizes =
|
|
extractFromIntegerArrayAttr<int64_t>(getTileSizes());
|
|
if (reductionTileSizes.size() != reductionDims.size()) {
|
|
emitOpError(
|
|
"missing tile sizes for reduction dimensions that are to be tiled");
|
|
return DiagnosedSilenceableFailure::definiteFailure();
|
|
}
|
|
|
|
// Adjust tile sizes so that it corresponds to the reduction iterator types.
|
|
SmallVector<OpFoldResult> tileSizes;
|
|
int reductionTileSizeNum = 0;
|
|
OpFoldResult zero = rewriter.getIndexAttr(0);
|
|
for (auto iterator : target.getLoopIteratorTypes()) {
|
|
if (iterator == utils::IteratorType::parallel) {
|
|
tileSizes.push_back(zero);
|
|
continue;
|
|
}
|
|
tileSizes.push_back(
|
|
rewriter.getIndexAttr(reductionTileSizes[reductionTileSizeNum++]));
|
|
}
|
|
|
|
scf::SCFTilingOptions tilingOptions;
|
|
tilingOptions.setTileSizes(tileSizes)
|
|
.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp)
|
|
.setReductionTilingStrategy(
|
|
ReductionTilingStrategy::PartialReductionOuterParallel)
|
|
.setReductionDims(reductionDims);
|
|
if (auto mapping = getMapping()) {
|
|
tilingOptions.setMapping(getMapping().value());
|
|
}
|
|
|
|
LogicalResult result = applyTileAndFuseToAll(
|
|
rewriter, getOperation(), state.getPayloadOps(getRootOp()),
|
|
/*numLoops =*/1, tilingOptions, transformResults);
|
|
|
|
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
|
|
: DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TestTileAndFuseOuterParallelPartialReduction
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
DiagnosedSilenceableFailure transform::TestTileUsingCustomLoopOp::apply(
|
|
TransformRewriter &transformRewriter, TransformResults &transformResults,
|
|
TransformState &state) {
|
|
auto target =
|
|
dyn_cast<TilingInterface>(*state.getPayloadOps(getRootOp()).begin());
|
|
if (!target) {
|
|
emitOpError("expected root operation to implement `TilingInterface`");
|
|
return DiagnosedSilenceableFailure::definiteFailure();
|
|
}
|
|
|
|
OpFoldResult oneOfr = transformRewriter.getIndexAttr(1);
|
|
|
|
scf::SCFTilingOptions::GenerateLoopHeaderFn loopHeaderFn =
|
|
[&](RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
|
|
ArrayRef<OpFoldResult> givenTileSizes,
|
|
ValueRange outerDestinationTensors)
|
|
-> FailureOr<scf::SCFTilingOptions::CustomLoopHeaderInfo> {
|
|
// Check that the strides are all 1 (to make it easier in the test).
|
|
if (llvm::any_of(loopRanges, [](Range r) {
|
|
return !isConstantIntValue(r.stride, 1);
|
|
})) {
|
|
return emitOpError("unable to handle loop ranges with strides != 1");
|
|
}
|
|
// Check number of tile sizes is equal to loop dimensions.
|
|
if (loopRanges.size() != givenTileSizes.size()) {
|
|
return emitOpError("expected number of tile sizes to be same as the "
|
|
"number of loops in the operation");
|
|
}
|
|
// For testing disallow any of the tile sizes being 0.
|
|
if (llvm::any_of(givenTileSizes, isZeroInteger)) {
|
|
return emitOpError("unhandled case of zero tile size");
|
|
}
|
|
// For testing, only handle tensor tiling.
|
|
if (outerDestinationTensors.empty()) {
|
|
return emitOpError("expected destination tensors");
|
|
}
|
|
|
|
// Compute the number of iterations for each of the loops.
|
|
AffineExpr s0, s1, s2;
|
|
bindSymbols(rewriter.getContext(), s0, s1, s2);
|
|
AffineExpr numItersExpr = (s1 - s0).ceilDiv(s2); // (ub - lb) / tileSize
|
|
|
|
SmallVector<OpFoldResult> allNumIters;
|
|
allNumIters.reserve(loopRanges.size());
|
|
for (auto [loopRange, tileSize] :
|
|
llvm::zip_equal(loopRanges, givenTileSizes)) {
|
|
OpFoldResult numIters = affine::makeComposedFoldedAffineApply(
|
|
rewriter, loc, numItersExpr,
|
|
{loopRange.offset, loopRange.size, tileSize});
|
|
allNumIters.push_back(numIters);
|
|
}
|
|
if (allNumIters.empty()) {
|
|
return emitOpError("invalid empty tile sizes and loop ranges");
|
|
}
|
|
|
|
AffineExpr mulExpr = s0 * s1;
|
|
OpFoldResult cumulative = oneOfr;
|
|
for (auto numIters : allNumIters) {
|
|
cumulative = affine::makeComposedFoldedAffineApply(
|
|
rewriter, loc, mulExpr, {cumulative, numIters});
|
|
}
|
|
|
|
Value zeroVal = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
|
Value oneVal = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
|
Value ub = getValueOrCreateConstantIndexOp(rewriter, loc, cumulative);
|
|
|
|
SmallVector<OpFoldResult> offsets;
|
|
SmallVector<OpFoldResult> sizes;
|
|
SmallVector<Value> innerDestinationTensors;
|
|
offsets.reserve(loopRanges.size());
|
|
sizes.reserve(loopRanges.size());
|
|
|
|
AffineExpr d0;
|
|
bindDims(rewriter.getContext(), d0);
|
|
AffineExpr offsetExpr = s0 + d0 * s1; // lb + iv * tileSize
|
|
AffineMap minMap =
|
|
AffineMap::get(1, 2, {s0 - d0, s1},
|
|
rewriter.getContext()); // min(ub - offset, tileSize)
|
|
auto forOp = scf::ForOp::create(
|
|
rewriter, loc, zeroVal, ub, oneVal, outerDestinationTensors,
|
|
[&](OpBuilder &b, Location bodyLoc, Value linearizedIv,
|
|
ValueRange destinations) {
|
|
auto delinearizeOp = affine::AffineDelinearizeIndexOp::create(
|
|
b, bodyLoc, linearizedIv, allNumIters);
|
|
for (auto [normalizedIv, range, tileSize] : llvm::zip_equal(
|
|
delinearizeOp.getResults(), loopRanges, givenTileSizes)) {
|
|
|
|
OpFoldResult normalizedIvOfr = getAsOpFoldResult(normalizedIv);
|
|
OpFoldResult offset = affine::makeComposedFoldedAffineApply(
|
|
b, bodyLoc, offsetExpr,
|
|
{normalizedIvOfr, range.offset, tileSize});
|
|
offsets.push_back(offset);
|
|
|
|
OpFoldResult size = affine::makeComposedFoldedAffineMin(
|
|
b, bodyLoc, minMap, {offset, range.size, tileSize});
|
|
sizes.push_back(size);
|
|
}
|
|
innerDestinationTensors = llvm::to_vector(destinations);
|
|
});
|
|
rewriter.setInsertionPointToEnd(forOp.getBody());
|
|
return scf::SCFTilingOptions::CustomLoopHeaderInfo{
|
|
{cast<LoopLikeOpInterface>(forOp.getOperation())},
|
|
offsets,
|
|
sizes,
|
|
innerDestinationTensors};
|
|
};
|
|
|
|
scf::SCFTilingOptions::GenerateLoopTerminatorFn terminatorFn =
|
|
[&](RewriterBase &rewriter, Location loc,
|
|
ArrayRef<LoopLikeOpInterface> loops, ValueRange tiledResults,
|
|
ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
|
|
ArrayRef<SmallVector<OpFoldResult>> resultSizes,
|
|
ValueRange destinationTensors) -> LogicalResult {
|
|
SmallVector<Value> yieldValues;
|
|
yieldValues.reserve(destinationTensors.size());
|
|
for (auto [tiledResult, offsets, sizes, destination] : llvm::zip_equal(
|
|
tiledResults, resultOffsets, resultSizes, destinationTensors)) {
|
|
SmallVector<OpFoldResult> strides(offsets.size(), oneOfr);
|
|
Value insertedVal = tensor::InsertSliceOp::create(
|
|
rewriter, loc, tiledResult, destination, offsets, sizes, strides);
|
|
yieldValues.push_back(insertedVal);
|
|
}
|
|
scf::YieldOp::create(rewriter, loc, yieldValues);
|
|
return success();
|
|
};
|
|
|
|
scf::SCFTilingOptions tilingOptions;
|
|
SmallVector<int64_t> staticTileSizes =
|
|
extractFromIntegerArrayAttr<int64_t>(getTileSizes());
|
|
SmallVector<OpFoldResult> tileSizes =
|
|
getAsIndexOpFoldResult(transformRewriter.getContext(), staticTileSizes);
|
|
tilingOptions.setTileSizes(tileSizes)
|
|
.setLoopType(scf::SCFTilingOptions::LoopType::CustomOp)
|
|
.setCustomLoopGenerationFns(loopHeaderFn, terminatorFn);
|
|
|
|
OpBuilder::InsertionGuard g(transformRewriter);
|
|
transformRewriter.setInsertionPoint(target);
|
|
FailureOr<scf::SCFTilingResult> tiledResults =
|
|
scf::tileUsingSCF(transformRewriter, target, tilingOptions);
|
|
if (failed(tiledResults)) {
|
|
return DiagnosedSilenceableFailure::definiteFailure();
|
|
}
|
|
transformRewriter.replaceOp(target, tiledResults->replacements);
|
|
transformResults.set(getOperation()->getResult(0), tiledResults->tiledOps);
|
|
transformResults.set(getOperation()->getResult(1), tiledResults->loops);
|
|
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "TestTilingInterfaceTransformOps.cpp.inc"
|
|
|
|
namespace {
|
|
class TestTilingInterfaceDialectExtension
|
|
: public transform::TransformDialectExtension<
|
|
TestTilingInterfaceDialectExtension> {
|
|
public:
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
|
|
TestTilingInterfaceDialectExtension)
|
|
|
|
using Base::Base;
|
|
|
|
void init() {
|
|
declareDependentDialect<affine::AffineDialect>();
|
|
declareDependentDialect<index::IndexDialect>();
|
|
declareDependentDialect<scf::SCFDialect>();
|
|
declareDependentDialect<tensor::TensorDialect>();
|
|
|
|
registerTransformOps<
|
|
#define GET_OP_LIST
|
|
#include "TestTilingInterfaceTransformOps.cpp.inc"
|
|
>();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
namespace test {
|
|
void registerTestTilingInterfaceTransformDialectExtension(
|
|
DialectRegistry ®istry) {
|
|
registry.addExtensions<TestTilingInterfaceDialectExtension>();
|
|
}
|
|
} // namespace test
|