These are identified by misc-include-cleaner. I've filtered out those that break builds. Also, I'm staying away from llvm-config.h, config.h, and Compiler.h, which likely cause platform- or compiler-specific build failures.
115 lines
4.2 KiB
C++
115 lines
4.2 KiB
C++
//===- SwapExtractSliceWithProducerPatterns.cpp ---------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// Swap a `tensor.extract_slice` with the producer of the source if the producer
|
|
// implements the `TilingInterface`. When used in conjunction with tiling this
|
|
// effectively tiles + fuses the producer with its consumer.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
|
#include "mlir/Interfaces/TilingInterface.h"
|
|
#include "llvm/Support/Debug.h"
|
|
|
|
#define DEBUG_TYPE "tensor-swap-slices"
|
|
|
|
using namespace mlir;
|
|
|
|
FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer(
|
|
OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producer) {
|
|
auto producerOp = dyn_cast<TilingInterface>(producer.getOwner());
|
|
if (!producerOp)
|
|
return failure();
|
|
|
|
// `TilingInterface` currently only supports strides being 1.
|
|
if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
|
|
return failure();
|
|
|
|
FailureOr<TilingResult> tiledResult = producerOp.generateResultTileValue(
|
|
builder, producer.getResultNumber(), sliceOp.getMixedOffsets(),
|
|
sliceOp.getMixedSizes());
|
|
if (failed(tiledResult))
|
|
return failure();
|
|
|
|
// For cases where the slice was rank-reducing, create a rank-reducing slice
|
|
// to get the same type back.
|
|
llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
|
|
if (droppedDims.any()) {
|
|
assert(tiledResult->tiledValues.size() == 1 &&
|
|
"expected only a single tiled result value to replace the extract "
|
|
"slice");
|
|
SmallVector<OpFoldResult> offsets(sliceOp.getSourceType().getRank(),
|
|
builder.getIndexAttr(0));
|
|
SmallVector<OpFoldResult> strides(sliceOp.getSourceType().getRank(),
|
|
builder.getIndexAttr(1));
|
|
auto newSliceOp = tensor::ExtractSliceOp::create(
|
|
builder, sliceOp.getLoc(), sliceOp.getType(),
|
|
tiledResult->tiledValues[0], offsets, sliceOp.getMixedSizes(), strides);
|
|
tiledResult->tiledValues[0] = newSliceOp;
|
|
}
|
|
|
|
return *tiledResult;
|
|
}
|
|
|
|
FailureOr<TilingResult> tensor::replaceInsertSlicesWithTiledConsumer(
|
|
OpBuilder &builder, ArrayRef<tensor::InsertSliceOp> sliceOps,
|
|
ArrayRef<OpOperand *> consumerOperands) {
|
|
if (sliceOps.empty()) {
|
|
LLVM_DEBUG(
|
|
{ llvm::dbgs() << "expected candidate slices list to be non-empty"; });
|
|
return failure();
|
|
}
|
|
if (sliceOps.size() != consumerOperands.size()) {
|
|
LLVM_DEBUG({
|
|
llvm::dbgs()
|
|
<< "expected as many operands as the number of slices passed";
|
|
});
|
|
return failure();
|
|
}
|
|
auto consumerOp =
|
|
dyn_cast<TilingInterface>(consumerOperands.front()->getOwner());
|
|
if (!consumerOp)
|
|
return failure();
|
|
for (auto opOperand : consumerOperands.drop_front()) {
|
|
if (opOperand->getOwner() != consumerOp) {
|
|
LLVM_DEBUG({
|
|
llvm::dbgs()
|
|
<< "expected all consumer operands to be from the same operation";
|
|
});
|
|
return failure();
|
|
}
|
|
}
|
|
|
|
auto consumerOperandNums = llvm::map_to_vector(
|
|
consumerOperands, [](OpOperand *opOperand) -> unsigned {
|
|
return opOperand->getOperandNumber();
|
|
});
|
|
SmallVector<SmallVector<OpFoldResult>> allOffsets;
|
|
SmallVector<SmallVector<OpFoldResult>> allSizes;
|
|
for (auto sliceOp : sliceOps) {
|
|
|
|
// `TilingInterface` currently only supports strides being 1.
|
|
if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
|
|
return failure();
|
|
|
|
SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
|
|
SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
|
|
allOffsets.emplace_back(std::move(offsets));
|
|
allSizes.emplace_back(std::move(sizes));
|
|
}
|
|
FailureOr<TilingResult> tiledResult =
|
|
consumerOp.getTiledImplementationFromOperandTiles(
|
|
builder, consumerOperandNums, allOffsets, allSizes);
|
|
if (failed(tiledResult))
|
|
return failure();
|
|
|
|
return *tiledResult;
|
|
}
|