This PR adds to the generateLoopTerminatorFn callback the loops generated by GenerateLoopHeaderFn. This is needed to correctly set the insertion point with scf.forall ops.
658 lines
26 KiB
C++
658 lines
26 KiB
C++
//===- TestTilingInterfaceTransformOps.cpp - Test `TilingInterface` ------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file defines transform dialect operations used for testing
|
|
// TilingInterface
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Index/IR/IndexDialect.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
|
|
#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
|
|
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
|
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
|
|
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
|
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
|
#include "mlir/IR/Dominance.h"
|
|
#include "mlir/IR/OpImplementation.h"
|
|
#include "mlir/Interfaces/TilingInterface.h"
|
|
#include "llvm/Support/Debug.h"
|
|
|
|
#define DEBUG_TYPE "test-tiling-interface"
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "TestTilingInterfaceTransformOps.h.inc"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::transform;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TestFuseAndYieldOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static llvm::SmallDenseSet<Operation *> collectTiledAndFusedOps(Operation *op) {
|
|
SmallVector<Operation *> worklist;
|
|
llvm::SmallDenseSet<Operation *> producers;
|
|
worklist.push_back(op);
|
|
producers.insert(op);
|
|
while (!worklist.empty()) {
|
|
Operation *current = worklist.pop_back_val();
|
|
for (OpOperand &operand : current->getOpOperands()) {
|
|
Operation *producer = operand.get().getDefiningOp();
|
|
if (!producer || !isa<TilingInterface>(producer) ||
|
|
producers.contains(producer))
|
|
continue;
|
|
worklist.push_back(producer);
|
|
producers.insert(producer);
|
|
}
|
|
}
|
|
return producers;
|
|
}
|
|
|
|
/// Apply a tile and fuse transformation to all payload ops and store both the
|
|
/// tiled operation as well as the created tile loops.
|
|
template <typename Range>
|
|
static LogicalResult
|
|
applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
|
|
Range &&payloadOps, unsigned numLoops,
|
|
scf::SCFTilingOptions tilingOptions,
|
|
TransformResults &transformResults) {
|
|
SmallVector<Operation *> tiledOps;
|
|
SmallVector<SmallVector<Operation *>> loopOps(numLoops);
|
|
|
|
for (Operation *target : payloadOps) {
|
|
auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
|
|
if (!tilingInterfaceOp)
|
|
return transformOp->emitError("only TilingInterface ops are supported");
|
|
DominanceInfo dominanceInfo(tilingInterfaceOp);
|
|
|
|
llvm::SmallDenseSet<Operation *> tiledAndFusedOps =
|
|
collectTiledAndFusedOps(tilingInterfaceOp);
|
|
llvm::DenseSet<Operation *> yieldReplacementsFor;
|
|
for (auto op : tiledAndFusedOps) {
|
|
if (llvm::any_of(op->getUsers(), [&](Operation *user) {
|
|
return dominanceInfo.properlyDominates(tilingInterfaceOp, user);
|
|
})) {
|
|
yieldReplacementsFor.insert(op);
|
|
}
|
|
}
|
|
|
|
scf::SCFTileAndFuseOptions tileAndFuseOptions;
|
|
tileAndFuseOptions.setTilingOptions(tilingOptions);
|
|
|
|
scf::SCFTileAndFuseOptions::ControlFnTy controlFn =
|
|
[&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
|
|
bool isDestinationOperand)
|
|
-> std::optional<scf::SCFTileAndFuseOptions::ControlFnResult> {
|
|
Operation *owner = originalProducer.getOwner();
|
|
bool yieldProducerReplacement = yieldReplacementsFor.contains(owner);
|
|
return scf::SCFTileAndFuseOptions::ControlFnResult{
|
|
yieldProducerReplacement};
|
|
};
|
|
tileAndFuseOptions.setFusionControlFn(controlFn);
|
|
|
|
rewriter.setInsertionPoint(target);
|
|
FailureOr<scf::SCFTileAndFuseResult> tiledResults =
|
|
scf::tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
|
|
tileAndFuseOptions);
|
|
if (failed(tiledResults))
|
|
return failure();
|
|
|
|
// Perform the replacement of tiled and fused values.
|
|
SmallVector<Operation *> opsToReplace{target};
|
|
llvm::append_range(opsToReplace, tiledResults->fusedProducers);
|
|
for (Operation *toReplace : opsToReplace) {
|
|
for (OpResult res : toReplace->getResults())
|
|
if (auto replacement = tiledResults->replacements.lookup(res)) {
|
|
Operation *replacementOp = replacement.getDefiningOp();
|
|
rewriter.replaceUsesWithIf(res, replacement, [&](OpOperand &use) {
|
|
Operation *user = use.getOwner();
|
|
return dominanceInfo.properlyDominates(replacementOp, user) &&
|
|
user->getParentOp() == replacementOp->getParentOp();
|
|
});
|
|
}
|
|
|
|
if (toReplace->use_empty()) {
|
|
rewriter.eraseOp(toReplace);
|
|
}
|
|
}
|
|
|
|
// Report back the relevant handles to the transform op.
|
|
tiledOps.push_back(tiledResults->tiledAndFusedOps.front());
|
|
assert(tiledResults->loops.size() == numLoops &&
|
|
"Mismatched number of loops, tile and fuse transform should have "
|
|
"failed");
|
|
for (unsigned int i = 0; i < numLoops; ++i)
|
|
loopOps[i].push_back(tiledResults->loops[i]);
|
|
}
|
|
|
|
transformResults.set(transformOp->getOpResult(0), tiledOps);
|
|
for (unsigned int i = 0; i < numLoops; ++i)
|
|
transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
|
|
|
|
return success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
|
|
TransformResults &transformResults,
|
|
TransformState &state) {
|
|
SmallVector<int64_t> tileSizes =
|
|
extractFromIntegerArrayAttr<int64_t>(getTileSizes());
|
|
SmallVector<int64_t> tileInterchange =
|
|
extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
|
|
|
|
SmallVector<OpFoldResult> tileSizesOfr =
|
|
getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
|
|
|
|
scf::SCFTilingOptions tilingOptions;
|
|
tilingOptions.setTileSizes(tileSizesOfr).setInterchange(tileInterchange);
|
|
if (getUseForall()) {
|
|
tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
|
|
}
|
|
|
|
LogicalResult result = applyTileAndFuseToAll(
|
|
rewriter, getOperation(), state.getPayloadOps(getTarget()),
|
|
tileSizes.size() - llvm::count(tileSizes, 0), tilingOptions,
|
|
transformResults);
|
|
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
|
|
: DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TestFuseConsumerOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Apply fusing of consumer transformation to all payload ops and store both
|
|
/// the original consumer operation as well as the fused consumer operation.
|
|
static LogicalResult applyFuseConsumer(
|
|
RewriterBase &rewriter, Operation *transformOp,
|
|
ArrayRef<Operation *> slices, MutableArrayRef<LoopLikeOpInterface> loops,
|
|
uint32_t numConsumerToFuse, TransformResults &transformResults) {
|
|
SmallVector<Operation *> originalConsumerOps;
|
|
SmallVector<Operation *> fusedConsumerOps;
|
|
|
|
rewriter.setInsertionPoint(slices.front());
|
|
|
|
while (numConsumerToFuse--) {
|
|
FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
|
|
scf::tileAndFuseConsumerOfSlices(rewriter, slices, loops);
|
|
|
|
if (failed(fuseConsumerResults))
|
|
return slices.front()->emitOpError("failed to fuse consumer of slice");
|
|
|
|
// Report back the relevant handles to the transform op.
|
|
for (OpOperand *origConsumerOperand :
|
|
fuseConsumerResults->origConsumerOperands) {
|
|
originalConsumerOps.push_back(origConsumerOperand->getOwner());
|
|
}
|
|
for (OpOperand *tiledAndFusedConsumerOperand :
|
|
fuseConsumerResults->tiledAndFusedConsumerOperands) {
|
|
fusedConsumerOps.push_back(tiledAndFusedConsumerOperand->getOwner());
|
|
}
|
|
}
|
|
|
|
transformResults.set(transformOp->getOpResult(0), originalConsumerOps);
|
|
transformResults.set(transformOp->getOpResult(1), fusedConsumerOps);
|
|
return success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
|
|
TransformResults &transformResults,
|
|
TransformState &state) {
|
|
SmallVector<Operation *> slices;
|
|
for (auto op : getTargets()) {
|
|
auto sliceOp = *state.getPayloadOps(op).begin();
|
|
slices.push_back(sliceOp);
|
|
}
|
|
|
|
SmallVector<LoopLikeOpInterface> loops;
|
|
for (auto op : llvm::reverse(getLoops())) {
|
|
auto loopLikeOp =
|
|
dyn_cast<LoopLikeOpInterface>(*state.getPayloadOps(op).begin());
|
|
if (!loopLikeOp) {
|
|
return DiagnosedSilenceableFailure::definiteFailure();
|
|
}
|
|
loops.push_back(loopLikeOp);
|
|
}
|
|
LogicalResult result =
|
|
applyFuseConsumer(rewriter, getOperation(), slices, loops,
|
|
getNumConsumerToFuse(), transformResults);
|
|
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
|
|
: DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void transform::TestFuseConsumerOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
consumesHandle(getTargetsMutable(), effects);
|
|
consumesHandle(getLoopsMutable(), effects);
|
|
producesHandle(getOperation()->getOpResults(), effects);
|
|
modifiesPayload(effects);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TestTileUsingForallOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Apply a tiling transformation to all payload ops and store both the
|
|
/// tiled operation as well as the created tile loops.
|
|
template <typename Range>
|
|
static LogicalResult
|
|
applyTileToAll(RewriterBase &rewriter, Operation *transformOp,
|
|
Range &&payloadOps, ArrayRef<OpFoldResult> tileSizes,
|
|
ArrayRef<int64_t> interchange, std::optional<ArrayAttr> mapping,
|
|
TransformResults &transformResults) {
|
|
SmallVector<Operation *> tiledOps;
|
|
SmallVector<Operation *> loopOps;
|
|
|
|
for (Operation *target : payloadOps) {
|
|
auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
|
|
if (!tilingInterfaceOp)
|
|
return transformOp->emitError("only TilingInterface ops are supported");
|
|
scf::SCFTilingOptions tilingOptions;
|
|
tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
|
|
if (mapping) {
|
|
tilingOptions.setMapping(mapping.value().getValue());
|
|
}
|
|
tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
|
|
|
|
rewriter.setInsertionPoint(target);
|
|
FailureOr<scf::SCFTilingResult> tiledResults =
|
|
scf::tileUsingSCF(rewriter, tilingInterfaceOp, tilingOptions);
|
|
if (failed(tiledResults))
|
|
return failure();
|
|
|
|
// Perform the replacement of tiled and fused values.
|
|
rewriter.replaceOp(tilingInterfaceOp, tiledResults->replacements);
|
|
|
|
// Report back the relevant handles to the transform op.
|
|
tiledOps.push_back(tiledResults->tiledOps.front());
|
|
for (Operation *loop : tiledResults->loops)
|
|
loopOps.push_back(loop);
|
|
}
|
|
|
|
transformResults.set(transformOp->getOpResult(0), tiledOps);
|
|
for (auto [index, loop] : llvm::enumerate(loopOps))
|
|
transformResults.set(transformOp->getOpResult(index + 1), {loop});
|
|
|
|
return success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
transform::TestTileUsingForallOp::apply(TransformRewriter &rewriter,
|
|
TransformResults &transformResults,
|
|
TransformState &state) {
|
|
SmallVector<int64_t> tileSizes =
|
|
extractFromIntegerArrayAttr<int64_t>(getTileSizes());
|
|
SmallVector<int64_t> interchange =
|
|
extractFromIntegerArrayAttr<int64_t>(getInterchange());
|
|
SmallVector<OpFoldResult> tileSizesOfr =
|
|
getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
|
|
|
|
LogicalResult result =
|
|
applyTileToAll(rewriter, getOperation(), state.getPayloadOps(getTarget()),
|
|
tileSizesOfr, interchange, getMapping(), transformResults);
|
|
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
|
|
: DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void transform::TestTileUsingForallOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
consumesHandle(getTargetMutable(), effects);
|
|
producesHandle(getOperation()->getOpResults(), effects);
|
|
modifiesPayload(effects);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TestFuseUsingForallOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Apply a tiling transformation to all payload ops and store both the
|
|
/// tiled operation as well as the created tile loops.
|
|
template <typename Range>
|
|
static LogicalResult applyTilingToAll(
|
|
RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
|
|
unsigned numLoops, TransformResults &transformResults,
|
|
function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
|
|
applyFn) {
|
|
SmallVector<Operation *> tiledLinalgOps;
|
|
SmallVector<SmallVector<Operation *>> loopOps(1);
|
|
|
|
for (Operation *target : payloadOps) {
|
|
auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
|
|
if (!tilingInterfaceOp)
|
|
return transformOp->emitError("only TilingInterface ops are supported");
|
|
|
|
rewriter.setInsertionPoint(target);
|
|
FailureOr<scf::SCFTileAndFuseResult> tiledResults =
|
|
applyFn(tilingInterfaceOp);
|
|
if (failed(tiledResults))
|
|
return failure();
|
|
|
|
// Perform the replacement of tiled and fused values.
|
|
SmallVector<Operation *> opsToReplace{target};
|
|
llvm::append_range(opsToReplace, tiledResults->fusedProducers);
|
|
for (Operation *toReplace : opsToReplace) {
|
|
for (OpResult res : toReplace->getResults())
|
|
if (auto replacement = tiledResults->replacements.lookup(res))
|
|
rewriter.replaceAllUsesWith(res, replacement);
|
|
if (toReplace->use_empty())
|
|
rewriter.eraseOp(toReplace);
|
|
}
|
|
|
|
// Report back the relevant handles to the transform op.
|
|
tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
|
|
assert(tiledResults->loops.size() == 1 &&
|
|
cast<scf::ForallOp>(tiledResults->loops[0]).getRank() == numLoops &&
|
|
"Mismatched number of loops, tile and fuse transform should have "
|
|
"failed");
|
|
loopOps[0] = {tiledResults->loops[0]};
|
|
}
|
|
|
|
transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
|
|
if (!loopOps.empty())
|
|
transformResults.set(transformOp->getOpResult(1), loopOps[0]);
|
|
|
|
return success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
transform::TestFuseUsingForallOp::apply(TransformRewriter &rewriter,
|
|
TransformResults &transformResults,
|
|
TransformState &state) {
|
|
SmallVector<int64_t> tileSizes =
|
|
extractFromIntegerArrayAttr<int64_t>(getTileSizes());
|
|
SmallVector<int64_t> tileInterchange =
|
|
extractFromIntegerArrayAttr<int64_t>(getInterchange());
|
|
|
|
scf::SCFTilingOptions tilingOptions;
|
|
tilingOptions.interchangeVector = tileInterchange;
|
|
SmallVector<OpFoldResult> tileSizesOfr =
|
|
getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
|
|
tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
|
|
tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
|
|
scf::SCFTileAndFuseOptions tileAndFuseOptions;
|
|
tileAndFuseOptions.tilingOptions = tilingOptions;
|
|
LogicalResult result = applyTilingToAll(
|
|
rewriter, getOperation(), state.getPayloadOps(getRootOp()),
|
|
tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
|
|
[&](TilingInterface tilingInterfaceOp)
|
|
-> FailureOr<scf::SCFTileAndFuseResult> {
|
|
return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
|
|
tileAndFuseOptions);
|
|
});
|
|
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
|
|
: DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void transform::TestFuseUsingForallOp::getEffects(
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
consumesHandle(getRootOpMutable(), effects);
|
|
producesHandle(getOperation()->getOpResults(), effects);
|
|
modifiesPayload(effects);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TestTileAndFuseOuterParallelPartialReduction
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
DiagnosedSilenceableFailure
|
|
transform::TestTileAndFuseOuterParallelPartialReductionOp::apply(
|
|
TransformRewriter &rewriter, TransformResults &transformResults,
|
|
TransformState &state) {
|
|
auto target =
|
|
dyn_cast<TilingInterface>(*state.getPayloadOps(getRootOp()).begin());
|
|
if (!target) {
|
|
emitOpError("expected root operation to implement `TilingInterface`");
|
|
return DiagnosedSilenceableFailure::definiteFailure();
|
|
}
|
|
|
|
SmallVector<unsigned> reductionDims =
|
|
extractFromIntegerArrayAttr<unsigned>(getReductionDims());
|
|
if (reductionDims.empty()) {
|
|
for (auto [index, iterator] :
|
|
llvm::enumerate(target.getLoopIteratorTypes()))
|
|
if (iterator == utils::IteratorType::reduction)
|
|
reductionDims.push_back(index);
|
|
}
|
|
|
|
if (reductionDims.empty()) {
|
|
emitOpError(
|
|
"no reduction dimension specified or found in the target operation");
|
|
return DiagnosedSilenceableFailure::definiteFailure();
|
|
}
|
|
|
|
SmallVector<int64_t> reductionTileSizes =
|
|
extractFromIntegerArrayAttr<int64_t>(getTileSizes());
|
|
if (reductionTileSizes.size() != reductionDims.size()) {
|
|
emitOpError(
|
|
"missing tile sizes for reduction dimensions that are to be tiled");
|
|
return DiagnosedSilenceableFailure::definiteFailure();
|
|
}
|
|
|
|
// Adjust tile sizes so that it corresponds to the reduction iterator types.
|
|
SmallVector<OpFoldResult> tileSizes;
|
|
int reductionTileSizeNum = 0;
|
|
OpFoldResult zero = rewriter.getIndexAttr(0);
|
|
for (auto iterator : target.getLoopIteratorTypes()) {
|
|
if (iterator == utils::IteratorType::parallel) {
|
|
tileSizes.push_back(zero);
|
|
continue;
|
|
}
|
|
tileSizes.push_back(
|
|
rewriter.getIndexAttr(reductionTileSizes[reductionTileSizeNum++]));
|
|
}
|
|
|
|
scf::SCFTilingOptions tilingOptions;
|
|
tilingOptions.setTileSizes(tileSizes)
|
|
.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp)
|
|
.setReductionTilingStrategy(
|
|
ReductionTilingStrategy::PartialReductionOuterParallel)
|
|
.setReductionDims(reductionDims);
|
|
if (auto mapping = getMapping()) {
|
|
tilingOptions.setMapping(getMapping().value());
|
|
}
|
|
|
|
LogicalResult result = applyTileAndFuseToAll(
|
|
rewriter, getOperation(), state.getPayloadOps(getRootOp()),
|
|
/*numLoops =*/1, tilingOptions, transformResults);
|
|
|
|
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
|
|
: DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TestTileAndFuseOuterParallelPartialReduction
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
DiagnosedSilenceableFailure transform::TestTileUsingCustomLoopOp::apply(
|
|
TransformRewriter &transformRewriter, TransformResults &transformResults,
|
|
TransformState &state) {
|
|
auto target =
|
|
dyn_cast<TilingInterface>(*state.getPayloadOps(getRootOp()).begin());
|
|
if (!target) {
|
|
emitOpError("expected root operation to implement `TilingInterface`");
|
|
return DiagnosedSilenceableFailure::definiteFailure();
|
|
}
|
|
|
|
OpFoldResult oneOfr = transformRewriter.getIndexAttr(1);
|
|
|
|
scf::SCFTilingOptions::GenerateLoopHeaderFn loopHeaderFn =
|
|
[&](RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
|
|
ArrayRef<OpFoldResult> givenTileSizes,
|
|
ValueRange outerDestinationTensors)
|
|
-> FailureOr<scf::SCFTilingOptions::CustomLoopHeaderInfo> {
|
|
// Check that the strides are all 1 (to make it easier in the test).
|
|
if (llvm::any_of(loopRanges, [](Range r) {
|
|
return !isConstantIntValue(r.stride, 1);
|
|
})) {
|
|
return emitOpError("unable to handle loop ranges with strides != 1");
|
|
}
|
|
// Check number of tile sizes is equal to loop dimensions.
|
|
if (loopRanges.size() != givenTileSizes.size()) {
|
|
return emitOpError("expected number of tile sizes to be same as the "
|
|
"number of loops in the operation");
|
|
}
|
|
// For testing disallow any of the tile sizes being 0.
|
|
if (llvm::any_of(givenTileSizes, isZeroInteger)) {
|
|
return emitOpError("unhandled case of zero tile size");
|
|
}
|
|
// For testing, only handle tensor tiling.
|
|
if (outerDestinationTensors.empty()) {
|
|
return emitOpError("expected destination tensors");
|
|
}
|
|
|
|
// Compute the number of iterations for each of the loops.
|
|
AffineExpr s0, s1, s2;
|
|
bindSymbols(rewriter.getContext(), s0, s1, s2);
|
|
AffineExpr numItersExpr = (s1 - s0).ceilDiv(s2); // (ub - lb) / tileSize
|
|
|
|
SmallVector<OpFoldResult> allNumIters;
|
|
allNumIters.reserve(loopRanges.size());
|
|
for (auto [loopRange, tileSize] :
|
|
llvm::zip_equal(loopRanges, givenTileSizes)) {
|
|
OpFoldResult numIters = affine::makeComposedFoldedAffineApply(
|
|
rewriter, loc, numItersExpr,
|
|
{loopRange.offset, loopRange.size, tileSize});
|
|
allNumIters.push_back(numIters);
|
|
}
|
|
if (allNumIters.empty()) {
|
|
return emitOpError("invalid empty tile sizes and loop ranges");
|
|
}
|
|
|
|
AffineExpr mulExpr = s0 * s1;
|
|
OpFoldResult cumulative = oneOfr;
|
|
for (auto numIters : allNumIters) {
|
|
cumulative = affine::makeComposedFoldedAffineApply(
|
|
rewriter, loc, mulExpr, {cumulative, numIters});
|
|
}
|
|
|
|
Value zeroVal = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
|
Value oneVal = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
|
Value ub = getValueOrCreateConstantIndexOp(rewriter, loc, cumulative);
|
|
|
|
SmallVector<OpFoldResult> offsets;
|
|
SmallVector<OpFoldResult> sizes;
|
|
SmallVector<Value> innerDestinationTensors;
|
|
offsets.reserve(loopRanges.size());
|
|
sizes.reserve(loopRanges.size());
|
|
|
|
AffineExpr d0;
|
|
bindDims(rewriter.getContext(), d0);
|
|
AffineExpr offsetExpr = s0 + d0 * s1; // lb + iv * tileSize
|
|
AffineMap minMap =
|
|
AffineMap::get(1, 2, {s0 - d0, s1},
|
|
rewriter.getContext()); // min(ub - offset, tileSize)
|
|
auto forOp = scf::ForOp::create(
|
|
rewriter, loc, zeroVal, ub, oneVal, outerDestinationTensors,
|
|
[&](OpBuilder &b, Location bodyLoc, Value linearizedIv,
|
|
ValueRange destinations) {
|
|
auto delinearizeOp = affine::AffineDelinearizeIndexOp::create(
|
|
b, bodyLoc, linearizedIv, allNumIters);
|
|
for (auto [normalizedIv, range, tileSize] : llvm::zip_equal(
|
|
delinearizeOp.getResults(), loopRanges, givenTileSizes)) {
|
|
|
|
OpFoldResult normalizedIvOfr = getAsOpFoldResult(normalizedIv);
|
|
OpFoldResult offset = affine::makeComposedFoldedAffineApply(
|
|
b, bodyLoc, offsetExpr,
|
|
{normalizedIvOfr, range.offset, tileSize});
|
|
offsets.push_back(offset);
|
|
|
|
OpFoldResult size = affine::makeComposedFoldedAffineMin(
|
|
b, bodyLoc, minMap, {offset, range.size, tileSize});
|
|
sizes.push_back(size);
|
|
}
|
|
innerDestinationTensors = llvm::to_vector(destinations);
|
|
});
|
|
rewriter.setInsertionPointToEnd(forOp.getBody());
|
|
return scf::SCFTilingOptions::CustomLoopHeaderInfo{
|
|
{cast<LoopLikeOpInterface>(forOp.getOperation())},
|
|
offsets,
|
|
sizes,
|
|
innerDestinationTensors};
|
|
};
|
|
|
|
scf::SCFTilingOptions::GenerateLoopTerminatorFn terminatorFn =
|
|
[&](RewriterBase &rewriter, Location loc,
|
|
ArrayRef<LoopLikeOpInterface> loops, ValueRange tiledResults,
|
|
ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
|
|
ArrayRef<SmallVector<OpFoldResult>> resultSizes,
|
|
ValueRange destinationTensors) -> LogicalResult {
|
|
SmallVector<Value> yieldValues;
|
|
yieldValues.reserve(destinationTensors.size());
|
|
for (auto [tiledResult, offsets, sizes, destination] : llvm::zip_equal(
|
|
tiledResults, resultOffsets, resultSizes, destinationTensors)) {
|
|
SmallVector<OpFoldResult> strides(offsets.size(), oneOfr);
|
|
Value insertedVal = tensor::InsertSliceOp::create(
|
|
rewriter, loc, tiledResult, destination, offsets, sizes, strides);
|
|
yieldValues.push_back(insertedVal);
|
|
}
|
|
scf::YieldOp::create(rewriter, loc, yieldValues);
|
|
return success();
|
|
};
|
|
|
|
scf::SCFTilingOptions tilingOptions;
|
|
SmallVector<int64_t> staticTileSizes =
|
|
extractFromIntegerArrayAttr<int64_t>(getTileSizes());
|
|
SmallVector<OpFoldResult> tileSizes =
|
|
getAsIndexOpFoldResult(transformRewriter.getContext(), staticTileSizes);
|
|
tilingOptions.setTileSizes(tileSizes)
|
|
.setLoopType(scf::SCFTilingOptions::LoopType::CustomOp)
|
|
.setCustomLoopGenerationFns(loopHeaderFn, terminatorFn);
|
|
|
|
OpBuilder::InsertionGuard g(transformRewriter);
|
|
transformRewriter.setInsertionPoint(target);
|
|
FailureOr<scf::SCFTilingResult> tiledResults =
|
|
scf::tileUsingSCF(transformRewriter, target, tilingOptions);
|
|
if (failed(tiledResults)) {
|
|
return DiagnosedSilenceableFailure::definiteFailure();
|
|
}
|
|
transformRewriter.replaceOp(target, tiledResults->replacements);
|
|
transformResults.set(getOperation()->getResult(0), tiledResults->tiledOps);
|
|
transformResults.set(getOperation()->getResult(1), tiledResults->loops);
|
|
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "TestTilingInterfaceTransformOps.cpp.inc"
|
|
|
|
namespace {
|
|
class TestTilingInterfaceDialectExtension
|
|
: public transform::TransformDialectExtension<
|
|
TestTilingInterfaceDialectExtension> {
|
|
public:
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
|
|
TestTilingInterfaceDialectExtension)
|
|
|
|
using Base::Base;
|
|
|
|
void init() {
|
|
declareDependentDialect<affine::AffineDialect>();
|
|
declareDependentDialect<index::IndexDialect>();
|
|
declareDependentDialect<scf::SCFDialect>();
|
|
declareDependentDialect<tensor::TensorDialect>();
|
|
|
|
registerTransformOps<
|
|
#define GET_OP_LIST
|
|
#include "TestTilingInterfaceTransformOps.cpp.inc"
|
|
>();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
namespace test {
|
|
void registerTestTilingInterfaceTransformDialectExtension(
|
|
DialectRegistry ®istry) {
|
|
registry.addExtensions<TestTilingInterfaceDialectExtension>();
|
|
}
|
|
} // namespace test
|