Thomas Raoux 33d2a780a1 [mlir][linalg] Add pattern to split reduction dimension in a linalg op
This transformation allow to break up a reduction dimension in a
parallel and a reduction dimension. This is followed by a separate
reduction op. This allows to generate tree reduction which is beneficial
on target allowing to take advantage parallelism.

Differential Revision: https://reviews.llvm.org/D122045
2022-03-24 23:22:53 +00:00

235 lines
10 KiB
C++

//===-------- SplitReduction.cpp - Split reduction dimesion ---------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements linalg transformation to break a reduction dimension
// between a parallel and a reduction dimension.
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
using namespace mlir::linalg;
/// Return the identity numeric value associated to the give op.
static Optional<Attribute> getIdentity(Operation *op) {
// Builder only used as helper for attribute creation.
OpBuilder b(op->getContext());
Type resultType = op->getResult(0).getType();
if (auto floatType = resultType.dyn_cast<FloatType>()) {
const llvm::fltSemantics &semantic = floatType.getFloatSemantics();
if (isa<arith::AddFOp>(op))
return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic));
if (isa<arith::MulFOp>(op))
return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1));
if (isa<arith::MaxFOp>(op))
return b.getFloatAttr(resultType,
llvm::APFloat::getLargest(semantic, true));
if (isa<arith::MinFOp>(op))
return b.getFloatAttr(resultType,
llvm::APFloat::getLargest(semantic, true));
return llvm::None;
}
if (isa<arith::AddIOp, arith::OrIOp, arith::XOrIOp>(op))
return b.getIntegerAttr(resultType, 0);
if (isa<arith::AndIOp>(op))
return b.getIntegerAttr(resultType, -1);
if (isa<arith::MaxSIOp>(op))
return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::min());
if (isa<arith::MinSIOp>(op))
return b.getIntegerAttr(resultType, std::numeric_limits<int64_t>::max());
if (isa<arith::MulIOp>(op))
return b.getIntegerAttr(resultType, 1);
return llvm::None;
}
FailureOr<LinalgOp>
mlir::linalg::splitReduction(PatternRewriter &b, LinalgOp op,
ControlSplitReductionFn controlSplitReductionFn,
LinalgTransformationFilter filter) {
if (failed(filter.checkAndNotify(b, op)) || !op.hasTensorSemantics() ||
op.getNumReductionLoops() != 1 || op.getNumOutputs() != 1 ||
!op.hasOnlyProjectedPermutations())
return b.notifyMatchFailure(op, "precondition not met");
std::pair<int64_t, unsigned> control = controlSplitReductionFn(op);
int64_t ratio = control.first;
unsigned insertDimIndex = control.second;
if (ratio <= 1)
return b.notifyMatchFailure(op, "split ratio needs to be greater than 1");
SmallVector<unsigned> dims;
op.getReductionDims(dims);
assert(dims.size() == 1);
unsigned reductionDim = dims[0];
Optional<SmallVector<int64_t, 4>> loopRanges = op.getStaticLoopRanges();
if (!loopRanges)
return b.notifyMatchFailure(op, "Cannot analyze loops");
int64_t reductionDimSize = (*loopRanges)[reductionDim];
if (reductionDimSize == ShapedType::kDynamicSize ||
reductionDimSize % ratio != 0 || insertDimIndex >= loopRanges->size())
return b.notifyMatchFailure(
op, "Reduction dimension not divisible by split ratio");
SmallVector<Operation *, 4> combinerOps;
if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) ||
combinerOps.size() != 1)
return b.notifyMatchFailure(op, "Cannot match the reduction pattern");
Operation *reductionOp = combinerOps[0];
Optional<Attribute> identity = getIdentity(reductionOp);
if (!identity)
return b.notifyMatchFailure(op, "Unknown identity value for the redution");
Location loc = op->getLoc();
SmallVector<Value> newInputs;
SmallVector<AffineMap> newMaps;
// Calculate the new shapes and indexing maps of the input operands.
for (OpOperand *operand : op.getInputOperands()) {
AffineMap map = op.getTiedIndexingMap(operand);
SmallVector<int64_t> newShape;
SmallVector<AffineExpr> exprs;
SmallVector<ReassociationIndices> reassociation;
unsigned index = 0;
for (unsigned idx : llvm::seq<unsigned>(0, map.getNumResults())) {
unsigned dim = map.getDimPosition(idx);
if (reductionDim == dim) {
newShape.push_back(ratio);
newShape.push_back(op.getShape(operand)[idx] / ratio);
reassociation.push_back({index++, index++});
exprs.push_back(b.getAffineDimExpr(insertDimIndex));
exprs.push_back(
b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1));
continue;
}
newShape.push_back(op.getShape(operand)[idx]);
exprs.push_back(b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1));
reassociation.push_back({index++});
}
newMaps.push_back(
AffineMap::get(map.getNumDims() + 1, 0, exprs, op.getContext()));
// If the shape is unchanged the input doesn't change.
if (newShape == op.getShape(operand)) {
newInputs.push_back(operand->get());
continue;
}
Type newType = RankedTensorType::get(
newShape,
operand->get().getType().cast<RankedTensorType>().getElementType());
Value newInput = b.create<tensor::ExpandShapeOp>(
loc, newType, operand->get(), reassociation);
newInputs.push_back(newInput);
}
// Calculate the new output map and shape, we insert the new dimension based
// on the index returned by `controlSplitReductionFn`.
SmallVector<int64_t> newOutputShape;
AffineMap oldOutputMap = op.getTiedIndexingMap(op.getOutputOperand(0));
ArrayRef<int64_t> oldShape = op.getShape(op.getOutputOperand(0));
SmallVector<AffineExpr> outputExpr;
for (unsigned idx :
llvm::seq<unsigned>(0, oldOutputMap.getNumResults() + 1)) {
if (idx == insertDimIndex) {
newOutputShape.push_back(ratio);
outputExpr.push_back(b.getAffineDimExpr(insertDimIndex));
continue;
}
unsigned oldDim = idx < insertDimIndex ? idx : idx - 1;
newOutputShape.push_back(oldShape[oldDim]);
unsigned dim = oldOutputMap.getDimPosition(oldDim);
outputExpr.push_back(
b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1));
}
Value initTensor = b.create<linalg::InitTensorOp>(
loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
Value identityTensor =
b.create<linalg::FillOp>(op->getLoc(), constantOp, initTensor)
.getResult(0);
newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr,
op.getContext()));
SmallVector<StringRef> newIteratorTypes;
for (auto &it : llvm::enumerate(op.iterator_types())) {
if (insertDimIndex == it.index())
newIteratorTypes.push_back(getParallelIteratorTypeName());
newIteratorTypes.push_back(it.value().cast<StringAttr>().getValue());
}
// Create the new op matching the original op with an extra parallel
// dimension.
GenericOp genericOp = b.create<GenericOp>(
loc, TypeRange({initTensor.getType()}), newInputs,
ValueRange({identityTensor}), newMaps, newIteratorTypes);
b.inlineRegionBefore(op->getRegion(0), genericOp.region(),
genericOp.region().begin());
// Then create a new reduction that only reduce the newly added dimension from
// the previous op.
unsigned intermRank = newOutputShape.size();
AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
SmallVector<Value> outputOperands = op.getOutputOperands();
SmallVector<StringRef> reductionIteratorTypes;
SmallVector<AffineExpr> exprs;
for (unsigned i : llvm::seq<unsigned>(0, intermRank)) {
if (insertDimIndex == i) {
reductionIteratorTypes.push_back(getReductionIteratorTypeName());
} else {
exprs.push_back(b.getAffineDimExpr(i));
reductionIteratorTypes.push_back(getParallelIteratorTypeName());
}
}
AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext());
SmallVector<AffineMap> reductionMaps = {inputMap, outputMap};
auto reduction = b.create<GenericOp>(
loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}),
outputOperands, reductionMaps, reductionIteratorTypes,
[reductionOp](OpBuilder &b, Location loc, ValueRange inputs) {
Operation *clonedReductionOp = b.clone(*reductionOp);
clonedReductionOp->setOperand(0, inputs[0]);
clonedReductionOp->setOperand(1, inputs[1]);
b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
});
b.replaceOp(op, reduction.getResults());
filter.replaceLinalgTransformationFilter(b, genericOp);
filter.replaceLinalgTransformationFilter(b, reduction);
return cast<LinalgOp>(genericOp.getOperation());
}
namespace {
struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> {
/// Construct a generic pattern applied to all LinalgOp that verify `filter`.
LinalgSplitReduction(MLIRContext *context,
ControlSplitReductionFn controlSplitReductionFn,
LinalgTransformationFilter f, PatternBenefit benefit = 1)
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
controlSplitReductionFn(controlSplitReductionFn), filter(std::move(f)) {
}
LogicalResult matchAndRewrite(LinalgOp op,
PatternRewriter &rewriter) const override {
return splitReduction(rewriter, op, controlSplitReductionFn, filter);
}
private:
ControlSplitReductionFn controlSplitReductionFn;
LinalgTransformationFilter filter;
};
} // namespace
void linalg::populateSplitReductionPattern(
RewritePatternSet &patterns,
ControlSplitReductionFn controlSplitReductionFn,
LinalgTransformationFilter f) {
patterns.add<LinalgSplitReduction>(patterns.getContext(),
controlSplitReductionFn, f);
}