ExpandShapeOp builder cannot infer the result type since it doesn't know how the dimension needs to be split. Remove this builder so that it doesn't get used accidently. Also remove one potential path using it in generic fusion. Differential Revision: https://reviews.llvm.org/D122019
2347 lines
99 KiB
C++
2347 lines
99 KiB
C++
//===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===///
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements the linalg dialect Fusion on tensors operations pass.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
#include <utility>
|
|
|
|
#include "PassDetail.h"
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
#include "mlir/Dialect/Linalg/Passes.h"
|
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
|
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
|
|
#include "mlir/IR/AffineExpr.h"
|
|
#include "mlir/IR/AffineMap.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Support/LLVM.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::linalg;
|
|
|
|
//===---------------------------------------------------------------------===//
|
|
// Methods and patterns that fuse elementwise `linalg.generic` operations.
|
|
//===---------------------------------------------------------------------===//
|
|
|
|
/// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
|
|
/// the `producer` to use in the fused operation given the indexing map of the
|
|
/// result of the producer in the consumer.
|
|
static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
|
|
OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
|
|
AffineMap fusedConsumerArgIndexMap) {
|
|
// The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
|
|
// from consumer loop -> consumer arg tensor index/producer result tensor
|
|
// index. The fused loop is same as the consumer loop. For each producer arg
|
|
// the indexing map to be computed is a map from consumer loop -> producer
|
|
// arg tensor index.
|
|
// producerResultIndexMap is a map from producer loop -> tensor index.
|
|
// Compute the inverse to get map from tensor index -> producer loop.
|
|
// The inverse is a map from producer result tensor index -> producer loop.
|
|
AffineMap invProducerResultIndexMap =
|
|
inversePermutation(producerResultIndexMap);
|
|
assert(invProducerResultIndexMap &&
|
|
"expected producer result indexing map to be invertible");
|
|
|
|
LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
|
|
// argMap is a map from producer loop -> producer arg tensor index.
|
|
AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand);
|
|
|
|
// Compose argMap with invProducerResultIndexMap to get a map from
|
|
// producer result tensor index -> producer arg tensor index.
|
|
AffineMap t1 = argMap.compose(invProducerResultIndexMap);
|
|
|
|
// Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
|
|
// consumer loop/ fused loop -> producer arg tensor index.
|
|
return t1.compose(fusedConsumerArgIndexMap);
|
|
}
|
|
|
|
/// Conditions for elementwise fusion of generic operations.
|
|
static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
|
|
OpOperand *consumerOpOperand) {
|
|
// Producer and consumer must have tensor semantics.
|
|
if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
|
|
return false;
|
|
|
|
// Verify that
|
|
// - the producer has all "parallel" iterator type.
|
|
if (producer.getNumParallelLoops() != producer.getNumLoops())
|
|
return false;
|
|
|
|
// Only allow fusing the producer of an input operand for now.
|
|
// TODO: allow fusing the producer of an output operand.
|
|
if (!consumer.isInputTensor(consumerOpOperand))
|
|
return false;
|
|
|
|
// Get the consumer index map. The number of results of the consumer index
|
|
// map must match the number of loops of the producer.
|
|
AffineMap consumerIndexMap = consumer.getTiedIndexingMap(consumerOpOperand);
|
|
if (consumerIndexMap.getNumResults() != producer.getNumLoops())
|
|
return false;
|
|
|
|
// Currently support only operations with single result.
|
|
if (producer.getNumOutputs() != 1)
|
|
return false;
|
|
|
|
// Finally the index_map for the result must be invertible. For now just
|
|
// verify it is a permutation.
|
|
AffineMap producerResultIndexMap =
|
|
producer.getTiedIndexingMap(producer.getOutputOperand(0));
|
|
if (!producerResultIndexMap.isPermutation())
|
|
return false;
|
|
|
|
// Ensure that the fusion does not remove size information required to
|
|
// get the loop bounds. For non-reduction generics, this is trivially the
|
|
// case due to the output operand. For reductions, we need to check that after
|
|
// the fusion, each loop dimension has at least one input that defines it.
|
|
if ((consumer.getNumReductionLoops())) {
|
|
BitVector coveredDims(consumer.getNumLoops(), false);
|
|
|
|
auto addToCoveredDims = [&](AffineMap map) {
|
|
for (auto result : map.getResults())
|
|
if (auto dimExpr = result.dyn_cast<AffineDimExpr>())
|
|
coveredDims[dimExpr.getPosition()] = true;
|
|
};
|
|
|
|
for (auto pair :
|
|
llvm::zip(consumer->getOperands(), consumer.getIndexingMaps())) {
|
|
Value operand = std::get<0>(pair);
|
|
if (operand == consumerOpOperand->get())
|
|
continue;
|
|
AffineMap operandMap = std::get<1>(pair);
|
|
addToCoveredDims(operandMap);
|
|
}
|
|
|
|
for (OpOperand *operand : producer.getInputOperands()) {
|
|
AffineMap newIndexingMap =
|
|
getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
|
|
operand, producerResultIndexMap, consumerIndexMap);
|
|
addToCoveredDims(newIndexingMap);
|
|
}
|
|
if (!coveredDims.all())
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
/// Generate the region of the fused tensor operation. The region of the fused
|
|
/// op must be empty.
|
|
static void
|
|
generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
|
|
AffineMap consumerToProducerLoopsMap,
|
|
OpOperand *consumerOpOperand,
|
|
unsigned nloops) {
|
|
auto producer = cast<GenericOp>(consumerOpOperand->get().getDefiningOp());
|
|
auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
|
|
// Build the region of the fused op.
|
|
Block &producerBlock = producer->getRegion(0).front();
|
|
Block &consumerBlock = consumer->getRegion(0).front();
|
|
Block *fusedBlock = new Block();
|
|
fusedOp.region().push_back(fusedBlock);
|
|
BlockAndValueMapping mapper;
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
rewriter.setInsertionPointToStart(fusedBlock);
|
|
|
|
// 2. Add an index operation for every fused loop dimension and use the
|
|
// `consumerToProducerLoopsMap` to map the producer indices.
|
|
if (producer.hasIndexSemantics()) {
|
|
// Add an index operation for every fused loop dimension.
|
|
unsigned numFusedOpLoops =
|
|
std::max(producer.getNumLoops(), consumer.getNumLoops());
|
|
SmallVector<Value> fusedIndices;
|
|
fusedIndices.reserve(numFusedOpLoops);
|
|
llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
|
|
std::back_inserter(fusedIndices), [&](uint64_t dim) {
|
|
return rewriter.create<IndexOp>(producer.getLoc(), dim);
|
|
});
|
|
for (IndexOp indexOp :
|
|
llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
|
|
Value newIndex = rewriter.create<mlir::AffineApplyOp>(
|
|
producer.getLoc(),
|
|
consumerToProducerLoopsMap.getSubMap(indexOp.dim()), fusedIndices);
|
|
mapper.map(indexOp.getResult(), newIndex);
|
|
}
|
|
}
|
|
// TODO: allow fusing the producer of an output operand.
|
|
assert(consumer.isInputTensor(consumerOpOperand) &&
|
|
"expected producer of input operand");
|
|
// 3. Consumer input operands up to consumerIdx (exclusive).
|
|
for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
|
|
consumerOpOperand->getOperandNumber())) // input assumption.
|
|
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
|
|
|
|
// Replacing consumerIdx requires getting the cloned, yielded, value from
|
|
// the (cloned) producer block. This happens in step 9.
|
|
|
|
// 4. Splice in producer's input operands.
|
|
for (BlockArgument bbArg :
|
|
producerBlock.getArguments().take_front(producer.getNumInputs()))
|
|
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
|
|
|
|
// 4.b. Producer output operand/map that is fused needs to be mapped to the
|
|
// producer bbArg if it is an "initTensor" (i.e. its value is actually read).
|
|
assert(producer->getNumResults() == 1 && "expected single result producer");
|
|
if (producer.isInitTensor(producer.getOutputOperand(0))) {
|
|
BlockArgument bbArg = producerBlock.getArguments()
|
|
.drop_front(producer.getNumInputs())
|
|
// TODO: bbArg index of
|
|
.front();
|
|
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
|
|
}
|
|
// 5. Remaining consumer's input operands (drop past index `consumerIdx`).
|
|
for (BlockArgument bbArg :
|
|
consumerBlock.getArguments()
|
|
.take_front(consumer.getNumInputs())
|
|
.drop_front(consumerOpOperand->getOperandNumber() + 1))
|
|
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
|
|
// 6. All of consumer's output operands.
|
|
for (BlockArgument bbArg :
|
|
consumerBlock.getArguments().take_back(consumer.getNumOutputs()))
|
|
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
|
|
// 7. All of producer's output operands except the one fused.
|
|
// TODO: allow fusion of multi-result producers.
|
|
assert(producer->getNumResults() == 1 && "expected single result producer");
|
|
|
|
// 8. Clone all producer operations except for the yield and index operations
|
|
// to the fused operation.
|
|
for (auto &op : producerBlock.without_terminator()) {
|
|
if (!isa<IndexOp>(op))
|
|
rewriter.clone(op, mapper);
|
|
}
|
|
// 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
|
|
// forward the yield operand.
|
|
auto yieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
|
|
// TODO: allow fusion of multi-result producers.
|
|
assert(producer->getNumResults() == 1 && "expected single result producer");
|
|
unsigned producerResultNumber = 0;
|
|
Value replacement =
|
|
mapper.lookupOrDefault(yieldOp.getOperand(producerResultNumber));
|
|
// Sanity checks, if replacement is not already in the mapper then it must be
|
|
// produced outside.
|
|
if (replacement == yieldOp.getOperand(producerResultNumber)) {
|
|
if (auto bb = replacement.dyn_cast<BlockArgument>())
|
|
assert(bb.getOwner() != &producerBlock &&
|
|
"yielded block argument must have been mapped");
|
|
else
|
|
assert(!producer->isAncestor(replacement.getDefiningOp()) &&
|
|
"yielded value must have been mapped");
|
|
}
|
|
mapper.map(consumerBlock.getArgument(consumerOpOperand->getOperandNumber()),
|
|
replacement);
|
|
// 10. Clone operations from the consumer to the fused op.
|
|
for (auto &op : consumerBlock.getOperations())
|
|
rewriter.clone(op, mapper);
|
|
|
|
// Sanity checks.
|
|
assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
|
|
"Ill-formed GenericOp region");
|
|
}
|
|
|
|
static Optional<SmallVector<Value>>
|
|
fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand,
|
|
const ControlElementwiseOpsFusionFn &controlFn,
|
|
PatternRewriter &rewriter) {
|
|
auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
|
|
if (!areElementwiseOpsFusable(producer, consumer, consumerOpOperand) ||
|
|
!controlFn(producer->getResult(0), *consumerOpOperand))
|
|
return llvm::None;
|
|
|
|
// TODO: allow fusing the producer of an output operand.
|
|
assert(consumer.isInputTensor(consumerOpOperand) &&
|
|
"expected producer of input operand");
|
|
|
|
// Compute the fused operands list and indexing maps.
|
|
SmallVector<Value> fusedOperands;
|
|
SmallVector<AffineMap> fusedIndexMaps;
|
|
fusedOperands.reserve(producer->getNumOperands() +
|
|
consumer->getNumOperands());
|
|
fusedIndexMaps.reserve(producer->getNumOperands() +
|
|
consumer->getNumOperands());
|
|
// In the following, numbering matches that of `generateFusedTensorOpRegion`.
|
|
// 3. Consumer input operands/maps up to consumerIdx (exclusive).
|
|
SmallVector<OpOperand *> consumerInputs = consumer.getInputOperands();
|
|
SmallVector<OpOperand *>::iterator it =
|
|
llvm::find(consumerInputs, consumerOpOperand);
|
|
assert(it != consumerInputs.end() && "expected to find the consumer operand");
|
|
for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
|
|
fusedOperands.push_back(opOperand->get());
|
|
fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
|
|
}
|
|
// 4. Splice in producer's input operands/maps.
|
|
assert(producer->getNumResults() == 1 && "expected single result producer");
|
|
AffineMap producerResultIndexMap =
|
|
producer.getTiedIndexingMap(producer.getOutputOperand(0));
|
|
for (OpOperand *opOperand : producer.getInputOperands()) {
|
|
fusedOperands.push_back(opOperand->get());
|
|
// Compute indexing maps for the producer args in the fused operation.
|
|
AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
|
|
opOperand, producerResultIndexMap,
|
|
consumer.getTiedIndexingMap(consumerOpOperand));
|
|
fusedIndexMaps.push_back(map);
|
|
}
|
|
// 4.b. Producer output operand/map that is fused needs to be passed if it is
|
|
// an "initTensor" (i.e. its value is actually read).
|
|
assert(producer->getNumResults() == 1 && "expected single result producer");
|
|
if (producer.isInitTensor(producer.getOutputOperand(0))) {
|
|
fusedOperands.push_back(producer.getOutputOperand(0)->get());
|
|
// Compute indexing maps for the producer args in the fused operation.
|
|
AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
|
|
producer.getOutputOperand(0), producerResultIndexMap,
|
|
consumer.getTiedIndexingMap(consumerOpOperand));
|
|
fusedIndexMaps.push_back(map);
|
|
}
|
|
// 5. Remaining consumer's input operands/maps (drop past index
|
|
// `consumerIdx`).
|
|
for (OpOperand *opOperand :
|
|
llvm::make_range(std::next(it), consumerInputs.end())) {
|
|
fusedOperands.push_back(opOperand->get());
|
|
fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
|
|
}
|
|
// 6. All of consumer's output operands (skip operands: added by the builder).
|
|
for (OpOperand *opOperand : consumer.getOutputOperands())
|
|
fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
|
|
// 7. All of producer's output operands/maps except the one fused.
|
|
// TODO: allow fusion of multi-result producers.
|
|
assert(producer->getNumResults() == 1 && "expected single result producer");
|
|
|
|
// Generate the fused op.
|
|
SmallVector<Value> consumerOutputs = consumer.getOutputOperands();
|
|
auto fusedOp = rewriter.create<GenericOp>(
|
|
consumer.getLoc(), consumer->getResultTypes(),
|
|
/*inputs=*/fusedOperands,
|
|
// TODO: handle outputs.
|
|
consumerOutputs, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
|
|
consumer.iterator_types(),
|
|
/*doc=*/nullptr,
|
|
/*library_call=*/nullptr);
|
|
if (!fusedOp.getShapesToLoopsMap()) {
|
|
// Fused op has invalid indexing maps. Typically this means something is off
|
|
// in the input, but going ahead here would result in verification errors.
|
|
// So cleanup and abort.
|
|
rewriter.eraseOp(fusedOp);
|
|
return llvm::None;
|
|
}
|
|
|
|
// Construct an AffineMap from consumer loops to producer loops.
|
|
// consumer loop -> tensor index
|
|
AffineMap consumerResultIndexMap =
|
|
consumer.getTiedIndexingMap(consumerOpOperand);
|
|
// tensor index -> producer loop
|
|
AffineMap invProducerResultIndexMap =
|
|
inversePermutation(producerResultIndexMap);
|
|
assert(invProducerResultIndexMap &&
|
|
"expected producer result indexig map to be invertible");
|
|
// consumer loop -> producer loop
|
|
AffineMap consumerToProducerLoopsMap =
|
|
invProducerResultIndexMap.compose(consumerResultIndexMap);
|
|
|
|
generateFusedElementwiseOpRegion(rewriter, fusedOp,
|
|
consumerToProducerLoopsMap,
|
|
consumerOpOperand, consumer.getNumLoops());
|
|
return SmallVector<Value>(fusedOp->getResults());
|
|
}
|
|
|
|
static Optional<SmallVector<Value>>
|
|
fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand,
|
|
GenericOp producer,
|
|
const ControlElementwiseOpsFusionFn &controlFn) {
|
|
if (producer->getNumResults() != 1)
|
|
return llvm::None;
|
|
|
|
return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn,
|
|
rewriter);
|
|
}
|
|
|
|
namespace {
|
|
/// Patterns to fuse a generic op, with the producer of its operands.
|
|
class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
|
|
public:
|
|
FuseElementwiseOps(MLIRContext *context, ControlElementwiseOpsFusionFn &fun,
|
|
PatternBenefit benefit = 1)
|
|
: OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
|
|
|
|
LogicalResult matchAndRewrite(GenericOp genericOp,
|
|
PatternRewriter &rewriter) const override {
|
|
// Find the first operand that is defined by another generic op on tensors.
|
|
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
|
|
auto producer =
|
|
dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp());
|
|
if (!producer || !producer.hasTensorSemantics())
|
|
continue;
|
|
Optional<SmallVector<Value>> fusedOpResults =
|
|
fuseElementwiseOps(rewriter, opOperand, producer, controlFn);
|
|
if (fusedOpResults) {
|
|
rewriter.replaceOp(genericOp, *fusedOpResults);
|
|
return success();
|
|
}
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
private:
|
|
ControlElementwiseOpsFusionFn controlFn;
|
|
};
|
|
} // namespace
|
|
|
|
//===---------------------------------------------------------------------===//
|
|
// Methods and patterns that fuse reshape ops with elementwise operations by
|
|
// linearization of indexing maps.
|
|
//===---------------------------------------------------------------------===//
|
|
|
|
// TODO(ravishankarm): The indexing maps
|
|
// these produce in the general case are detrimental to transformations.
|
|
// These patterns are on deprecation path in favor of using fusion by
|
|
// collapsing, which covers the only legitimate use case of this pattern of
|
|
// folding unit-extent dims.
|
|
|
|
/// Linearize the expressions in `sourceMap` based on the `reassociationMaps`
|
|
/// provided, given the shape of the source tensor that corresponds to the
|
|
/// `sourceMap`. Note that this implicitly assumes that the tensors dimensions
|
|
/// are "row-major" ordered logically.
|
|
///
|
|
/// For example:
|
|
///
|
|
/// %0 = op ... : tensor<?x?x4x5xf32>
|
|
/// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>`
|
|
///
|
|
/// and reshape:
|
|
/// %1 = tensor.collapse_shape %0 [[0], [0, 1, 2]] :
|
|
/// tensor<?x?x4x5xf32> into tensor<?x?xf32>
|
|
///
|
|
/// would be rewritten into:
|
|
/// %0 = op ... : tensor<?x?x4x5xf32>
|
|
/// with output index_map
|
|
/// `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>`
|
|
template <typename TensorReshapeOp>
|
|
static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
|
|
TensorReshapeOp reshapeOp) {
|
|
constexpr bool isExpanding =
|
|
std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value;
|
|
ArrayRef<int64_t> sourceShape =
|
|
(isExpanding ? reshapeOp.getResultType().getShape()
|
|
: reshapeOp.getSrcType().getShape());
|
|
SmallVector<AffineExpr> resultExprs;
|
|
ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults();
|
|
MLIRContext *context = sourceMap.getContext();
|
|
|
|
// Compute the result exprs based on the reassociation maps.
|
|
for (auto &indices : reshapeOp.getReassociationIndices()) {
|
|
// Assume that they are in-order and contiguous (already checked in
|
|
// verifier).
|
|
assert(!indices.empty());
|
|
SmallVector<int64_t> sizes;
|
|
SmallVector<AffineExpr> dimExprs;
|
|
for (auto en : llvm::zip(sourceShape.slice(indices[0], indices.size()),
|
|
sourceExprs.slice(indices[0], indices.size()))) {
|
|
if (std::get<0>(en) == 1)
|
|
continue;
|
|
sizes.push_back(std::get<0>(en));
|
|
dimExprs.push_back(std::get<1>(en));
|
|
}
|
|
AffineExpr linearizedExpr =
|
|
makeCanonicalStridedLayoutExpr(sizes, dimExprs, context);
|
|
resultExprs.push_back(linearizedExpr);
|
|
}
|
|
// The new affine map cannot drop unused dimension but some new symbols may
|
|
// have been added. Create a map with at least as many dimensions/symbols as
|
|
// the original affine map.
|
|
int64_t maxDim = -1;
|
|
int64_t maxSym = -1;
|
|
getMaxDimAndSymbol<SmallVector<AffineExpr>>({resultExprs}, maxDim, maxSym);
|
|
unsigned numDims = std::max(unsigned(maxDim + 1), sourceMap.getNumDims());
|
|
unsigned numSyms = std::max(unsigned(maxSym + 1), sourceMap.getNumSymbols());
|
|
return AffineMap::get(numDims, numSyms, resultExprs, context);
|
|
}
|
|
|
|
// tensor::ExpandShapeOp is fusable with its consumer (i.e. reshape as a
|
|
// producer). Fusing when operand has higher rank will require use of mods and
|
|
// divs in the indexing maps of the fused op which would make it non-invertible.
|
|
static bool isTensorReshapeOpFoldableByLinearization(
|
|
tensor::ExpandShapeOp expandOp, AffineMap useIndexMap, bool asProducer) {
|
|
if (!asProducer)
|
|
return false;
|
|
return useIndexMap.isPermutation();
|
|
}
|
|
|
|
// tensor::CollapseShapeOp is fusable with its producer (i.e. reshape as a
|
|
// consumer).
|
|
static bool
|
|
isTensorReshapeOpFoldableByLinearization(tensor::CollapseShapeOp collapseOp,
|
|
AffineMap useIndexMap,
|
|
bool asProducer) {
|
|
if (asProducer)
|
|
return false;
|
|
return useIndexMap.isPermutation();
|
|
}
|
|
|
|
/// Check if the reshape operation is only expansion into/collapsing of
|
|
/// unit-dimension.
|
|
template <typename TensorReshapeOp>
|
|
static bool isUnitDimExpansionOnly(TensorReshapeOp reshapeOp) {
|
|
constexpr bool isExpanding =
|
|
std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value;
|
|
ArrayRef<int64_t> expandedShape =
|
|
(isExpanding ? reshapeOp.getResultType().getShape()
|
|
: reshapeOp.getSrcType().getShape());
|
|
for (auto &indices : reshapeOp.getReassociationIndices()) {
|
|
unsigned numUnitDims = 0;
|
|
for (int64_t position : indices)
|
|
if (expandedShape[position] == 1)
|
|
numUnitDims++;
|
|
if (numUnitDims != indices.size() - 1)
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
namespace {
|
|
/// Pattern to fold tensor_expand_shape op with its consumer by using the source
|
|
/// of the reshape op as the operand in the consumer (instead of the result of
|
|
/// the tensor_collapse_shape). The corresponding index map in the consumer
|
|
/// needs to be modified to linearize the folded dimension.
|
|
///
|
|
/// For example,
|
|
///
|
|
/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
|
/// %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]]
|
|
/// tensor<?x?x?xf32> into tensor<?x?x4x?xf32>
|
|
/// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... }
|
|
/// ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) ...
|
|
/// -> tensor<?x?x4x?xf32>
|
|
///
|
|
/// can be folded into
|
|
///
|
|
/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
|
|
/// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
|
/// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... }
|
|
/// ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ...
|
|
/// -> tensor<?x?x4x?xf32>
|
|
template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
|
|
struct FoldProducerReshapeOpByLinearization
|
|
: public OpRewritePattern<GenericOp> {
|
|
using OpRewritePattern<GenericOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(GenericOp genericOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (!genericOp.hasTensorSemantics())
|
|
return failure();
|
|
SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
|
|
for (const auto &en : llvm::enumerate(inputOperands)) {
|
|
auto reshapeOp = en.value()->get().getDefiningOp<TensorReshapeOp>();
|
|
if (!reshapeOp)
|
|
continue;
|
|
|
|
if (!isTensorReshapeOpFoldableByLinearization(
|
|
reshapeOp, genericOp.getTiedIndexingMap(en.value()),
|
|
/*asProducer =*/true) ||
|
|
(foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
|
|
continue;
|
|
|
|
// Compute the fused operands list,
|
|
SmallVector<Value> fusedOperands = genericOp.getInputOperands();
|
|
fusedOperands[en.index()] = reshapeOp.src();
|
|
SmallVector<Value> outputOperands = genericOp.getOutputOperands();
|
|
llvm::append_range(fusedOperands, outputOperands);
|
|
|
|
// Compute indexing_maps for the fused operation. The indexing_maps for
|
|
// the operands of the consumers that arent fused are the same.
|
|
SmallVector<AffineMap> fusedIndexMaps = genericOp.getIndexingMaps();
|
|
|
|
// Compute the indexing map to use for the result of the producer.
|
|
AffineMap modifiedMap =
|
|
linearizeCollapsedDims(fusedIndexMaps[en.index()], reshapeOp);
|
|
// The modified map cannot have symbols.
|
|
if (modifiedMap.getNumSymbols())
|
|
return failure();
|
|
for (AffineExpr expr : modifiedMap.getResults()) {
|
|
if (!expr.isPureAffine())
|
|
return failure();
|
|
}
|
|
fusedIndexMaps[en.index()] = modifiedMap;
|
|
|
|
// Further check that the resulting index maps can be fused and
|
|
// inverted. Without this the resultant op is not legal.
|
|
if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
|
|
return rewriter.notifyMatchFailure(
|
|
genericOp, "fused op loop bound computation failed");
|
|
}
|
|
|
|
rewriter.startRootUpdate(genericOp);
|
|
genericOp->setOperands(fusedOperands);
|
|
genericOp.indexing_mapsAttr(
|
|
rewriter.getAffineMapArrayAttr(fusedIndexMaps));
|
|
rewriter.finalizeRootUpdate(genericOp);
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
};
|
|
|
|
/// Pattern to fold tensor_collapse_shape or tensor_expand_shape op with its
|
|
/// producer. The corresponding index map in the consumer needs to be modified
|
|
/// to linearize the folded dimension.
|
|
template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
|
|
struct FoldConsumerReshapeOpByLinearization
|
|
: public OpRewritePattern<TensorReshapeOp> {
|
|
using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
|
|
PatternRewriter &rewriter) const override {
|
|
GenericOp producer = reshapeOp.src().template getDefiningOp<GenericOp>();
|
|
if (!producer || !producer.hasTensorSemantics() ||
|
|
producer.getNumOutputs() != 1 ||
|
|
!isTensorReshapeOpFoldableByLinearization(
|
|
reshapeOp,
|
|
producer.getTiedIndexingMap(producer.getOutputOperand(0)),
|
|
/*asProducer =*/false) ||
|
|
(foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
|
|
return failure();
|
|
// The indexing_maps for the operands of the fused operation are same as
|
|
// those for the operands of the producer.
|
|
SmallVector<AffineMap> fusedIndexMaps = producer.getIndexingMaps();
|
|
|
|
// Compute the indexing map to use for the operand of the producer.
|
|
AffineMap modifiedMap = linearizeCollapsedDims(
|
|
producer.getTiedIndexingMap(producer.getOutputOperand(0)), reshapeOp);
|
|
for (AffineExpr expr : modifiedMap.getResults()) {
|
|
if (!expr.isPureAffine()) {
|
|
return rewriter.notifyMatchFailure(
|
|
producer, "fused op indexing map is not affine");
|
|
}
|
|
}
|
|
fusedIndexMaps.back() = modifiedMap;
|
|
|
|
// Further check that the resulting index maps can be fused and
|
|
// inverted. Without this the resultant op is not legal.
|
|
if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
|
|
return rewriter.notifyMatchFailure(
|
|
producer, "fused op loop bound computation failed");
|
|
}
|
|
|
|
Location loc = producer.getLoc();
|
|
SmallVector<Value> inputOperands = producer.getInputOperands();
|
|
Value output = rewriter.create<TensorReshapeOp>(
|
|
loc, producer.getOutputOperand(0)->get(),
|
|
reshapeOp.getReassociationExprs());
|
|
auto fusedOp = rewriter.create<GenericOp>(
|
|
loc, reshapeOp.getResultType(),
|
|
/*inputs=*/inputOperands,
|
|
// TODO: handle outputs.
|
|
/*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
|
|
producer.iterator_types(),
|
|
/*doc=*/nullptr,
|
|
/*library_call=*/nullptr);
|
|
auto &fusedRegion = fusedOp->getRegion(0);
|
|
rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion,
|
|
fusedRegion.begin());
|
|
rewriter.replaceOp(reshapeOp, fusedOp->getResults());
|
|
return success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
//===---------------------------------------------------------------------===//
|
|
// Methods and patterns that fuse reshape ops with elementwise operations by
|
|
// expanding the dimensionality of the elementwise operations.
|
|
//===---------------------------------------------------------------------===//
|
|
|
|
/// Conditions for folding a generic operation with a reshape op by expanding
|
|
/// the iteration space dimensionality for tensor operations. These are
|
|
/// preconditions assumed by `foldReshapeByDimExpansion` which implements the
|
|
/// following fusion pattern.
|
|
///
|
|
/// Consider
|
|
///
|
|
/// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
|
|
/// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
|
|
/// affine_map<(d0, d1, d2) -> (d1, d2)>,
|
|
/// affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
|
|
/// %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]]
|
|
/// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
|
|
///
|
|
/// The reshape can be folded into the `genericOp` if its loop dimensionality
|
|
/// is increased to match the result (operand) of the tensor_expand_shape.
|
|
/// The indexing_map of the fused tensor in the `genericOp` and the
|
|
/// reassociation map helps compute the indexing maps of the modified op.
|
|
/// For the above example, based on the reassociation map it
|
|
/// can be concluded that
|
|
///
|
|
/// - The loop used to access the first dimension of the fused tensor is split
|
|
/// into two.
|
|
/// - The loop used to access the second dimension of the fused tensor is kept
|
|
/// as is.
|
|
/// - The loop used to access the third dimension of the fused tensor is split
|
|
/// into three.
|
|
///
|
|
/// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
|
|
/// op, then
|
|
///
|
|
/// d0 -> e0, e1
|
|
/// d1 -> e2, e3, e4
|
|
/// d2 -> e5
|
|
///
|
|
/// substituting this, the generic op can be rewritten as
|
|
///
|
|
/// %d = linalg.generic ins(%0, %1 : )
|
|
/// indexing_maps =
|
|
/// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
|
|
/// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
|
|
/// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
|
|
///
|
|
/// Since operands to the linalg generic are now 5D, reshapes can be introduced
|
|
/// to make it consistent
|
|
///
|
|
/// %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]]
|
|
/// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
|
|
/// %1 = tensor.expand_shape %b [[0, 1, 2], [3]]
|
|
/// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
|
|
///
|
|
/// The added reshapes are again expanding patterns, so they will get fused
|
|
/// with its producers if possible.
|
|
static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
|
|
OpOperand *fusableOpOperand) {
|
|
// Is fusable only if:
|
|
// - All the indexing maps for operands and results are projected
|
|
// permutations.
|
|
// - The fused tensor is not a scalar.
|
|
// - All the loops are parallel loops.
|
|
return genericOp.hasTensorSemantics() &&
|
|
llvm::all_of(genericOp.indexing_maps().getValue(),
|
|
[](Attribute attr) {
|
|
return attr.cast<AffineMapAttr>()
|
|
.getValue()
|
|
.isProjectedPermutation();
|
|
}) &&
|
|
genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 &&
|
|
llvm::all_of(genericOp.iterator_types(), [](Attribute attr) {
|
|
return attr.cast<StringAttr>().getValue() ==
|
|
getParallelIteratorTypeName();
|
|
});
|
|
}
|
|
|
|
namespace {
|
|
/// Information needed to expand a generic operation to fold the reshape with
|
|
/// it.
|
|
class ExpansionInfo {
|
|
public:
|
|
// Computes the mapping from original dimensions of the op to the dimensions
|
|
// of the expanded op given the `indexingMap` of the fused operand/result of
|
|
// the generic op, the `reassocationMaps` of the reshape op and the shape of
|
|
// the expanded op.
|
|
LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
|
|
ArrayRef<AffineMap> reassociationMaps,
|
|
ArrayRef<int64_t> expandedShape,
|
|
ArrayRef<int64_t> collapsedShape,
|
|
PatternRewriter &rewriter);
|
|
unsigned getOrigOpNumDims() const { return reassociation.size(); }
|
|
unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
|
|
ReassociationIndicesRef getExpandedDims(unsigned i) const {
|
|
return reassociation[i];
|
|
}
|
|
ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
|
|
return expandedShapeMap[i];
|
|
}
|
|
ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; }
|
|
|
|
private:
|
|
/// Reassociation from the dimensions in the original operation to the
|
|
/// dimension of the expanded operation.
|
|
SmallVector<ReassociationIndices> reassociation;
|
|
/// Mapping from extent of loops in the original operation, to the extent of
|
|
/// loops in the expanded operation.
|
|
SmallVector<SmallVector<int64_t>> expandedShapeMap;
|
|
/// Extent of the loop in the original operation.
|
|
SmallVector<int64_t> originalLoopExtent;
|
|
unsigned expandedOpNumDims;
|
|
};
|
|
} // namespace
|
|
|
|
LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
|
|
OpOperand *fusableOpOperand,
|
|
ArrayRef<AffineMap> reassociationMaps,
|
|
ArrayRef<int64_t> expandedShape,
|
|
ArrayRef<int64_t> collapsedShape,
|
|
PatternRewriter &rewriter) {
|
|
if (reassociationMaps.empty())
|
|
return failure();
|
|
AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand);
|
|
|
|
Optional<SmallVector<int64_t, 4>> originalLoopRange =
|
|
linalgOp.getStaticLoopRanges();
|
|
if (!originalLoopRange)
|
|
return rewriter.notifyMatchFailure(linalgOp, "unable to find loop range");
|
|
originalLoopExtent.assign(originalLoopRange->begin(),
|
|
originalLoopRange->end());
|
|
|
|
reassociation.clear();
|
|
expandedShapeMap.clear();
|
|
// Compute the number of dimension in the expanded op that correspond to each
|
|
// dimension of the original op.
|
|
SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
|
|
expandedShapeMap.resize(fusedIndexMap.getNumDims());
|
|
for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
|
|
unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition();
|
|
AffineMap foldedDims = reassociationMaps[resultExpr.index()];
|
|
numExpandedDims[pos] = foldedDims.getNumResults();
|
|
ArrayRef<int64_t> shape =
|
|
expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
|
|
expandedShapeMap[pos].assign(shape.begin(), shape.end());
|
|
}
|
|
// The remaining dimensions remain the same.
|
|
for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
|
|
if (expandedShapeMap[i].empty())
|
|
expandedShapeMap[i] = {originalLoopExtent[i]};
|
|
|
|
// Compute reassociation map from the original op to the expanded op.
|
|
unsigned sum = 0;
|
|
reassociation.reserve(fusedIndexMap.getNumDims());
|
|
for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) {
|
|
auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
|
|
reassociation.emplace_back(seq.begin(), seq.end());
|
|
sum += numFoldedDim.value();
|
|
}
|
|
expandedOpNumDims = sum;
|
|
return success();
|
|
}
|
|
|
|
/// Epanding the body of a linalg operation requires adaptations of the accessed
|
|
/// loop indices. Specifically, access of indices in the original operation need
|
|
/// to be replaced with linearizations of indices in the expanded op. That
|
|
/// requires the shape of the expanded dimensions to be static (at least all but
|
|
/// the most significant). For now check that these are all statically sized.
|
|
/// Note that this could be extended to handle dynamic case, but the
|
|
/// implementation below uses `affine.apply` which seems to have issues when the
|
|
/// shapes are not static.
|
|
static LogicalResult isGenericOpExpandable(GenericOp genericOp,
|
|
const ExpansionInfo &expansionInfo,
|
|
PatternRewriter &rewriter) {
|
|
if (!genericOp.hasIndexSemantics())
|
|
return success();
|
|
for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
|
|
ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
|
|
if (expandedShape.size() == 1)
|
|
continue;
|
|
for (int64_t shape : expandedShape.drop_front()) {
|
|
if (ShapedType::isDynamic(shape)) {
|
|
return rewriter.notifyMatchFailure(
|
|
genericOp, "cannot expand due to index semantics and dynamic dims");
|
|
}
|
|
}
|
|
}
|
|
return success();
|
|
}
|
|
|
|
/// Return the indexing map to use in the expanded op for a given the
|
|
/// `indexingMap` of the original operation.
|
|
static AffineMap
|
|
getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
|
|
const ExpansionInfo &expansionInfo) {
|
|
SmallVector<AffineExpr> newExprs;
|
|
for (AffineExpr expr : indexingMap.getResults()) {
|
|
unsigned pos = expr.cast<AffineDimExpr>().getPosition();
|
|
SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
|
|
llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
|
|
return builder.getAffineDimExpr(static_cast<unsigned>(v));
|
|
}));
|
|
newExprs.append(expandedExprs.begin(), expandedExprs.end());
|
|
}
|
|
return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
|
|
indexingMap.getNumSymbols(), newExprs,
|
|
builder.getContext());
|
|
}
|
|
|
|
/// Return the type of the operand/result to use in the expanded op given the
|
|
/// type in the original op.
|
|
static RankedTensorType getExpandedType(RankedTensorType originalType,
|
|
AffineMap indexingMap,
|
|
const ExpansionInfo &expansionInfo) {
|
|
SmallVector<int64_t> expandedShape;
|
|
for (AffineExpr expr : indexingMap.getResults()) {
|
|
unsigned dim = expr.cast<AffineDimExpr>().getPosition();
|
|
auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
|
|
expandedShape.append(dimExpansion.begin(), dimExpansion.end());
|
|
}
|
|
return RankedTensorType::get(expandedShape, originalType.getElementType());
|
|
}
|
|
|
|
/// Returns the reassociation maps to use in the `tensor.expand_shape`
|
|
/// operation to convert the operands of the original operation to operands of
|
|
/// the expanded operation. The same method is used to compute the
|
|
/// `tensor.collapse_shape` used to collapse the result of the expanded
|
|
/// op to get the value that can replace all uses of the results of the original
|
|
/// op.
|
|
static SmallVector<ReassociationIndices>
|
|
getReassociationForExpansion(AffineMap indexingMap,
|
|
const ExpansionInfo &expansionInfo) {
|
|
SmallVector<ReassociationIndices> reassociation;
|
|
unsigned numReshapeDims = 0;
|
|
for (AffineExpr expr : indexingMap.getResults()) {
|
|
unsigned dim = expr.cast<AffineDimExpr>().getPosition();
|
|
auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
|
|
SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
|
|
llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
|
|
reassociation.emplace_back(std::move(indices));
|
|
numReshapeDims += numExpandedDims;
|
|
}
|
|
return reassociation;
|
|
}
|
|
|
|
/// Update the body of an expanded linalg operation having index semantics. The
|
|
/// indices of the original operation need to be recovered by linearizing the
|
|
/// indices of the correspoding dimensions of the expanded operation. For now it
|
|
/// is assumed that the shapes of the expanded operation needed for
|
|
/// linearization are static.
|
|
static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
|
|
Location loc, Region &fusedRegion,
|
|
const ExpansionInfo &expansionInfo) {
|
|
// Replace the original indices by the linearization of the expanded indices.
|
|
for (IndexOp indexOp :
|
|
llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
|
|
ArrayRef<int64_t> expandedDims =
|
|
expansionInfo.getExpandedDims(indexOp.dim());
|
|
assert(!expandedDims.empty() && "expected valid expansion info");
|
|
|
|
// Skip index operations that are not affected by the expansion.
|
|
if (expandedDims.size() == 1 &&
|
|
expandedDims.front() == (int64_t)indexOp.dim())
|
|
continue;
|
|
|
|
// Linearize the expanded indices of the original index dimension.
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
rewriter.setInsertionPointAfter(indexOp);
|
|
ArrayRef<int64_t> expandedDimsShape =
|
|
expansionInfo.getExpandedShapeOfDim(indexOp.dim()).drop_front();
|
|
SmallVector<Value> expandedIndices;
|
|
expandedIndices.reserve(expandedDims.size() - 1);
|
|
llvm::transform(
|
|
expandedDims.drop_front(), std::back_inserter(expandedIndices),
|
|
[&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
|
|
Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front());
|
|
for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
|
|
assert(!ShapedType::isDynamic(std::get<0>(it)));
|
|
AffineExpr idx, acc;
|
|
bindDims(rewriter.getContext(), idx, acc);
|
|
newIndex = rewriter.create<AffineApplyOp>(
|
|
indexOp.getLoc(), idx + acc * std::get<0>(it),
|
|
ValueRange{std::get<1>(it), newIndex});
|
|
}
|
|
rewriter.replaceOp(indexOp, newIndex);
|
|
}
|
|
}
|
|
|
|
/// Implements the fusion of a tensor_collapse_shape or a tensor_expand_shape op
|
|
/// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
|
|
/// that those conditions have been satisfied.
|
|
static Optional<SmallVector<Value>>
|
|
fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
|
|
OpOperand *fusableOpOperand,
|
|
PatternRewriter &rewriter) {
|
|
assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) &&
|
|
"preconditions for fuse operation failed");
|
|
// Check if reshape is expanding or collapsing.
|
|
auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
|
|
auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
|
|
bool isExpanding = (expandingReshapeOp != nullptr);
|
|
RankedTensorType expandedType = isExpanding
|
|
? expandingReshapeOp.getResultType()
|
|
: collapsingReshapeOp.getSrcType();
|
|
RankedTensorType collapsedType = isExpanding
|
|
? expandingReshapeOp.getSrcType()
|
|
: collapsingReshapeOp.getResultType();
|
|
|
|
ExpansionInfo expansionInfo;
|
|
if (failed(expansionInfo.compute(
|
|
genericOp, fusableOpOperand,
|
|
isExpanding ? expandingReshapeOp.getReassociationMaps()
|
|
: collapsingReshapeOp.getReassociationMaps(),
|
|
expandedType.getShape(), collapsedType.getShape(), rewriter)))
|
|
return llvm::None;
|
|
|
|
if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter)))
|
|
return llvm::None;
|
|
|
|
SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
|
|
llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap m) {
|
|
return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
|
|
}));
|
|
|
|
SmallVector<Value> expandedOpOperands;
|
|
expandedOpOperands.reserve(genericOp.getNumInputs());
|
|
for (OpOperand *opOperand : genericOp.getInputOperands()) {
|
|
if (opOperand == fusableOpOperand) {
|
|
expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.src()
|
|
: collapsingReshapeOp.src());
|
|
continue;
|
|
}
|
|
if (genericOp.isInputTensor(opOperand)) {
|
|
AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
|
|
auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
|
|
RankedTensorType expandedOperandType =
|
|
getExpandedType(opOperandType, indexingMap, expansionInfo);
|
|
if (expandedOperandType != opOperand->get().getType()) {
|
|
// Reshape the operand to get the right type.
|
|
SmallVector<ReassociationIndices> reassociation =
|
|
getReassociationForExpansion(indexingMap, expansionInfo);
|
|
if (failed(reshapeLikeShapesAreCompatible(
|
|
[&](const Twine &msg) {
|
|
return rewriter.notifyMatchFailure(genericOp, msg);
|
|
},
|
|
opOperandType.getShape(), expandedOperandType.getShape(),
|
|
reassociation,
|
|
/*isExpandingReshape=*/true)))
|
|
return llvm::None;
|
|
expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
|
|
genericOp.getLoc(), expandedOperandType, opOperand->get(),
|
|
reassociation));
|
|
continue;
|
|
}
|
|
}
|
|
expandedOpOperands.push_back(opOperand->get());
|
|
}
|
|
|
|
Location loc = genericOp.getLoc();
|
|
SmallVector<Value> outputs;
|
|
for (OpOperand *opOperand : genericOp.getOutputOperands()) {
|
|
AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
|
|
auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
|
|
RankedTensorType expandedOutputType =
|
|
getExpandedType(opOperandType, indexingMap, expansionInfo);
|
|
if (expandedOutputType != opOperand->get().getType()) {
|
|
SmallVector<ReassociationIndices> reassociation =
|
|
getReassociationForExpansion(indexingMap, expansionInfo);
|
|
if (failed(reshapeLikeShapesAreCompatible(
|
|
[&](const Twine &msg) {
|
|
return rewriter.notifyMatchFailure(genericOp, msg);
|
|
},
|
|
opOperandType.getShape(), expandedOutputType.getShape(),
|
|
reassociation,
|
|
/*isExpandingReshape=*/true)))
|
|
return llvm::None;
|
|
outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
|
|
genericOp.getLoc(), expandedOutputType, opOperand->get(),
|
|
reassociation));
|
|
}
|
|
}
|
|
|
|
// The iterator types of the expanded op are all parallel.
|
|
SmallVector<StringRef> iteratorTypes(expansionInfo.getExpandedOpNumDims(),
|
|
getParallelIteratorTypeName());
|
|
|
|
TypeRange resultTypes = ValueRange(outputs).getTypes();
|
|
auto fusedOp =
|
|
rewriter.create<GenericOp>(genericOp.getLoc(), resultTypes,
|
|
/*inputs=*/expandedOpOperands, outputs,
|
|
expandedOpIndexingMaps, iteratorTypes);
|
|
Region &fusedRegion = fusedOp->getRegion(0);
|
|
Region &originalRegion = genericOp->getRegion(0);
|
|
rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
|
|
|
|
// Update the index accesses after the expansion.
|
|
updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo);
|
|
|
|
// Reshape the result values to their original shape if this is a collapsing
|
|
// reshape folded into its consumer.
|
|
SmallVector<Value> resultVals;
|
|
for (OpResult opResult : genericOp->getOpResults()) {
|
|
int64_t resultNumber = opResult.getResultNumber();
|
|
if (!isExpanding && resultTypes[resultNumber] != opResult.getType()) {
|
|
SmallVector<ReassociationIndices> reassociation =
|
|
getReassociationForExpansion(
|
|
genericOp.getTiedIndexingMap(
|
|
genericOp.getOutputOperand(resultNumber)),
|
|
expansionInfo);
|
|
resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>(
|
|
genericOp.getLoc(), opResult.getType(),
|
|
fusedOp->getResult(resultNumber), reassociation));
|
|
} else {
|
|
resultVals.push_back(fusedOp->getResult(resultNumber));
|
|
}
|
|
}
|
|
// Assuming a single result.
|
|
return resultVals;
|
|
}
|
|
|
|
namespace {
|
|
|
|
/// Pattern to fuse a tensor_collapse_shape op with its consumer generic op,
|
|
/// when the reshape op is collapsing dimensions. The dimensionality of the loop
|
|
/// in the consumer is expanded.
|
|
class FoldWithProducerReshapeOpByExpansion
|
|
: public OpRewritePattern<GenericOp> {
|
|
public:
|
|
FoldWithProducerReshapeOpByExpansion(
|
|
MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
|
|
PatternBenefit benefit = 1)
|
|
: OpRewritePattern<GenericOp>(context, benefit),
|
|
controlFoldingReshapes(std::move(foldReshapes)) {}
|
|
|
|
LogicalResult matchAndRewrite(GenericOp genericOp,
|
|
PatternRewriter &rewriter) const override {
|
|
for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
|
|
tensor::CollapseShapeOp reshapeOp =
|
|
opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
|
|
if (!reshapeOp)
|
|
continue;
|
|
// Fold only if
|
|
// - The tensor reshape op is folding.
|
|
// - All constraints of fusing with reshape by expansion are met.
|
|
if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
|
|
(!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)))
|
|
continue;
|
|
|
|
Optional<SmallVector<Value>> replacementValues =
|
|
fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter);
|
|
if (!replacementValues)
|
|
return failure();
|
|
rewriter.replaceOp(genericOp, replacementValues.getValue());
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
private:
|
|
ControlElementwiseOpsFusionFn controlFoldingReshapes;
|
|
};
|
|
|
|
/// Pattern to fold a tensor_expand_shape op with its producer generic op
|
|
/// by expanding the dimensionality of the loop in the producer op.
|
|
struct FoldReshapeWithGenericOpByExpansion
|
|
: public OpRewritePattern<tensor::ExpandShapeOp> {
|
|
|
|
FoldReshapeWithGenericOpByExpansion(
|
|
MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
|
|
PatternBenefit benefit = 1)
|
|
: OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
|
|
controlFoldingReshapes(std::move(foldReshapes)) {}
|
|
|
|
LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp,
|
|
PatternRewriter &rewriter) const override {
|
|
// Fold only if all constraints of fusing with reshape by expansion are met.
|
|
GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>();
|
|
if (!producer || producer.getNumOutputs() != 1 ||
|
|
!isFusableWithReshapeByDimExpansion(producer,
|
|
producer.getOutputOperand(0)) ||
|
|
!controlFoldingReshapes(producer->getResult(0),
|
|
reshapeOp->getOpOperand(0)))
|
|
return failure();
|
|
Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion(
|
|
producer, reshapeOp, producer.getOutputOperand(0), rewriter);
|
|
if (!replacementValues)
|
|
return failure();
|
|
rewriter.replaceOp(reshapeOp, replacementValues.getValue());
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
ControlElementwiseOpsFusionFn controlFoldingReshapes;
|
|
};
|
|
} // namespace
|
|
|
|
//===---------------------------------------------------------------------===//
|
|
// Methods and patterns to fuse reshape with linalg.generic operations by
|
|
// contraction of dimensions.
|
|
//===---------------------------------------------------------------------===//
|
|
|
|
/// For an `indexingMap` that is a projected permutation, if the range is to be
|
|
/// collapsed using the given `reassociation`, get the reassociation in the
|
|
/// domain that would keep the map a projected permutation.
|
|
static SmallVector<ReassociationIndices>
|
|
getDomainReassociation(AffineMap indexingMap,
|
|
ArrayRef<ReassociationIndices> rangeReassociation) {
|
|
assert(indexingMap.isProjectedPermutation() &&
|
|
"expected projected permutation map");
|
|
unsigned counter = 0;
|
|
SmallVector<ReassociationIndices> domainReassociation;
|
|
llvm::SmallDenseSet<unsigned, 4> processedDomainDims;
|
|
// Iterate over the reassociation indices.
|
|
for (ReassociationIndicesRef foldedRangeDims : rangeReassociation) {
|
|
ReassociationIndices foldedDomainDims;
|
|
for (auto rangeDim : foldedRangeDims) {
|
|
(void)rangeDim;
|
|
AffineDimExpr dimExpr =
|
|
indexingMap.getResult(counter++).cast<AffineDimExpr>();
|
|
foldedDomainDims.push_back(dimExpr.getPosition());
|
|
processedDomainDims.insert(dimExpr.getPosition());
|
|
}
|
|
domainReassociation.emplace_back(std::move(foldedDomainDims));
|
|
}
|
|
// Fill in the missing domain dims.
|
|
for (auto dim : llvm::seq<unsigned>(0, indexingMap.getNumDims())) {
|
|
if (processedDomainDims.count(dim))
|
|
continue;
|
|
ReassociationIndices vec = {dim};
|
|
domainReassociation.emplace_back(std::move(vec));
|
|
}
|
|
|
|
// Sort the reassociation using the first dimension of the folded range to
|
|
// not create unnecessary transposes.
|
|
llvm::sort(domainReassociation,
|
|
[](ReassociationIndicesRef lhs, ReassociationIndicesRef rhs) {
|
|
return lhs[0] < rhs[0];
|
|
});
|
|
return domainReassociation;
|
|
}
|
|
|
|
/// For a given `dimSequence`, check if the sequence is conserved in the
|
|
/// `indexingMap`. `indexingMap` is expected to be a projected permutation.
|
|
/// Non-existence of the sequence returns true as well.
|
|
static bool isDimSequencePreserved(AffineMap indexingMap,
|
|
ReassociationIndicesRef dimSequence) {
|
|
assert(!dimSequence.empty() &&
|
|
"expected non-empty list for dimension sequence");
|
|
assert(indexingMap.isProjectedPermutation() &&
|
|
"expected indexing map to be projected permutation");
|
|
|
|
llvm::SmallDenseSet<unsigned, 4> sequenceElements;
|
|
sequenceElements.insert(dimSequence.begin(), dimSequence.end());
|
|
|
|
unsigned dimSequenceStart = dimSequence[0];
|
|
for (const auto &expr : enumerate(indexingMap.getResults())) {
|
|
unsigned dimInMapStart = expr.value().cast<AffineDimExpr>().getPosition();
|
|
// 1. Check if this start of the sequence.
|
|
if (dimInMapStart == dimSequenceStart) {
|
|
if (expr.index() + dimSequence.size() > indexingMap.getNumResults())
|
|
return false;
|
|
// 1a. Check if sequence is preserved.
|
|
for (const auto &dimInSequence : enumerate(dimSequence)) {
|
|
unsigned dimInMap =
|
|
indexingMap.getResult(expr.index() + dimInSequence.index())
|
|
.cast<AffineDimExpr>()
|
|
.getPosition();
|
|
if (dimInMap != dimInSequence.value())
|
|
return false;
|
|
}
|
|
// Found the sequence. Projected permutation
|
|
// enforces that all AffineDimExprs in the result are unique, so no
|
|
// further checks are needed.
|
|
return true;
|
|
}
|
|
// 2. If position in the expr (which is of type AffineDimExpr) is part
|
|
// of sequence, return false here. This implies the entire sequence does not
|
|
// exist in the indexing map.
|
|
if (sequenceElements.count(dimInMapStart))
|
|
return false;
|
|
}
|
|
// 3. No element of sequence found. Return true.
|
|
return true;
|
|
}
|
|
|
|
// Check if a generic op can be fused along an operand by collapsing dimensions.
|
|
static bool isFusableWithReshapeByDimCollapse(
|
|
GenericOp genericOp, OpOperand *fusableOperand,
|
|
ArrayRef<ReassociationIndices> reassociation) {
|
|
// Some basic checks for this fusion to be valid.
|
|
if (!genericOp.hasTensorSemantics() || genericOp.getNumOutputs() != 1)
|
|
return false;
|
|
|
|
if (!llvm::all_of(genericOp.getIndexingMaps(), [](AffineMap map) {
|
|
return map.isProjectedPermutation();
|
|
})) {
|
|
return false;
|
|
}
|
|
|
|
// Get the reassociation for the iteration space.
|
|
SmallVector<ReassociationIndices> iterationReassociation =
|
|
getDomainReassociation(genericOp.getTiedIndexingMap(fusableOperand),
|
|
reassociation);
|
|
if (iterationReassociation.empty()) {
|
|
// If the domain reassociation indices is empty, then this is a scalar op.
|
|
// Nothing to do.
|
|
return false;
|
|
}
|
|
|
|
auto iteratorTypes = genericOp.iterator_types().getValue();
|
|
ArrayRef<Attribute> iteratorTypesRef(iteratorTypes);
|
|
for (ReassociationIndicesRef foldedIterDims : iterationReassociation) {
|
|
// Check that for all indexing maps, the folded dimensions sequence is
|
|
// preserved.
|
|
if (!llvm::all_of(genericOp.getIndexingMaps(), [&](AffineMap indexingMap) {
|
|
return isDimSequencePreserved(indexingMap, foldedIterDims);
|
|
}))
|
|
return false;
|
|
unsigned startDim = foldedIterDims[0];
|
|
ArrayRef<Attribute> foldedIteratorTypes =
|
|
iteratorTypesRef.drop_front(startDim).take_front(foldedIterDims.size());
|
|
// Check that all folded iterator types are either all parallel, or all
|
|
// reduction.
|
|
if (!llvm::all_of(
|
|
foldedIteratorTypes,
|
|
[](Attribute attr) { return isParallelIterator(attr); }) &&
|
|
!llvm::all_of(foldedIteratorTypes,
|
|
[](Attribute attr) { return isReductionIterator(attr); }))
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
/// Helper class to carry state while collapsing the `linalg.generic` op.
|
|
namespace {
|
|
class CollapsingInfo {
|
|
public:
|
|
CollapsingInfo(SmallVector<ReassociationIndices> &&reassociation) {
|
|
iterationReassociation = std::move(reassociation);
|
|
for (const auto &foldedIterDims : enumerate(iterationReassociation)) {
|
|
foldedDimStartToSequenceMap[foldedIterDims.value()[0]] =
|
|
foldedIterDims.index();
|
|
}
|
|
}
|
|
|
|
// Returns the iteration space reassociation.
|
|
ArrayRef<ReassociationIndices> getReassociationIndices() {
|
|
return iterationReassociation;
|
|
}
|
|
|
|
// Returns true if the given dimension is the start of a sequence of folded
|
|
// dimensions.
|
|
bool isDimStartOfFoldedDims(unsigned dim) {
|
|
return foldedDimStartToSequenceMap.count(dim);
|
|
}
|
|
|
|
// Return the folded dimensions starting at `dim`.
|
|
ReassociationIndicesRef getFoldedDimsStartingAt(unsigned dim) {
|
|
assert(foldedDimStartToSequenceMap.count(dim) &&
|
|
"invalid start dim of folded dim "
|
|
"sequence");
|
|
return iterationReassociation[foldedDimStartToSequenceMap[dim]];
|
|
}
|
|
|
|
// For a dim in the original op, return the dim in the collapsed op, that it
|
|
// is mapped to. Expectes `dim` to be start of a folded dimension sequence.
|
|
unsigned getDimInCollapsedOpForStartOfFoldedDims(unsigned dim) {
|
|
assert(foldedDimStartToSequenceMap.count(dim) &&
|
|
"invalid start dim of folded dim sequence");
|
|
return foldedDimStartToSequenceMap[dim];
|
|
}
|
|
|
|
private:
|
|
/// Reassociation describing the folded iteration space dimensions.
|
|
SmallVector<ReassociationIndices> iterationReassociation;
|
|
|
|
/// Map from the starting dimensions of the folded dimension sequences to
|
|
/// their index in `iterationReassociation`.
|
|
llvm::DenseMap<unsigned, unsigned> foldedDimStartToSequenceMap;
|
|
};
|
|
} // namespace
|
|
|
|
/// Get the iterator types for the collapsed operation given the original
|
|
/// iterator types and collapsed dimensions.
|
|
static SmallVector<StringRef>
|
|
getCollapsedOpIteratorTypes(ArrayRef<Attribute> iteratorTypes,
|
|
CollapsingInfo &collapsingInfo) {
|
|
SmallVector<StringRef> collapsedIteratorTypes;
|
|
for (ReassociationIndicesRef foldedIterDims :
|
|
collapsingInfo.getReassociationIndices()) {
|
|
assert(!foldedIterDims.empty() &&
|
|
"reassociation indices expected to have non-empty sets");
|
|
// Just pick the iterator type of the first folded dim. Pre-condition checks
|
|
// expected to have checked that iterator types of all folded dimensions are
|
|
// the same.
|
|
collapsedIteratorTypes.push_back(
|
|
iteratorTypes[foldedIterDims[0]].cast<StringAttr>().getValue());
|
|
}
|
|
return collapsedIteratorTypes;
|
|
}
|
|
|
|
/// Compute the indexing map in the collapsed op that corresponds to the given
|
|
/// `indexingMap` of the original operation.
|
|
static AffineMap getCollapsedOpIndexingMap(AffineMap indexingMap,
|
|
CollapsingInfo &collapsingInfo) {
|
|
MLIRContext *context = indexingMap.getContext();
|
|
assert(indexingMap.isProjectedPermutation() &&
|
|
"expected indexing map to be projected permutation");
|
|
SmallVector<AffineExpr> resultExprs;
|
|
for (auto expr : indexingMap.getResults()) {
|
|
unsigned dim = expr.cast<AffineDimExpr>().getPosition();
|
|
if (collapsingInfo.isDimStartOfFoldedDims(dim)) {
|
|
resultExprs.push_back(getAffineDimExpr(
|
|
collapsingInfo.getDimInCollapsedOpForStartOfFoldedDims(dim),
|
|
context));
|
|
}
|
|
}
|
|
return AffineMap::get(collapsingInfo.getReassociationIndices().size(), 0,
|
|
resultExprs, context);
|
|
}
|
|
|
|
/// Return the `reassociation` indices to use to collapse the operand when the
|
|
/// iteration space of a generic op is collapsed.
|
|
static SmallVector<ReassociationIndices>
|
|
getOperandReassociation(AffineMap indexingMap, CollapsingInfo &collapsingInfo) {
|
|
unsigned counter = 0;
|
|
SmallVector<ReassociationIndices> operandReassociation;
|
|
for (auto expr : indexingMap.getResults()) {
|
|
unsigned dim = expr.cast<AffineDimExpr>().getPosition();
|
|
if (collapsingInfo.isDimStartOfFoldedDims(dim)) {
|
|
unsigned numFoldedDims =
|
|
collapsingInfo.getFoldedDimsStartingAt(dim).size();
|
|
auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims);
|
|
operandReassociation.emplace_back(range.begin(), range.end());
|
|
counter += numFoldedDims;
|
|
}
|
|
}
|
|
return operandReassociation;
|
|
}
|
|
|
|
/// Get the new value to use for a given `OpOperand` in the collapsed operation.
|
|
static Value getCollapsedOpOperand(Location loc, GenericOp genericOp,
|
|
OpOperand *opOperand,
|
|
CollapsingInfo &collapsingInfo,
|
|
OpBuilder &builder) {
|
|
AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
|
|
SmallVector<ReassociationIndices> operandReassociation =
|
|
getOperandReassociation(indexingMap, collapsingInfo);
|
|
|
|
// If the number of entries in the reassocation for the operand is same as the
|
|
// number of results of the indexing map, then nothing to do for this operand.
|
|
Value operand = opOperand->get();
|
|
if (operandReassociation.size() == indexingMap.getNumResults())
|
|
return operand;
|
|
|
|
// Insert a reshape to collapse the dimensions.
|
|
auto reshapeOp = builder.create<tensor::CollapseShapeOp>(
|
|
loc, operand, operandReassociation);
|
|
return reshapeOp.getResult();
|
|
}
|
|
|
|
/// Modify the `linalg.index` operations in the original generic op, to its
|
|
/// value in the collapsed operation.
|
|
void generateCollapsedIndexingRegion(Location loc, Block *block,
|
|
CollapsingInfo &collapsingInfo,
|
|
ValueRange loopRange,
|
|
PatternRewriter &rewriter) {
|
|
OpBuilder::InsertionGuard g(rewriter);
|
|
rewriter.setInsertionPointToStart(block);
|
|
|
|
// Collect all the original index ops.
|
|
auto indexOps = llvm::to_vector(block->getOps<linalg::IndexOp>());
|
|
|
|
// For each folded dimension list resolve the original induction variable
|
|
// values in terms of the folded dimension induction variable.
|
|
// i_{folded} = (i_0 * d1 + i1) * d2 + i2.
|
|
// can be inverted to
|
|
// i2 = i_{folded} % d2
|
|
// i1 = (i_{folded} / d2) % d1
|
|
// i0 = i_{folded} / (d1 * d2)
|
|
llvm::DenseMap<unsigned, Value> indexReplacementVals;
|
|
for (auto &foldedDims : enumerate(collapsingInfo.getReassociationIndices())) {
|
|
ReassociationIndicesRef foldedDimsRef(foldedDims.value());
|
|
Value newIndexVal =
|
|
rewriter.create<linalg::IndexOp>(loc, foldedDims.index());
|
|
for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) {
|
|
indexReplacementVals[dim] =
|
|
rewriter.create<arith::RemUIOp>(loc, newIndexVal, loopRange[dim]);
|
|
newIndexVal =
|
|
rewriter.create<arith::DivUIOp>(loc, newIndexVal, loopRange[dim]);
|
|
}
|
|
indexReplacementVals[foldedDims.value().front()] = newIndexVal;
|
|
}
|
|
|
|
for (auto indexOp : indexOps) {
|
|
auto dim = indexOp.dim();
|
|
rewriter.replaceOp(indexOp, indexReplacementVals[dim]);
|
|
}
|
|
}
|
|
|
|
/// Implementation of fusion with reshape operation by collapsing dimensions.
|
|
static Optional<SmallVector<Value>>
|
|
fuseWithReshapeByCollapsing(GenericOp genericOp, Operation *reshapeOp,
|
|
OpOperand *fusableOpOperand,
|
|
PatternRewriter &rewriter) {
|
|
SmallVector<ReassociationIndices> reassociation =
|
|
isa<tensor::CollapseShapeOp>(reshapeOp)
|
|
? cast<tensor::CollapseShapeOp>(reshapeOp).getReassociationIndices()
|
|
: cast<tensor::ExpandShapeOp>(reshapeOp).getReassociationIndices();
|
|
assert(isFusableWithReshapeByDimCollapse(genericOp, fusableOpOperand,
|
|
reassociation) &&
|
|
"preconditions for fusing with reshape by collapse failed");
|
|
|
|
CollapsingInfo collapsingInfo(getDomainReassociation(
|
|
genericOp.getTiedIndexingMap(fusableOpOperand), reassociation));
|
|
// Check for trivial no transformation cases. In that case return nothing.
|
|
if (collapsingInfo.getReassociationIndices().size() ==
|
|
genericOp.getNumLoops())
|
|
return llvm::None;
|
|
|
|
// Get the iterator types for the operand.
|
|
SmallVector<StringRef> iteratorTypes = getCollapsedOpIteratorTypes(
|
|
genericOp.iterator_types().getValue(), collapsingInfo);
|
|
|
|
// Get the indexing maps.
|
|
auto indexingMaps = llvm::to_vector(
|
|
llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap map) {
|
|
return getCollapsedOpIndexingMap(map, collapsingInfo);
|
|
}));
|
|
|
|
Location loc =
|
|
rewriter.getFusedLoc({genericOp->getLoc(), reshapeOp->getLoc()});
|
|
|
|
// Get the input operands.
|
|
auto inputOperands = llvm::to_vector(
|
|
llvm::map_range(genericOp.getInputOperands(), [&](OpOperand *opOperand) {
|
|
return getCollapsedOpOperand(loc, genericOp, opOperand, collapsingInfo,
|
|
rewriter);
|
|
}));
|
|
|
|
// Get the output operands and result types.
|
|
SmallVector<Type> resultTypes;
|
|
SmallVector<Value> outputOperands;
|
|
resultTypes.reserve(genericOp.getNumOutputs());
|
|
outputOperands.reserve(genericOp.getNumOutputs());
|
|
for (OpOperand *output : genericOp.getOutputOperands()) {
|
|
Value newOutput =
|
|
getCollapsedOpOperand(loc, genericOp, output, collapsingInfo, rewriter);
|
|
outputOperands.push_back(newOutput);
|
|
resultTypes.push_back(newOutput.getType());
|
|
}
|
|
|
|
// Create the generic op.
|
|
auto collapsedGenericOp = rewriter.create<linalg::GenericOp>(
|
|
loc, resultTypes, inputOperands, outputOperands, indexingMaps,
|
|
iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
|
|
Block *origOpBlock = &genericOp->getRegion(0).front();
|
|
Block *collapsedOpBlock = &collapsedGenericOp->getRegion(0).front();
|
|
rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
|
|
collapsedOpBlock->getArguments());
|
|
|
|
if (collapsedGenericOp.hasIndexSemantics()) {
|
|
// Collect the loop range of the generic op.
|
|
OpBuilder::InsertionGuard g(rewriter);
|
|
rewriter.setInsertionPoint(collapsedGenericOp);
|
|
SmallVector<Range> loopRanges =
|
|
cast<LinalgOp>(genericOp.getOperation())
|
|
.createLoopRanges(rewriter, genericOp.getLoc());
|
|
assert(llvm::all_of(loopRanges,
|
|
[](Range range) {
|
|
return matchPattern(range.offset, m_Zero()) &&
|
|
matchPattern(range.stride, m_One());
|
|
}) &&
|
|
"expected all loop ranges to have zero start and unit stride");
|
|
SmallVector<Value> loopBound = llvm::to_vector(
|
|
llvm::map_range(loopRanges, [](Range range) { return range.size; }));
|
|
generateCollapsedIndexingRegion(loc,
|
|
&collapsedGenericOp->getRegion(0).front(),
|
|
collapsingInfo, loopBound, rewriter);
|
|
}
|
|
|
|
// Insert expanding reshape for the result to get back the original result
|
|
// type.
|
|
SmallVector<Value> results;
|
|
for (const auto &originalResult : llvm::enumerate(genericOp->getResults())) {
|
|
Value collapsedOpResult =
|
|
collapsedGenericOp->getResult(originalResult.index());
|
|
auto originalResultType =
|
|
originalResult.value().getType().cast<ShapedType>();
|
|
auto collapsedOpResultType = collapsedOpResult.getType().cast<ShapedType>();
|
|
if (collapsedOpResultType.getRank() != originalResultType.getRank()) {
|
|
AffineMap indexingMap =
|
|
genericOp.getTiedIndexingMapForResult(originalResult.value());
|
|
SmallVector<ReassociationIndices> reassociation =
|
|
getOperandReassociation(indexingMap, collapsingInfo);
|
|
Value result = rewriter.create<tensor::ExpandShapeOp>(
|
|
loc, originalResultType, collapsedOpResult, reassociation);
|
|
results.push_back(result);
|
|
} else {
|
|
results.push_back(collapsedOpResult);
|
|
}
|
|
}
|
|
return results;
|
|
}
|
|
|
|
namespace {
|
|
|
|
/// Pattern to fuse a tensor.expand_shape op with its consumer generic op by
|
|
/// contracting dimensions of the loop.
|
|
class FoldWithProducerReshapeOpByCollapsing
|
|
: public OpRewritePattern<GenericOp> {
|
|
public:
|
|
FoldWithProducerReshapeOpByCollapsing(
|
|
MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
|
|
PatternBenefit benefit = 1)
|
|
: OpRewritePattern<GenericOp>(context, benefit),
|
|
controlFoldingReshapes(std::move(foldReshapes)) {}
|
|
|
|
LogicalResult matchAndRewrite(GenericOp genericOp,
|
|
PatternRewriter &rewriter) const override {
|
|
for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
|
|
tensor::ExpandShapeOp reshapeOp =
|
|
opOperand->get().getDefiningOp<tensor::ExpandShapeOp>();
|
|
if (!reshapeOp)
|
|
continue;
|
|
|
|
if (!isFusableWithReshapeByDimCollapse(
|
|
genericOp, opOperand, reshapeOp.getReassociationIndices()) ||
|
|
!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)) {
|
|
continue;
|
|
}
|
|
|
|
Optional<SmallVector<Value>> replacements = fuseWithReshapeByCollapsing(
|
|
genericOp, reshapeOp, opOperand, rewriter);
|
|
if (!replacements) {
|
|
return rewriter.notifyMatchFailure(
|
|
genericOp, "failed to do the fusion by collapsing transformation");
|
|
}
|
|
|
|
rewriter.replaceOp(genericOp, replacements.getValue());
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
private:
|
|
ControlElementwiseOpsFusionFn controlFoldingReshapes;
|
|
};
|
|
} // namespace
|
|
|
|
//===---------------------------------------------------------------------===//
|
|
// Methods and patterns to convert tensor.expand_shape -> linalg.generic
|
|
// into linalg.generic -> tensor.expand_shape, i.e. push the reshape down.
|
|
//===---------------------------------------------------------------------===//
|
|
|
|
// TODO(ravishankarm): This pattern is to be deprecated in favor of fusion by
|
|
// collapsing that provides a more general functionality. This pattern is very
|
|
// specific to a particular use case. The fusion by collapsing can provide the
|
|
// same control to clients using the control function there.
|
|
|
|
static SmallVector<ReassociationIndices>
|
|
getReassociationIndices(ArrayRef<AffineMap> maps) {
|
|
SmallVector<ReassociationIndices> reassociation;
|
|
for (AffineMap map : maps) {
|
|
ReassociationIndices indices;
|
|
for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
|
|
unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition();
|
|
indices.push_back(pos);
|
|
}
|
|
reassociation.push_back(indices);
|
|
}
|
|
return reassociation;
|
|
}
|
|
|
|
namespace {
|
|
/// Pattern to move rank reducing reshape after an elementwise linalg generic
|
|
/// op. This is useful to expose more fusion opportunities between named ops and
|
|
/// generic ops. This can only be done if there is no broadcast or permuation
|
|
/// within the dimensions we need to merge.
|
|
///
|
|
/// For example,
|
|
///
|
|
/// %0 = tensor.expand_shape %A [[0, 1], [2]]
|
|
/// : tensor<12544x16xf32> into tensor<112x112x16xf32>
|
|
/// %2 = linalg.generic {indexing_maps = [
|
|
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
|
|
/// affine_map<(d0, d1, d2) -> (d2)>,
|
|
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types =
|
|
/// ["parallel", "parallel", "parallel"]} {
|
|
/// } -> tensor<112x112x16xf32>
|
|
///
|
|
/// into
|
|
///
|
|
/// %2 = linalg.generic {indexing_maps = [
|
|
/// affine_map<(d0, d1) -> (d0, d1)>,
|
|
/// affine_map<(d0, d1) -> (d1)>,
|
|
/// affine_map<(d0, d1) -> (d0, d1)>],
|
|
/// iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1
|
|
/// : tensor<12544x16xf32>, tensor<16xf32>) outs(%1 : tensor<12544x16xf32>) {
|
|
/// } -> tensor<12544x16xf32>
|
|
/// %3 = tensor.expand_shape %2 [[0, 1], [2]]
|
|
/// : tensor<12544x16xf32> into tensor<112x112x16xf32>
|
|
struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
|
|
using OpRewritePattern<GenericOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(GenericOp genericOp,
|
|
PatternRewriter &rewriter) const override {
|
|
// Only apply to elementwise linalg on tensor.
|
|
if (!genericOp.hasTensorSemantics() || genericOp.hasIndexSemantics() ||
|
|
genericOp.getNumParallelLoops() != genericOp.getNumLoops())
|
|
return failure();
|
|
// Only support identity output maps. It could be extended to permuations if
|
|
// needed.
|
|
if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *opOperand) {
|
|
return !genericOp.getTiedIndexingMap(opOperand).isIdentity();
|
|
}))
|
|
return failure();
|
|
int64_t destRank = genericOp.getNumParallelLoops();
|
|
SmallVector<Value> newOperands = genericOp.getInputOperands();
|
|
tensor::ExpandShapeOp reshapeFound;
|
|
// 1. Look for tensor_expand_shape operands and figure out save the
|
|
// dimensions merged.
|
|
SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
|
|
for (const auto &en : llvm::enumerate(inputOperands)) {
|
|
auto reshapeOp =
|
|
en.value()->get().template getDefiningOp<tensor::ExpandShapeOp>();
|
|
if (!reshapeOp)
|
|
continue;
|
|
// TODO: We could support non-identity map as long as the merged
|
|
// dimensions are still contiguous.
|
|
if (!genericOp.getTiedIndexingMap(en.value()).isIdentity())
|
|
continue;
|
|
if (reshapeFound) {
|
|
// Only support a second reshape op if it has the same reassociate maps.
|
|
if (reshapeFound.getReassociationMaps() ==
|
|
reshapeOp.getReassociationMaps())
|
|
newOperands[en.index()] = reshapeOp.src();
|
|
continue;
|
|
}
|
|
reshapeFound = reshapeOp;
|
|
newOperands[en.index()] = reshapeOp.src();
|
|
}
|
|
if (!reshapeFound)
|
|
return failure();
|
|
|
|
// Calculate the reassociation indices and rassociated reverse map.
|
|
SmallVector<ReassociationIndices> reassociation =
|
|
getReassociationIndices(reshapeFound.getReassociationMaps());
|
|
SmallVector<unsigned> remap(destRank);
|
|
for (auto &indices : llvm::enumerate(reassociation)) {
|
|
for (int64_t index : indices.value()) {
|
|
remap[index] = indices.index();
|
|
}
|
|
}
|
|
// 2. Verify that we can merge the dimensions in the linalg and that we
|
|
// don't need to create new reshapes operands. Inserting new reshape
|
|
// operands would defeat the purpose of the transformation.
|
|
for (const auto &en : llvm::enumerate(inputOperands)) {
|
|
if (en.value()->get() == newOperands[en.index()]) {
|
|
AffineMap map = genericOp.getTiedIndexingMap(en.value());
|
|
for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) {
|
|
if (reassociation[remap[map.getDimPosition(i)]].size() > 1)
|
|
return failure();
|
|
}
|
|
}
|
|
}
|
|
|
|
// 3. Calculate the affine map remapping and the reassociation to apply to
|
|
// output tensors.
|
|
SmallVector<AffineMap> newMaps;
|
|
unsigned newRank = reassociation.size();
|
|
for (auto map : genericOp.getIndexingMaps()) {
|
|
SmallVector<AffineExpr> newExprs;
|
|
for (auto expr : map.getResults()) {
|
|
unsigned position = expr.template cast<AffineDimExpr>().getPosition();
|
|
// Skip dimension merged except for the last of the group.
|
|
if (reassociation[remap[position]].back() == position) {
|
|
newExprs.push_back(
|
|
getAffineDimExpr(remap[position], genericOp.getContext()));
|
|
}
|
|
}
|
|
newMaps.push_back(
|
|
AffineMap::get(newRank, 0, newExprs, genericOp.getContext()));
|
|
}
|
|
|
|
// 4. Reshape the output tensors.
|
|
SmallVector<Value> newOutputs;
|
|
SmallVector<Type> newOutputTypes;
|
|
for (auto output : genericOp.outputs()) {
|
|
auto newOutputType = RankedTensorType::get(
|
|
reshapeFound.getSrcType().getShape(),
|
|
output.getType().template cast<RankedTensorType>().getElementType());
|
|
Value newOutput = rewriter.create<tensor::CollapseShapeOp>(
|
|
genericOp->getLoc(), newOutputType, output, reassociation);
|
|
newOutputTypes.push_back(newOutputType);
|
|
newOutputs.push_back(newOutput);
|
|
}
|
|
// 5. Create a new generic op with lowerer rank.
|
|
SmallVector<StringRef> iteratorTypes(newRank,
|
|
getParallelIteratorTypeName());
|
|
auto newOp = rewriter.create<GenericOp>(genericOp->getLoc(), newOutputTypes,
|
|
newOperands, newOutputs, newMaps,
|
|
iteratorTypes);
|
|
rewriter.inlineRegionBefore(genericOp.region(), newOp.region(),
|
|
newOp.region().begin());
|
|
// 6. Reshape the so that the type matches the uses.
|
|
SmallVector<Value> newResults;
|
|
for (const auto &result : llvm::enumerate(newOp->getResults())) {
|
|
newResults.push_back(rewriter.create<tensor::ExpandShapeOp>(
|
|
genericOp->getLoc(), genericOp.getOutputTensorTypes()[result.index()],
|
|
result.value(), reassociation));
|
|
}
|
|
rewriter.replaceOp(genericOp, newResults);
|
|
return success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
//===---------------------------------------------------------------------===//
|
|
// Methods and patterns that fuse constants with linalg.generic operations.
|
|
//===---------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
/// Pattern to fold a generic op with a splat constant/scalar constant. Does not
|
|
/// handle cases where the constant is not single-valued.
|
|
class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
|
|
public:
|
|
FoldScalarOrSplatConstant(MLIRContext *context,
|
|
ControlElementwiseOpsFusionFn &fun,
|
|
PatternBenefit benefit = 1)
|
|
: OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
|
|
|
|
LogicalResult matchAndRewrite(GenericOp genericOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (!genericOp.hasTensorSemantics())
|
|
return failure();
|
|
for (OpOperand *opOperand : genericOp.getInputOperands()) {
|
|
Operation *def = opOperand->get().getDefiningOp();
|
|
Attribute constantAttr;
|
|
auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
|
|
{
|
|
DenseElementsAttr splatAttr;
|
|
if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
|
|
splatAttr.isSplat() &&
|
|
splatAttr.getType().getElementType().isIntOrFloat()) {
|
|
constantAttr = splatAttr.getSplatValue<Attribute>();
|
|
return true;
|
|
}
|
|
}
|
|
{
|
|
IntegerAttr intAttr;
|
|
if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
|
|
constantAttr = intAttr;
|
|
return true;
|
|
}
|
|
}
|
|
{
|
|
FloatAttr floatAttr;
|
|
if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
|
|
constantAttr = floatAttr;
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
};
|
|
|
|
auto resultValue = opOperand->get().dyn_cast<OpResult>();
|
|
if (!def || !resultValue || !isScalarOrSplatConstantOp(def) ||
|
|
!controlFn(resultValue, *opOperand))
|
|
continue;
|
|
|
|
// The operands and the indexing_maps of the fused operation the same as
|
|
// the operands and indexing_maps of the generic operations with the
|
|
// values at the constant index dropped.
|
|
SmallVector<AffineMap> fusedIndexMaps;
|
|
SmallVector<Value> fusedOperands;
|
|
SmallVector<Location> fusedLocs{genericOp.getLoc()};
|
|
fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs());
|
|
fusedOperands.reserve(genericOp.getNumInputs());
|
|
fusedLocs.reserve(fusedLocs.size() + genericOp.getNumInputs());
|
|
for (OpOperand *inputOperand : genericOp.getInputOperands()) {
|
|
if (inputOperand == opOperand)
|
|
continue;
|
|
Value inputValue = inputOperand->get();
|
|
fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(inputOperand));
|
|
fusedOperands.push_back(inputValue);
|
|
fusedLocs.push_back(inputValue.getLoc());
|
|
}
|
|
for (OpOperand *outputOperand : genericOp.getOutputOperands())
|
|
fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(outputOperand));
|
|
|
|
// Check if the operation shapes to loops map is computable.
|
|
if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
|
|
return rewriter.notifyMatchFailure(
|
|
genericOp, "fused op loop bound computation failed");
|
|
}
|
|
|
|
// Create a constant scalar value from the splat constant.
|
|
Value scalarConstant = rewriter.create<arith::ConstantOp>(
|
|
def->getLoc(), constantAttr, constantAttr.getType());
|
|
|
|
SmallVector<Value> outputOperands = genericOp.getOutputOperands();
|
|
auto fusedOp = rewriter.create<GenericOp>(
|
|
rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(),
|
|
/*inputs=*/fusedOperands,
|
|
/*outputs=*/outputOperands,
|
|
rewriter.getAffineMapArrayAttr(fusedIndexMaps),
|
|
genericOp.iterator_types(),
|
|
/*doc=*/nullptr,
|
|
/*library_call=*/nullptr);
|
|
|
|
// Map the block argument corresponding to the replaced argument with the
|
|
// scalar constant.
|
|
Region ®ion = genericOp->getRegion(0);
|
|
Block &entryBlock = *region.begin();
|
|
BlockAndValueMapping mapping;
|
|
mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
|
|
scalarConstant);
|
|
Region &fusedRegion = fusedOp->getRegion(0);
|
|
rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
|
|
mapping);
|
|
rewriter.replaceOp(genericOp, fusedOp->getResults());
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
private:
|
|
ControlElementwiseOpsFusionFn controlFn;
|
|
};
|
|
|
|
/// Base class for constant folding linalg.generic ops with N inputs, 1 output,
|
|
/// and permutation indexing maps.
|
|
///
|
|
/// `ConcreteType` should provide methods with signatures
|
|
///
|
|
/// ```c++
|
|
/// bool matchIndexingMaps(GenericOp genericOp) const;
|
|
/// RegionComputationFn getRegionComputeFn(GenericOp) const;
|
|
/// ```
|
|
///
|
|
/// The latter inspects the region and returns the computation inside as a
|
|
/// functor. The functor will be invoked with constant elements for all inputs
|
|
/// and should return the corresponding computea constant element for output.
|
|
template <typename ConcreteType>
|
|
class FoldConstantBase : public OpRewritePattern<GenericOp> {
|
|
public:
|
|
struct APIntOrFloat {
|
|
Optional<APInt> apInt;
|
|
Optional<APFloat> apFloat;
|
|
};
|
|
struct APIntOrFloatArray {
|
|
SmallVector<APInt> apInts;
|
|
SmallVector<APFloat> apFloats;
|
|
};
|
|
using RegionComputationFn =
|
|
std::function<APIntOrFloat(const APIntOrFloatArray &)>;
|
|
|
|
FoldConstantBase(MLIRContext *context,
|
|
const ControlElementwiseOpsFusionFn &controlFn,
|
|
PatternBenefit benefit = 1)
|
|
: OpRewritePattern<GenericOp>(context, benefit), controlFn(controlFn) {}
|
|
|
|
LogicalResult matchAndRewrite(GenericOp genericOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (genericOp.hasBufferSemantics())
|
|
return failure();
|
|
|
|
// Only support ops generating one output for now.
|
|
if (genericOp.getNumOutputs() != 1)
|
|
return failure();
|
|
|
|
auto outputType = genericOp.getResultTypes().front().dyn_cast<ShapedType>();
|
|
// Require the output types to be static give we are generating constants.
|
|
if (!outputType || !outputType.hasStaticShape())
|
|
return failure();
|
|
|
|
if (!llvm::all_of(genericOp.getInputOperands(), [](OpOperand *operand) {
|
|
return operand->get().getType().isa<ShapedType>();
|
|
}))
|
|
return failure();
|
|
|
|
// Make sure all element types are the same.
|
|
auto getOperandElementType = [](OpOperand *operand) {
|
|
return operand->get().getType().cast<ShapedType>().getElementType();
|
|
};
|
|
if (!llvm::is_splat(llvm::map_range(genericOp.getInputAndOutputOperands(),
|
|
getOperandElementType)))
|
|
return failure();
|
|
|
|
// We can only handle the case where we have int/float elements.
|
|
auto elementType = outputType.getElementType();
|
|
if (!elementType.isIntOrFloat())
|
|
return failure();
|
|
|
|
// Require all indexing maps to be permutations for now. This is common and
|
|
// it simplifies input/output access greatly: we can do the data shuffling
|
|
// entirely in the compiler, without needing to turn all indices into
|
|
// Values, and then do affine apply on them, and then match back the
|
|
// constant again.
|
|
if (!llvm::all_of(genericOp.getIndexingMaps(),
|
|
[](AffineMap map) { return map.isPermutation(); }))
|
|
return failure();
|
|
|
|
for (OpOperand *operand : genericOp.getOutputOperands()) {
|
|
if (genericOp.payloadUsesValueFromOperand(operand))
|
|
return failure();
|
|
}
|
|
|
|
// Further check the indexing maps are okay for the ConcreteType.
|
|
if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(genericOp))
|
|
return failure();
|
|
|
|
// Defer to the concrete type to check the region and discover the
|
|
// computation inside.
|
|
RegionComputationFn computeFn =
|
|
static_cast<const ConcreteType *>(this)->getRegionComputeFn(genericOp);
|
|
if (!computeFn)
|
|
return failure();
|
|
|
|
// All inputs should be constants.
|
|
int numInputs = genericOp.getNumInputs();
|
|
SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs);
|
|
for (const auto &operand : llvm::enumerate(genericOp.getInputOperands())) {
|
|
if (!matchPattern(operand.value()->get(),
|
|
m_Constant(&inputValues[operand.index()])))
|
|
return failure();
|
|
}
|
|
|
|
// Identified this as a potential candidate for folding. Now check the
|
|
// policy to see whether we are allowed to proceed.
|
|
for (int i = 0; i < numInputs; ++i) {
|
|
OpOperand *consumer = genericOp.getInputOperand(i);
|
|
OpResult producer = consumer->get().cast<OpResult>();
|
|
if (!controlFn(producer, *consumer))
|
|
return failure();
|
|
}
|
|
|
|
auto linalgOp = cast<LinalgOp>(genericOp.getOperation());
|
|
SmallVector<int64_t, 4> loopBounds = linalgOp.computeStaticLoopSizes();
|
|
int64_t numElements = outputType.getNumElements();
|
|
|
|
// Use APInt/APFloat instead of Attribute here for constructing the output.
|
|
// This helps to avoid blowing up compiler memory usage: Attributes would
|
|
// unify the following cases but they have lifetime as the MLIRContext.
|
|
SmallVector<APInt> intOutputValues;
|
|
SmallVector<APFloat> fpOutputValues;
|
|
if (elementType.template isa<FloatType>())
|
|
fpOutputValues.resize(numElements, APFloat(0.f));
|
|
else
|
|
intOutputValues.resize(numElements);
|
|
|
|
// Return the constant dim positions from the given permutation map.
|
|
auto getDimPositions = [](AffineMap map) {
|
|
SmallVector<unsigned> dims;
|
|
dims.reserve(map.getNumResults());
|
|
for (AffineExpr result : map.getResults()) {
|
|
dims.push_back(result.cast<AffineDimExpr>().getPosition());
|
|
}
|
|
return dims;
|
|
};
|
|
|
|
SmallVector<SmallVector<unsigned>> inputDims;
|
|
for (int i = 0; i < numInputs; ++i)
|
|
inputDims.push_back(getDimPositions(genericOp.getIndexingMaps()[i]));
|
|
auto outputDims = getDimPositions(genericOp.getIndexingMaps().back());
|
|
auto outputShape = outputType.getShape();
|
|
|
|
// Allocate small vectors for index delinearization. Initial values do not
|
|
// matter here as they will be overwritten later.
|
|
SmallVector<uint64_t> indices(loopBounds.size(), 0);
|
|
SmallVector<uint64_t> dstIndices(loopBounds.size(), 0);
|
|
SmallVector<SmallVector<uint64_t>> srcIndices(
|
|
numInputs, SmallVector<uint64_t>(loopBounds.size(), 0));
|
|
SmallVector<uint64_t> srcLinearIndices(numInputs, 0);
|
|
uint64_t dstLinearIndex = 0;
|
|
|
|
// Allocate spaces for compute function inputs. Initial values do not matter
|
|
// here as they will be overwritten later.
|
|
APIntOrFloatArray computeFnInputs;
|
|
|
|
auto inputShapes = llvm::to_vector<4>(
|
|
llvm::map_range(genericOp.getInputOperands(), [](OpOperand *operand) {
|
|
return operand->get().getType().cast<ShapedType>().getShape();
|
|
}));
|
|
|
|
// Given a `linearIndex`, remap it to a linear index to access linalg op
|
|
// inputs/ouputs. This mutates `indices`, `srcIndices`, `dstIndices`,
|
|
// `srcLinearIndices`, `dstLinearIndex` in place.
|
|
auto computeRemappedLinearIndex = [&](int linearIndex) {
|
|
int totalCount = linearIndex;
|
|
for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
|
|
indices[dim] = totalCount % loopBounds[dim];
|
|
totalCount /= loopBounds[dim];
|
|
}
|
|
|
|
for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
|
|
for (int i = 0; i < numInputs; ++i)
|
|
srcIndices[i][dim] = indices[inputDims[i][dim]];
|
|
dstIndices[dim] = indices[outputDims[dim]];
|
|
}
|
|
|
|
dstLinearIndex = dstIndices.front();
|
|
for (int i = 0; i < numInputs; ++i)
|
|
srcLinearIndices[i] = srcIndices[i].front();
|
|
|
|
for (int dim = 1; dim < outputType.getRank(); ++dim) {
|
|
dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
|
|
for (int i = 0; i < numInputs; ++i)
|
|
srcLinearIndices[i] =
|
|
srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim];
|
|
}
|
|
};
|
|
|
|
bool isFloat = elementType.isa<FloatType>();
|
|
if (isFloat) {
|
|
SmallVector<DenseElementsAttr::iterator_range<APFloat>> inFpRanges;
|
|
for (int i = 0; i < numInputs; ++i)
|
|
inFpRanges.push_back(inputValues[i].getValues<APFloat>());
|
|
|
|
computeFnInputs.apFloats.resize(numInputs, APFloat(0.f));
|
|
|
|
// Transpose the input constant. Because we don't know its rank in
|
|
// advance, we need to loop over the range [0, element count) and
|
|
// delinearize the index.
|
|
for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
|
|
computeRemappedLinearIndex(linearIndex);
|
|
|
|
// Collect constant elements for all inputs at this loop iteration.
|
|
for (int i = 0; i < numInputs; ++i)
|
|
computeFnInputs.apFloats[i] = inFpRanges[i][srcLinearIndices[i]];
|
|
|
|
// Invoke the computation to get the corresponding constant output
|
|
// element.
|
|
fpOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apFloat;
|
|
}
|
|
} else {
|
|
SmallVector<DenseElementsAttr::iterator_range<APInt>> inIntRanges;
|
|
for (int i = 0; i < numInputs; ++i)
|
|
inIntRanges.push_back(inputValues[i].getValues<APInt>());
|
|
|
|
computeFnInputs.apInts.resize(numInputs);
|
|
|
|
// Transpose the input constant. Because we don't know its rank in
|
|
// advance, we need to loop over the range [0, element count) and
|
|
// delinearize the index.
|
|
for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
|
|
computeRemappedLinearIndex(linearIndex);
|
|
|
|
// Collect constant elements for all inputs at this loop iteration.
|
|
for (int i = 0; i < numInputs; ++i)
|
|
computeFnInputs.apInts[i] = inIntRanges[i][srcLinearIndices[i]];
|
|
|
|
// Invoke the computation to get the corresponding constant output
|
|
// element.
|
|
intOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apInt;
|
|
}
|
|
}
|
|
|
|
DenseElementsAttr outputAttr =
|
|
isFloat ? DenseElementsAttr::get(outputType, fpOutputValues)
|
|
: DenseElementsAttr::get(outputType, intOutputValues);
|
|
|
|
rewriter.replaceOpWithNewOp<arith::ConstantOp>(genericOp, outputAttr);
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
ControlElementwiseOpsFusionFn controlFn;
|
|
};
|
|
|
|
// Folds linalg.generic ops that are actually transposes on constant values.
|
|
struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
|
|
using FoldConstantBase::FoldConstantBase;
|
|
|
|
bool matchIndexingMaps(GenericOp genericOp) const {
|
|
// We should have one input and one output.
|
|
return genericOp.getIndexingMaps().size() == 2;
|
|
}
|
|
|
|
RegionComputationFn getRegionComputeFn(GenericOp genericOp) const {
|
|
// Make sure the region only contains a yield op.
|
|
Block &body = genericOp.region().front();
|
|
if (!llvm::hasSingleElement(body))
|
|
return nullptr;
|
|
auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
|
|
if (!yieldOp)
|
|
return nullptr;
|
|
|
|
// The yield op should return the block argument corresponds to the input.
|
|
for (Value yieldVal : yieldOp.values()) {
|
|
auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
|
|
if (!yieldArg || yieldArg.getOwner() != &body)
|
|
return nullptr;
|
|
if (yieldArg.getArgNumber() != 0)
|
|
return nullptr;
|
|
}
|
|
|
|
// No computation; just return the orginal value.
|
|
return [](const APIntOrFloatArray &inputs) {
|
|
if (inputs.apFloats.empty())
|
|
return APIntOrFloat{inputs.apInts.front(), llvm::None};
|
|
return APIntOrFloat{llvm::None, inputs.apFloats.front()};
|
|
};
|
|
}
|
|
|
|
ControlElementwiseOpsFusionFn controlFn;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
//===---------------------------------------------------------------------===//
|
|
// Miscellaneous patterns that help fusion.
|
|
//===---------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
/// Forces `outs` operands of linalg operations to use `linalg.init_tensor` if
|
|
/// the value of the `outs` operand is not used within the op. This is only
|
|
/// implemented for `linalg.generic` operations for now, but should hold for all
|
|
/// linalg structured ops.
|
|
struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
|
|
using OpRewritePattern<GenericOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(GenericOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
rewriter.startRootUpdate(op);
|
|
bool modifiedOutput = false;
|
|
Location loc = op.getLoc();
|
|
for (OpOperand *opOperand : op.getOutputOperands()) {
|
|
if (!op.payloadUsesValueFromOperand(opOperand)) {
|
|
Value operandVal = opOperand->get();
|
|
auto operandType = operandVal.getType().dyn_cast<RankedTensorType>();
|
|
if (!operandType)
|
|
continue;
|
|
|
|
// If outs is sparse, leave it to the sparse compiler.
|
|
if (sparse_tensor::getSparseTensorEncoding(operandVal.getType()))
|
|
continue;
|
|
|
|
// If outs is already an `init_tensor` operation, nothing to do.
|
|
auto definingOp = operandVal.getDefiningOp<InitTensorOp>();
|
|
if (definingOp)
|
|
continue;
|
|
modifiedOutput = true;
|
|
SmallVector<Value> dynamicDims;
|
|
for (const auto &dim : llvm::enumerate(operandType.getShape())) {
|
|
if (dim.value() != ShapedType::kDynamicSize)
|
|
continue;
|
|
dynamicDims.push_back(rewriter.createOrFold<tensor::DimOp>(
|
|
loc, operandVal, dim.index()));
|
|
}
|
|
Value initTensor = rewriter.create<InitTensorOp>(
|
|
loc, dynamicDims, operandType.getShape(),
|
|
operandType.getElementType());
|
|
op->setOperand(opOperand->getOperandNumber(), initTensor);
|
|
}
|
|
}
|
|
if (!modifiedOutput) {
|
|
rewriter.cancelRootUpdate(op);
|
|
return failure();
|
|
}
|
|
rewriter.finalizeRootUpdate(op);
|
|
return success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
//===---------------------------------------------------------------------===//
|
|
// Methods that add patterns described in this file to a pattern list.
|
|
//===---------------------------------------------------------------------===//
|
|
|
|
void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns(
|
|
RewritePatternSet &patterns) {
|
|
patterns.add<
|
|
FoldProducerReshapeOpByLinearization<false, tensor::CollapseShapeOp>,
|
|
FoldProducerReshapeOpByLinearization<false, tensor::ExpandShapeOp>,
|
|
FoldConsumerReshapeOpByLinearization<false, tensor::CollapseShapeOp>>(
|
|
patterns.getContext());
|
|
}
|
|
|
|
void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
|
|
RewritePatternSet &patterns) {
|
|
patterns
|
|
.add<FoldProducerReshapeOpByLinearization<true, tensor::CollapseShapeOp>,
|
|
FoldProducerReshapeOpByLinearization<true, tensor::ExpandShapeOp>,
|
|
FoldConsumerReshapeOpByLinearization<true, tensor::CollapseShapeOp>>(
|
|
patterns.getContext());
|
|
}
|
|
|
|
void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
|
|
RewritePatternSet &patterns,
|
|
const ControlElementwiseOpsFusionFn &controlFoldingReshapes) {
|
|
patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
|
|
controlFoldingReshapes);
|
|
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
|
|
controlFoldingReshapes);
|
|
}
|
|
|
|
void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
|
|
RewritePatternSet &patterns,
|
|
const ControlElementwiseOpsFusionFn &controlFoldingReshapes) {
|
|
patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(),
|
|
controlFoldingReshapes);
|
|
}
|
|
|
|
void mlir::linalg::populateElementwiseOpsFusionPatterns(
|
|
RewritePatternSet &patterns, LinalgElementwiseFusionOptions options) {
|
|
auto *context = patterns.getContext();
|
|
patterns.add<FuseElementwiseOps, FoldScalarOrSplatConstant,
|
|
FoldConstantTranspose>(context,
|
|
options.controlElementwiseOpsFusionFn);
|
|
patterns.add<RemoveOutsDependency>(context);
|
|
populateSparseTensorRewriting(patterns);
|
|
populateFoldReshapeOpsByExpansionPatterns(patterns,
|
|
options.controlFoldingReshapesFn);
|
|
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
|
|
GenericOp::getCanonicalizationPatterns(patterns, context);
|
|
tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
|
|
tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
|
|
context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
|
|
patterns);
|
|
}
|
|
|
|
void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) {
|
|
auto *context = patterns.getContext();
|
|
patterns.add<PushExpandingReshape>(context);
|
|
}
|
|
|
|
//===---------------------------------------------------------------------===//
|
|
// Passes
|
|
//===---------------------------------------------------------------------===//
|
|
|
|
bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
|
|
OpOperand &consumer) {
|
|
if (auto producerCollapseOp =
|
|
dyn_cast<tensor::CollapseShapeOp>(producer.getOwner())) {
|
|
return !isUnitDimExpansionOnly(producerCollapseOp);
|
|
}
|
|
if (auto consumerExpandOp =
|
|
dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
|
|
return !isUnitDimExpansionOnly(consumerExpandOp);
|
|
}
|
|
return true;
|
|
}
|
|
|
|
namespace {
|
|
|
|
/// Pass that fuses generic ops on tensors. Used only for testing.
|
|
struct LinalgElementwiseOpFusionPass
|
|
: public LinalgElementwiseOpFusionBase<LinalgElementwiseOpFusionPass> {
|
|
void runOnOperation() override {
|
|
Operation *op = getOperation();
|
|
RewritePatternSet patterns(op->getContext());
|
|
ControlElementwiseOpsFusionFn allowFoldingFn =
|
|
[](const OpResult &producer, const OpOperand &consumer) {
|
|
return true;
|
|
};
|
|
populateElementwiseOpsFusionPatterns(
|
|
patterns,
|
|
LinalgElementwiseFusionOptions().setControlFoldingReshapes(
|
|
allowFoldingUnitDimReshapes ? allowFoldingFn : skipUnitDimReshape));
|
|
|
|
// Use TopDownTraversal for compile time reasons
|
|
GreedyRewriteConfig grc;
|
|
grc.useTopDownTraversal = true;
|
|
(void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns),
|
|
grc);
|
|
}
|
|
};
|
|
|
|
/// Pass to test folding of reshape ops with generic ops by linearization.
|
|
struct FoldReshapeOpsByLinearizationPass
|
|
: public LinalgFoldReshapeOpsByLinearizationBase<
|
|
FoldReshapeOpsByLinearizationPass> {
|
|
void runOnOperation() override {
|
|
Operation *op = getOperation();
|
|
RewritePatternSet patterns(op->getContext());
|
|
populateFoldReshapeOpsByLinearizationPatterns(patterns);
|
|
if (allowFoldingUnitDimReshapes) {
|
|
populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns);
|
|
}
|
|
(void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
std::unique_ptr<Pass> mlir::createLinalgElementwiseOpFusionPass() {
|
|
return std::make_unique<LinalgElementwiseOpFusionPass>();
|
|
}
|
|
|
|
std::unique_ptr<Pass> mlir::createFoldReshapeOpsByLinearizationPass() {
|
|
return std::make_unique<FoldReshapeOpsByLinearizationPass>();
|
|
}
|