llvm-project/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
MaheshRavishankar 6740d701bd
[mlir][Linalg] Deprecate linalg::tileToForallOp and linalg::tileToForallOpUsingTileSizes (#91878)
The implementation of these methods are legacy and they are removed in
favor of using the `scf::tileUsingSCF` methods as replacements. To get
the latter on par with requirements of the deprecated methods, the
tiling allows one to specify the maximum number of tiles to use instead
of specifying the tile sizes. When tiling to `scf.forall` this
specification is used to generate the `num_threads` version of the
operation.

A slight deviation from previous implementation is that the deprecated
method always generated the `num_threads` variant of the `scf.forall`
operation. Instead now this is driven by the tiling options specified.
This reduces the indexing math generated when the tile sizes are
specified.

**Moving from `linalg::tileToForallOp` to `scf::tileUsingSCF`**

```
OpBuilder b;
TilingInterface op;
ArrayRef<OpFoldResult> numThreads;
ArrayAttr mapping;
FailureOr<ForallTilingResult> result =linalg::tileToForallOp(b, op, numThreads, mapping);
```

can be replaced by
```
scf::SCFTilingOptions options;
options.setNumThreads(numThreads);
options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
options.setMapping(mapping.getValue()); /*note the difference that setMapping takes an ArrayRef<Attribute> */
FailureOr<scf::SCFTilingResult> result = scf::tileUsingSCF(b, op, options);
```

This generates the `numThreads` version of the `scf.forall` for the
inter-tile loops, i.e.

```
... = scf.forall (%arg0, %arg1) in (%nt0, %nt1) shared_outs(...)
```

**Moving from `linalg::tileToForallOpUsingTileSizes` to
`scf::tileUsingSCF`**

```
OpBuilder b;
TilingInterface op;
ArrayRef<OpFoldResult> tileSizes;
ArrayAttr mapping;
FailureOr<ForallTilingResult> result =linalg::tileToForallOpUsingTileSizes(b, op, tileSizes, mapping);
```

can be replaced by
```
scf::SCFTilingOptions options;
options.setTileSizes(tileSizes);
options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
options.setMapping(mapping.getValue()); /*note the difference that setMapping takes an ArrayRef<Attribute> */
FailureOr<scf::SCFTilingResult> result = scf::tileUsingSCF(b, op, options);
```

Also note that `linalg::tileToForallOpUsingTileSizes` would effectively
call the `linalg::tileToForallOp` by computing the `numThreads` from the
`op` and `tileSizes` and generate the `numThreads` version of the
`scf.forall`. That is not the case anymore. Instead this will directly
generate the `tileSizes` version of the `scf.forall` op

```
... = scf.forall(%arg0, %arg1) = (%lb0, %lb1) to (%ub0, %ub1) step(%step0, %step1) shared_outs(...)
```

If you actually want to use the `numThreads` version, it is upto the
caller to compute the `numThreads` and set `options.setNumThreads`
instead of `options.setTileSizes`. Note that there is a slight
difference in the num threads version and tile size version. The former
requires an additional `affine.max` on the tile size to ensure
non-negative tile sizes. When lowering to `numThreads` version this
`affine.max` is not needed since by construction the tile sizes are
non-negative. In previous implementations, the `numThreads` version
generated when using the `linalg::tileToForallOpUsingTileSizes` method
would avoid generating the `affine.max` operation. To get the same
state, downstream users will have to additionally normalize the
`scf.forall` operation.

**Changes to `transform.structured.tile_using_forall`**

The transform dialect op that called into `linalg::tileToForallOp` and
`linalg::tileToForallOpUsingTileSizes` have been modified to call
`scf::tileUsingSCF`. The transform dialect op always generates the
`numThreads` version of the `scf.forall` op. So when `tile_sizes` are
specified for the transform dialect op, first the `tile_sizes` version
of the `scf.forall` is generated by the `scf::tileUsingSCF` method which
is then further normalized to get back to the same state. So there is no
functional change to `transform.structured.tile_using_forall`. It always
generates the `numThreads` version of the `scf.forall` op (as it did
before this change).

---------

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
2024-07-31 12:32:07 -07:00

407 lines
16 KiB
C++

//===- TestTilingInterfaceTransformOps.cpp - Test `TilingInterface` ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines transform dialect operations used for testing
// TilingInterface
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/TilingInterface.h"
#define GET_OP_CLASSES
#include "TestTilingInterfaceTransformOps.h.inc"
using namespace mlir;
using namespace mlir::transform;
//===----------------------------------------------------------------------===//
// TestFuseAndYieldOp
//===----------------------------------------------------------------------===//
static llvm::SmallDenseSet<Operation *> collectTiledAndFusedOps(Operation *op) {
SmallVector<Operation *> worklist;
llvm::SmallDenseSet<Operation *> producers;
worklist.push_back(op);
producers.insert(op);
while (!worklist.empty()) {
Operation *current = worklist.pop_back_val();
for (OpOperand &operand : current->getOpOperands()) {
Operation *producer = operand.get().getDefiningOp();
if (!producer || !isa<TilingInterface>(producer) ||
producers.contains(producer))
continue;
worklist.push_back(producer);
producers.insert(producer);
}
}
return producers;
}
/// Apply a tile and fuse transformation to all payload ops and store both the
/// tiled operation as well as the created tile loops.
template <typename Range>
static LogicalResult
applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp,
Range &&payloadOps, unsigned numLoops,
ArrayRef<OpFoldResult> tileSizes,
ArrayRef<int64_t> interchange, bool useForall,
TransformResults &transformResults) {
SmallVector<Operation *> tiledOps;
SmallVector<SmallVector<Operation *>> loopOps(numLoops);
for (Operation *target : payloadOps) {
auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
if (!tilingInterfaceOp)
return transformOp->emitError("only TilingInterface ops are supported");
DominanceInfo dominanceInfo(tilingInterfaceOp);
llvm::SmallDenseSet<Operation *> tiledAndFusedOps =
collectTiledAndFusedOps(tilingInterfaceOp);
llvm::DenseSet<Operation *> yieldReplacementsFor;
for (auto op : tiledAndFusedOps) {
if (llvm::any_of(op->getUsers(), [&](Operation *user) {
return dominanceInfo.properlyDominates(tilingInterfaceOp, user);
})) {
yieldReplacementsFor.insert(op);
}
}
scf::SCFTilingOptions tilingOptions;
tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
if (useForall) {
tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
}
scf::SCFTileAndFuseOptions tileAndFuseOptions;
tileAndFuseOptions.setTilingOptions(tilingOptions);
scf::SCFTileAndFuseOptions::ControlFnTy controlFn =
[&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
bool isDestinationOperand) {
Operation *owner = originalProducer.getOwner();
bool yieldProducerReplacement = yieldReplacementsFor.contains(owner);
return std::make_tuple(true, yieldProducerReplacement);
};
tileAndFuseOptions.setFusionControlFn(controlFn);
rewriter.setInsertionPoint(target);
FailureOr<scf::SCFTileAndFuseResult> tiledResults =
scf::tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
tileAndFuseOptions);
if (failed(tiledResults))
return failure();
// Perform the replacement of tiled and fused values.
SmallVector<Operation *> opsToReplace{target};
llvm::append_range(opsToReplace, tiledResults->fusedProducers);
for (Operation *toReplace : opsToReplace) {
for (OpResult res : toReplace->getResults())
if (auto replacement = tiledResults->replacements.lookup(res)) {
Operation *replacementOp = replacement.getDefiningOp();
rewriter.replaceUsesWithIf(res, replacement, [&](OpOperand &use) {
Operation *user = use.getOwner();
return dominanceInfo.properlyDominates(replacementOp, user) &&
user->getParentOp() == replacementOp->getParentOp();
});
}
if (toReplace->use_empty()) {
rewriter.eraseOp(toReplace);
}
}
// Report back the relevant handles to the transform op.
tiledOps.push_back(tiledResults->tiledAndFusedOps.front());
assert(tiledResults->loops.size() == numLoops &&
"Mismatched number of loops, tile and fuse transform should have "
"failed");
for (unsigned int i = 0; i < numLoops; ++i)
loopOps[i].push_back(tiledResults->loops[i]);
}
transformResults.set(transformOp->getOpResult(0), tiledOps);
for (unsigned int i = 0; i < numLoops; ++i)
transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
return success();
}
DiagnosedSilenceableFailure
transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
TransformResults &transformResults,
TransformState &state) {
SmallVector<int64_t> tileSizes =
extractFromIntegerArrayAttr<int64_t>(getTileSizes());
SmallVector<int64_t> tileInterchange =
extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
SmallVector<OpFoldResult> tileSizesOfr =
getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
LogicalResult result = applyTileAndFuseToAll(
rewriter, getOperation(), state.getPayloadOps(getTarget()),
tileSizes.size() - llvm::count(tileSizes, 0), tileSizesOfr,
tileInterchange, getUseForall(), transformResults);
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
: DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// TestFuseConsumerOp
//===----------------------------------------------------------------------===//
/// Apply fusing of consumer transformation to all payload ops and store both
/// the original consumer operation as well as the fused consumer operation.
template <typename Range>
static LogicalResult
applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp,
Range &&payloadOps, TransformResults &transformResults) {
SmallVector<Operation *> originalConsumerOps;
SmallVector<Operation *> fusedConsumerOps;
for (Operation *target : payloadOps) {
rewriter.setInsertionPoint(target);
FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
scf::tileAndFuseConsumerOfSlice(rewriter, target);
if (failed(fuseConsumerResults))
return failure();
// Report back the relevant handles to the transform op.
originalConsumerOps.push_back(
fuseConsumerResults->origConsumerOperand->getOwner());
fusedConsumerOps.push_back(
fuseConsumerResults->tiledAndFusedConsumerOperand->getOwner());
}
transformResults.set(transformOp->getOpResult(0), originalConsumerOps);
transformResults.set(transformOp->getOpResult(1), fusedConsumerOps);
return success();
}
DiagnosedSilenceableFailure
transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter,
TransformResults &transformResults,
TransformState &state) {
LogicalResult result =
applyFuseConsumer(rewriter, getOperation(),
state.getPayloadOps(getTarget()), transformResults);
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
: DiagnosedSilenceableFailure::success();
}
void transform::TestFuseConsumerOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
consumesHandle(getTargetMutable(), effects);
producesHandle(getOperation()->getOpResults(), effects);
modifiesPayload(effects);
}
//===----------------------------------------------------------------------===//
// TestTileUsingForallOp
//===----------------------------------------------------------------------===//
/// Apply a tiling transformation to all payload ops and store both the
/// tiled operation as well as the created tile loops.
template <typename Range>
static LogicalResult
applyTileToAll(RewriterBase &rewriter, Operation *transformOp,
Range &&payloadOps, ArrayRef<OpFoldResult> tileSizes,
ArrayRef<int64_t> interchange, std::optional<ArrayAttr> mapping,
TransformResults &transformResults) {
SmallVector<Operation *> tiledOps;
SmallVector<Operation *> loopOps;
for (Operation *target : payloadOps) {
auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
if (!tilingInterfaceOp)
return transformOp->emitError("only TilingInterface ops are supported");
scf::SCFTilingOptions tilingOptions;
tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
if (mapping) {
tilingOptions.setMapping(mapping.value().getValue());
}
tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
rewriter.setInsertionPoint(target);
FailureOr<scf::SCFTilingResult> tiledResults =
scf::tileUsingSCF(rewriter, tilingInterfaceOp, tilingOptions);
if (failed(tiledResults))
return failure();
// Perform the replacement of tiled and fused values.
rewriter.replaceOp(tilingInterfaceOp, tiledResults->replacements);
// Report back the relevant handles to the transform op.
tiledOps.push_back(tiledResults->tiledOps.front());
for (Operation *loop : tiledResults->loops)
loopOps.push_back(loop);
}
transformResults.set(transformOp->getOpResult(0), tiledOps);
for (auto [index, loop] : llvm::enumerate(loopOps))
transformResults.set(transformOp->getOpResult(index + 1), {loop});
return success();
}
DiagnosedSilenceableFailure
transform::TestTileUsingForallOp::apply(TransformRewriter &rewriter,
TransformResults &transformResults,
TransformState &state) {
SmallVector<int64_t> tileSizes =
extractFromIntegerArrayAttr<int64_t>(getTileSizes());
SmallVector<int64_t> interchange =
extractFromIntegerArrayAttr<int64_t>(getInterchange());
SmallVector<OpFoldResult> tileSizesOfr =
getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
LogicalResult result =
applyTileToAll(rewriter, getOperation(), state.getPayloadOps(getTarget()),
tileSizesOfr, interchange, getMapping(), transformResults);
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
: DiagnosedSilenceableFailure::success();
}
void transform::TestTileUsingForallOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
consumesHandle(getTargetMutable(), effects);
producesHandle(getOperation()->getOpResults(), effects);
modifiesPayload(effects);
}
//===----------------------------------------------------------------------===//
// TestFuseUsingForallOp
//===----------------------------------------------------------------------===//
/// Apply a tiling transformation to all payload ops and store both the
/// tiled operation as well as the created tile loops.
template <typename Range>
static LogicalResult applyTilingToAll(
RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
unsigned numLoops, TransformResults &transformResults,
function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
applyFn) {
SmallVector<Operation *> tiledLinalgOps;
SmallVector<SmallVector<Operation *>> loopOps(1);
for (Operation *target : payloadOps) {
auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
if (!tilingInterfaceOp)
return transformOp->emitError("only TilingInterface ops are supported");
rewriter.setInsertionPoint(target);
FailureOr<scf::SCFTileAndFuseResult> tiledResults =
applyFn(tilingInterfaceOp);
if (failed(tiledResults))
return failure();
// Perform the replacement of tiled and fused values.
SmallVector<Operation *> opsToReplace{target};
llvm::append_range(opsToReplace, tiledResults->fusedProducers);
for (Operation *toReplace : opsToReplace) {
for (OpResult res : toReplace->getResults())
if (auto replacement = tiledResults->replacements.lookup(res))
rewriter.replaceAllUsesWith(res, replacement);
if (toReplace->use_empty())
rewriter.eraseOp(toReplace);
}
// Report back the relevant handles to the transform op.
tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
assert(tiledResults->loops.size() == 1 &&
cast<scf::ForallOp>(tiledResults->loops[0]).getRank() == numLoops &&
"Mismatched number of loops, tile and fuse transform should have "
"failed");
loopOps[0] = {tiledResults->loops[0]};
}
transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
if (!loopOps.empty())
transformResults.set(transformOp->getOpResult(1), loopOps[0]);
return success();
}
DiagnosedSilenceableFailure
transform::TestFuseUsingForallOp::apply(TransformRewriter &rewriter,
TransformResults &transformResults,
TransformState &state) {
SmallVector<int64_t> tileSizes =
extractFromIntegerArrayAttr<int64_t>(getTileSizes());
SmallVector<int64_t> tileInterchange =
extractFromIntegerArrayAttr<int64_t>(getInterchange());
scf::SCFTilingOptions tilingOptions;
tilingOptions.interchangeVector = tileInterchange;
SmallVector<OpFoldResult> tileSizesOfr =
getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
scf::SCFTileAndFuseOptions tileAndFuseOptions;
tileAndFuseOptions.tilingOptions = tilingOptions;
LogicalResult result = applyTilingToAll(
rewriter, getOperation(), state.getPayloadOps(getRootOp()),
tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
[&](TilingInterface tilingInterfaceOp)
-> FailureOr<scf::SCFTileAndFuseResult> {
return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
tileAndFuseOptions);
});
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
: DiagnosedSilenceableFailure::success();
}
void transform::TestFuseUsingForallOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
consumesHandle(getRootOpMutable(), effects);
producesHandle(getOperation()->getOpResults(), effects);
modifiesPayload(effects);
}
#define GET_OP_CLASSES
#include "TestTilingInterfaceTransformOps.cpp.inc"
namespace {
class TestTilingInterfaceDialectExtension
: public transform::TransformDialectExtension<
TestTilingInterfaceDialectExtension> {
public:
using Base::Base;
void init() {
declareDependentDialect<affine::AffineDialect>();
declareDependentDialect<index::IndexDialect>();
declareDependentDialect<scf::SCFDialect>();
declareDependentDialect<tensor::TensorDialect>();
registerTransformOps<
#define GET_OP_LIST
#include "TestTilingInterfaceTransformOps.cpp.inc"
>();
}
};
} // namespace
namespace test {
void registerTestTilingInterfaceTransformDialectExtension(
DialectRegistry &registry) {
registry.addExtensions<TestTilingInterfaceDialectExtension>();
}
} // namespace test