This transformation allow to break up a reduction dimension in a parallel and a reduction dimension. This is followed by a separate reduction op. This allows to generate tree reduction which is beneficial on target allowing to take advantage parallelism. Differential Revision: https://reviews.llvm.org/D122045
235 lines
10 KiB
C++
235 lines
10 KiB
C++
//===-------- SplitReduction.cpp - Split reduction dimesion ---------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements linalg transformation to break a reduction dimension
|
|
// between a parallel and a reduction dimension.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Analysis/SliceAnalysis.h"
|
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::linalg;
|
|
|
|
/// Return the identity numeric value associated to the give op.
|
|
static Optional<Attribute> getIdentity(Operation *op) {
|
|
// Builder only used as helper for attribute creation.
|
|
OpBuilder b(op->getContext());
|
|
Type resultType = op->getResult(0).getType();
|
|
if (auto floatType = resultType.dyn_cast<FloatType>()) {
|
|
const llvm::fltSemantics &semantic = floatType.getFloatSemantics();
|
|
if (isa<arith::AddFOp>(op))
|
|
return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic));
|
|
if (isa<arith::MulFOp>(op))
|
|
return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1));
|
|
if (isa<arith::MaxFOp>(op))
|
|
return b.getFloatAttr(resultType,
|
|
llvm::APFloat::getLargest(semantic, true));
|
|
if (isa<arith::MinFOp>(op))
|
|
return b.getFloatAttr(resultType,
|
|
llvm::APFloat::getLargest(semantic, true));
|
|
return llvm::None;
|
|
}
|
|
if (isa<arith::AddIOp, arith::OrIOp, arith::XOrIOp>(op))
|
|
return b.getIntegerAttr(resultType, 0);
|
|
if (isa<arith::AndIOp>(op))
|
|
return b.getIntegerAttr(resultType, -1);
|
|
if (isa<arith::MaxSIOp>(op))
|
|
return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::min());
|
|
if (isa<arith::MinSIOp>(op))
|
|
return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::max());
|
|
if (isa<arith::MulIOp>(op))
|
|
return b.getIntegerAttr(resultType, 1);
|
|
return llvm::None;
|
|
}
|
|
|
|
FailureOr<LinalgOp>
|
|
mlir::linalg::splitReduction(PatternRewriter &b, LinalgOp op,
|
|
ControlSplitReductionFn controlSplitReductionFn,
|
|
LinalgTransformationFilter filter) {
|
|
if (failed(filter.checkAndNotify(b, op)) || !op.hasTensorSemantics() ||
|
|
op.getNumReductionLoops() != 1 || op.getNumOutputs() != 1 ||
|
|
!op.hasOnlyProjectedPermutations())
|
|
return b.notifyMatchFailure(op, "precondition not met");
|
|
std::pair<int64_t, unsigned> control = controlSplitReductionFn(op);
|
|
int64_t ratio = control.first;
|
|
unsigned insertDimIndex = control.second;
|
|
if (ratio <= 1)
|
|
return b.notifyMatchFailure(op, "split ratio needs to be greater than 1");
|
|
SmallVector<unsigned> dims;
|
|
op.getReductionDims(dims);
|
|
assert(dims.size() == 1);
|
|
unsigned reductionDim = dims[0];
|
|
Optional<SmallVector<int64_t, 4>> loopRanges = op.getStaticLoopRanges();
|
|
if (!loopRanges)
|
|
return b.notifyMatchFailure(op, "Cannot analyze loops");
|
|
int64_t reductionDimSize = (*loopRanges)[reductionDim];
|
|
if (reductionDimSize == ShapedType::kDynamicSize ||
|
|
reductionDimSize % ratio != 0 || insertDimIndex >= loopRanges->size())
|
|
return b.notifyMatchFailure(
|
|
op, "Reduction dimension not divisible by split ratio");
|
|
SmallVector<Operation *, 4> combinerOps;
|
|
if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) ||
|
|
combinerOps.size() != 1)
|
|
return b.notifyMatchFailure(op, "Cannot match the reduction pattern");
|
|
Operation *reductionOp = combinerOps[0];
|
|
Optional<Attribute> identity = getIdentity(reductionOp);
|
|
if (!identity)
|
|
return b.notifyMatchFailure(op, "Unknown identity value for the redution");
|
|
|
|
Location loc = op->getLoc();
|
|
SmallVector<Value> newInputs;
|
|
SmallVector<AffineMap> newMaps;
|
|
// Calculate the new shapes and indexing maps of the input operands.
|
|
for (OpOperand *operand : op.getInputOperands()) {
|
|
AffineMap map = op.getTiedIndexingMap(operand);
|
|
SmallVector<int64_t> newShape;
|
|
SmallVector<AffineExpr> exprs;
|
|
SmallVector<ReassociationIndices> reassociation;
|
|
unsigned index = 0;
|
|
for (unsigned idx : llvm::seq<unsigned>(0, map.getNumResults())) {
|
|
unsigned dim = map.getDimPosition(idx);
|
|
if (reductionDim == dim) {
|
|
newShape.push_back(ratio);
|
|
newShape.push_back(op.getShape(operand)[idx] / ratio);
|
|
reassociation.push_back({index++, index++});
|
|
exprs.push_back(b.getAffineDimExpr(insertDimIndex));
|
|
exprs.push_back(
|
|
b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1));
|
|
continue;
|
|
}
|
|
newShape.push_back(op.getShape(operand)[idx]);
|
|
exprs.push_back(b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1));
|
|
reassociation.push_back({index++});
|
|
}
|
|
newMaps.push_back(
|
|
AffineMap::get(map.getNumDims() + 1, 0, exprs, op.getContext()));
|
|
// If the shape is unchanged the input doesn't change.
|
|
if (newShape == op.getShape(operand)) {
|
|
newInputs.push_back(operand->get());
|
|
continue;
|
|
}
|
|
Type newType = RankedTensorType::get(
|
|
newShape,
|
|
operand->get().getType().cast<RankedTensorType>().getElementType());
|
|
Value newInput = b.create<tensor::ExpandShapeOp>(
|
|
loc, newType, operand->get(), reassociation);
|
|
newInputs.push_back(newInput);
|
|
}
|
|
// Calculate the new output map and shape, we insert the new dimension based
|
|
// on the index returned by `controlSplitReductionFn`.
|
|
SmallVector<int64_t> newOutputShape;
|
|
AffineMap oldOutputMap = op.getTiedIndexingMap(op.getOutputOperand(0));
|
|
ArrayRef<int64_t> oldShape = op.getShape(op.getOutputOperand(0));
|
|
SmallVector<AffineExpr> outputExpr;
|
|
for (unsigned idx :
|
|
llvm::seq<unsigned>(0, oldOutputMap.getNumResults() + 1)) {
|
|
if (idx == insertDimIndex) {
|
|
newOutputShape.push_back(ratio);
|
|
outputExpr.push_back(b.getAffineDimExpr(insertDimIndex));
|
|
continue;
|
|
}
|
|
unsigned oldDim = idx < insertDimIndex ? idx : idx - 1;
|
|
newOutputShape.push_back(oldShape[oldDim]);
|
|
unsigned dim = oldOutputMap.getDimPosition(oldDim);
|
|
outputExpr.push_back(
|
|
b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1));
|
|
}
|
|
Value initTensor = b.create<linalg::InitTensorOp>(
|
|
loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
|
|
Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
|
|
Value identityTensor =
|
|
b.create<linalg::FillOp>(op->getLoc(), constantOp, initTensor)
|
|
.getResult(0);
|
|
|
|
newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr,
|
|
op.getContext()));
|
|
SmallVector<StringRef> newIteratorTypes;
|
|
for (auto &it : llvm::enumerate(op.iterator_types())) {
|
|
if (insertDimIndex == it.index())
|
|
newIteratorTypes.push_back(getParallelIteratorTypeName());
|
|
newIteratorTypes.push_back(it.value().cast<StringAttr>().getValue());
|
|
}
|
|
// Create the new op matching the original op with an extra parallel
|
|
// dimension.
|
|
GenericOp genericOp = b.create<GenericOp>(
|
|
loc, TypeRange({initTensor.getType()}), newInputs,
|
|
ValueRange({identityTensor}), newMaps, newIteratorTypes);
|
|
b.inlineRegionBefore(op->getRegion(0), genericOp.region(),
|
|
genericOp.region().begin());
|
|
|
|
// Then create a new reduction that only reduce the newly added dimension from
|
|
// the previous op.
|
|
unsigned intermRank = newOutputShape.size();
|
|
AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
|
|
SmallVector<Value> outputOperands = op.getOutputOperands();
|
|
SmallVector<StringRef> reductionIteratorTypes;
|
|
SmallVector<AffineExpr> exprs;
|
|
for (unsigned i : llvm::seq<unsigned>(0, intermRank)) {
|
|
if (insertDimIndex == i) {
|
|
reductionIteratorTypes.push_back(getReductionIteratorTypeName());
|
|
} else {
|
|
exprs.push_back(b.getAffineDimExpr(i));
|
|
reductionIteratorTypes.push_back(getParallelIteratorTypeName());
|
|
}
|
|
}
|
|
AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext());
|
|
SmallVector<AffineMap> reductionMaps = {inputMap, outputMap};
|
|
|
|
auto reduction = b.create<GenericOp>(
|
|
loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}),
|
|
outputOperands, reductionMaps, reductionIteratorTypes,
|
|
[reductionOp](OpBuilder &b, Location loc, ValueRange inputs) {
|
|
Operation *clonedReductionOp = b.clone(*reductionOp);
|
|
clonedReductionOp->setOperand(0, inputs[0]);
|
|
clonedReductionOp->setOperand(1, inputs[1]);
|
|
b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
|
|
});
|
|
b.replaceOp(op, reduction.getResults());
|
|
filter.replaceLinalgTransformationFilter(b, genericOp);
|
|
filter.replaceLinalgTransformationFilter(b, reduction);
|
|
return cast<LinalgOp>(genericOp.getOperation());
|
|
}
|
|
|
|
namespace {
|
|
|
|
struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> {
|
|
/// Construct a generic pattern applied to all LinalgOp that verify `filter`.
|
|
LinalgSplitReduction(MLIRContext *context,
|
|
ControlSplitReductionFn controlSplitReductionFn,
|
|
LinalgTransformationFilter f, PatternBenefit benefit = 1)
|
|
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
|
|
controlSplitReductionFn(controlSplitReductionFn), filter(std::move(f)) {
|
|
}
|
|
|
|
LogicalResult matchAndRewrite(LinalgOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
return splitReduction(rewriter, op, controlSplitReductionFn, filter);
|
|
}
|
|
|
|
private:
|
|
ControlSplitReductionFn controlSplitReductionFn;
|
|
LinalgTransformationFilter filter;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void linalg::populateSplitReductionPattern(
|
|
RewritePatternSet &patterns,
|
|
ControlSplitReductionFn controlSplitReductionFn,
|
|
LinalgTransformationFilter f) {
|
|
patterns.add<LinalgSplitReduction>(patterns.getContext(),
|
|
controlSplitReductionFn, f);
|
|
}
|