llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp
Christopher Bate 9f1221521f Recommit "[mlir][vector] Allow unroll of contraction in arbitrary order"
Fixed issue with vector.contract default unroll permutation.

Adds support for vector unroll transformations to unroll in different
orders. For example, the vector.contract can be unrolled into a
smaller set of contractions. There is a choice of how to unroll the
decomposition based on the traversal order of (dim0, dim1, dim2).
The choice of traversal order can now be specified by a callback which
given by the caller of the transform. For now, only the
vector.contract, vector.transfer_read/transfer_write operations
support the callback.

Differential Revision: https://reviews.llvm.org/D127004
2022-06-09 14:01:19 -06:00

831 lines
35 KiB
C++

//===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns to do vector unrolling and vector distribution.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/Interfaces/VectorInterfaces.h"
#include "mlir/Support/MathExtras.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include <numeric>
#define DEBUG_TYPE "vector-unrolling"
using namespace mlir;
using namespace mlir::vector;
/// During unrolling from `originalShape` to `targetShape` return the offset for
/// the slice `index`.
static SmallVector<int64_t, 4> getVectorOffset(ArrayRef<int64_t> originalShape,
ArrayRef<int64_t> targetShape,
int64_t index) {
SmallVector<int64_t, 4> dstSliceStrides =
computeStrides(originalShape, targetShape);
SmallVector<int64_t, 4> vectorOffsets = delinearize(dstSliceStrides, index);
SmallVector<int64_t, 4> elementOffsets =
computeElementOffsetsFromVectorSliceOffsets(targetShape, vectorOffsets);
return elementOffsets;
}
/// A functor that accomplishes the same thing as `getVectorOffset` but allows
/// for reordering the traversal of the dimensions. The order of traversal is
/// given in "for loop order" (outer to inner).
namespace {
class DecomposeShapeIterator {
private:
SmallVector<int64_t, 4> vectorShape;
SmallVector<int64_t> loopOrder;
SmallVector<int64_t> sliceStrides;
int64_t maxIndexVal{1};
public:
DecomposeShapeIterator(ArrayRef<int64_t> originalShape,
ArrayRef<int64_t> targetShape,
ArrayRef<int64_t> loopOrder)
: vectorShape(targetShape.begin(), targetShape.end()),
loopOrder(loopOrder.begin(), loopOrder.end()),
sliceStrides(originalShape.size()) {
assert(originalShape.size() == targetShape.size());
assert(loopOrder.size() == targetShape.size());
// Compute the count for each dimension.
SmallVector<int64_t> sliceDimCounts(originalShape.size());
for (unsigned r = 0; r < originalShape.size(); ++r) {
sliceDimCounts[r] = ceilDiv(originalShape[r], targetShape[r]);
maxIndexVal *= sliceDimCounts[r];
}
// Reversing "loop order" gives dimensions from fastest varying to slowest
// varying (smallest stride to largest stride).
int64_t accum = 1;
for (auto idx : llvm::reverse(loopOrder)) {
sliceStrides[idx] = accum;
accum *= sliceDimCounts[idx];
}
}
// Turn the linear index into a d-tuple based on units of vectors of size
// `vectorShape`. The linear index is assumed to represent traversal of the
// dimensions based on `order`.
SmallVector<int64_t> delinearize(int64_t index) const {
// Traverse in for loop order (largest stride to smallest stride).
SmallVector<int64_t> vectorOffsets(sliceStrides.size());
for (auto idx : loopOrder) {
vectorOffsets[idx] = index / sliceStrides[idx];
index %= sliceStrides[idx];
}
return vectorOffsets;
}
int64_t maxIndex() const { return maxIndexVal; }
/// Return the offset within d-tuple based on the ordering given by
/// `loopOrder`.
SmallVector<int64_t> getVectorOffset(int64_t index) const {
SmallVector<int64_t> vectorOffsets = delinearize(index);
SmallVector<int64_t> elementOffsets =
computeElementOffsetsFromVectorSliceOffsets(vectorShape, vectorOffsets);
return elementOffsets;
}
};
} // namespace
/// Compute the indices of the slice `index` for a tranfer op.
static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
ArrayRef<Value> indices,
AffineMap permutationMap,
Location loc,
OpBuilder &builder) {
MLIRContext *ctx = builder.getContext();
auto isBroadcast = [](AffineExpr expr) {
if (auto constExpr = expr.dyn_cast<AffineConstantExpr>())
return constExpr.getValue() == 0;
return false;
};
// Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
SmallVector<Value> slicedIndices(indices.begin(), indices.end());
for (const auto &dim : llvm::enumerate(permutationMap.getResults())) {
if (isBroadcast(dim.value()))
continue;
unsigned pos = dim.value().cast<AffineDimExpr>().getPosition();
auto expr = getAffineDimExpr(0, builder.getContext()) +
getAffineConstantExpr(elementOffsets[dim.index()], ctx);
auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
slicedIndices[pos] = builder.create<AffineApplyOp>(loc, map, indices[pos]);
}
return slicedIndices;
}
// Clones `op` into a new operations that takes `operands` and returns
// `resultTypes`.
static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
Operation *op,
ArrayRef<Value> operands,
ArrayRef<Type> resultTypes) {
return builder.create(loc, op->getName().getIdentifier(), operands,
resultTypes, op->getAttrs());
}
/// Return the target shape for unrolling for the given `op`. Return llvm::None
/// if the op shouldn't be or cannot be unrolled.
static Optional<SmallVector<int64_t, 4>>
getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) {
if (options.filterConstraint && failed(options.filterConstraint(op)))
return llvm::None;
assert(options.nativeShape &&
"vector unrolling expects the native shape or native"
"shape call back function to be set");
auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
if (!unrollableVectorOp)
return llvm::None;
auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
if (!maybeUnrollShape)
return llvm::None;
Optional<SmallVector<int64_t, 4>> targetShape = options.nativeShape(op);
if (!targetShape)
return llvm::None;
auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, *targetShape);
if (!maybeShapeRatio ||
llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; }))
return llvm::None;
return targetShape;
}
static SmallVector<int64_t>
getUnrollOrder(unsigned numLoops, Operation *op,
const vector::UnrollVectorOptions &options) {
SmallVector<int64_t> loopOrder =
llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops)));
if (options.traversalOrderCallback != nullptr) {
Optional<SmallVector<int64_t>> order = options.traversalOrderCallback(op);
if (order.hasValue()) {
loopOrder = std::move(*order);
}
}
return loopOrder;
}
namespace {
struct UnrollTransferReadPattern
: public OpRewritePattern<vector::TransferReadOp> {
UnrollTransferReadPattern(MLIRContext *context,
const vector::UnrollVectorOptions &options)
: OpRewritePattern<vector::TransferReadOp>(context, /*benefit=*/1),
options(options) {}
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
PatternRewriter &rewriter) const override {
// TODO: support 0-d corner case.
if (readOp.getTransferRank() == 0)
return failure();
if (readOp.getMask())
return failure();
auto targetShape = getTargetShape(options, readOp);
if (!targetShape)
return failure();
auto sourceVectorType = readOp.getVectorType();
SmallVector<int64_t, 4> strides(targetShape->size(), 1);
Location loc = readOp.getLoc();
ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
// Prepare the result vector;
Value result = rewriter.create<arith::ConstantOp>(
loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
auto targetType =
VectorType::get(*targetShape, sourceVectorType.getElementType());
SmallVector<Value, 4> originalIndices(readOp.getIndices().begin(),
readOp.getIndices().end());
SmallVector<int64_t> loopOrder =
getUnrollOrder(originalSize.size(), readOp, options);
DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
loopOrder);
for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
SmallVector<int64_t, 4> elementOffsets =
indexToOffsets.getVectorOffset(i);
SmallVector<Value, 4> indices =
sliceTransferIndices(elementOffsets, originalIndices,
readOp.getPermutationMap(), loc, rewriter);
auto slicedRead = rewriter.create<vector::TransferReadOp>(
loc, targetType, readOp.getSource(), indices,
readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
readOp.getInBoundsAttr());
result = rewriter.create<vector::InsertStridedSliceOp>(
loc, slicedRead, result, elementOffsets, strides);
}
rewriter.replaceOp(readOp, result);
return success();
}
private:
vector::UnrollVectorOptions options;
};
struct UnrollTransferWritePattern
: public OpRewritePattern<vector::TransferWriteOp> {
UnrollTransferWritePattern(MLIRContext *context,
const vector::UnrollVectorOptions &options)
: OpRewritePattern<vector::TransferWriteOp>(context, /*benefit=*/1),
options(options) {}
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
PatternRewriter &rewriter) const override {
// TODO: support 0-d corner case.
if (writeOp.getTransferRank() == 0)
return failure();
if (writeOp.getMask())
return failure();
auto targetShape = getTargetShape(options, writeOp);
if (!targetShape)
return failure();
auto sourceVectorType = writeOp.getVectorType();
SmallVector<int64_t, 4> strides(targetShape->size(), 1);
Location loc = writeOp.getLoc();
ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
SmallVector<Value, 4> originalIndices(writeOp.getIndices().begin(),
writeOp.getIndices().end());
SmallVector<int64_t> loopOrder =
getUnrollOrder(originalSize.size(), writeOp, options);
DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
loopOrder);
Value resultTensor;
for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
SmallVector<int64_t, 4> elementOffsets =
indexToOffsets.getVectorOffset(i);
Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
SmallVector<Value, 4> indices =
sliceTransferIndices(elementOffsets, originalIndices,
writeOp.getPermutationMap(), loc, rewriter);
Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(),
indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
// For the tensor case update the destination for the next transfer write.
if (!slicedWrite->getResults().empty())
resultTensor = slicedWrite->getResult(0);
}
if (resultTensor)
rewriter.replaceOp(writeOp, resultTensor);
else
rewriter.eraseOp(writeOp);
return success();
}
private:
vector::UnrollVectorOptions options;
};
struct OffsetMapInfo {
static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; }
static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; }
static unsigned getHashValue(const SmallVector<int64_t> &v) {
return static_cast<unsigned>(llvm::hash_combine_range(v.begin(), v.end()));
}
static bool isEqual(const SmallVector<int64_t> &lhs,
const SmallVector<int64_t> &rhs) {
return lhs == rhs;
}
};
struct UnrollContractionPattern
: public OpRewritePattern<vector::ContractionOp> {
UnrollContractionPattern(MLIRContext *context,
const vector::UnrollVectorOptions &options)
: OpRewritePattern<vector::ContractionOp>(context, /*benefit=*/1),
options(options) {}
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {
auto targetShape = getTargetShape(options, contractOp);
if (!targetShape)
return failure();
auto dstVecType = contractOp.getResultType().cast<VectorType>();
SmallVector<int64_t, 4> originalSize = *contractOp.getShapeForUnroll();
Location loc = contractOp.getLoc();
unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
AffineMap dstAffineMap = contractOp.getIndexingMaps()[accIndex];
llvm::MapVector<
SmallVector<int64_t>, Value,
llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
accCache;
SmallVector<int64_t> loopOrder = getUnrollOrder(
contractOp.getIteratorTypes().size(), contractOp, options);
DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
loopOrder);
const int64_t sliceCount = indexToOffsets.maxIndex();
for (int64_t i = 0; i < sliceCount; i++) {
SmallVector<int64_t, 4> offsets = indexToOffsets.getVectorOffset(i);
SmallVector<Value, 4> slicesOperands(contractOp.getNumOperands());
// Helper to coompute the new shape of each operand and extract the slice.
auto extractOperand = [&](unsigned index, Value operand,
AffineMap permutationMap,
ArrayRef<int64_t> operandOffets) {
SmallVector<int64_t> operandShape = applyPermutationMap(
permutationMap, ArrayRef<int64_t>(*targetShape));
SmallVector<int64_t, 4> operandStrides(operandOffets.size(), 1);
slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>(
loc, operand, operandOffets, operandShape, operandStrides);
};
// Extract the new lhs operand.
AffineMap lhsPermutationMap = contractOp.getIndexingMaps()[0];
SmallVector<int64_t> lhsOffets =
applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
// If there is a mask associated to lhs, extract it as well.
if (slicesOperands.size() > 3)
extractOperand(3, contractOp.getMasks()[0], lhsPermutationMap,
lhsOffets);
// Extract the new rhs operand.
AffineMap rhsPermutationMap = contractOp.getIndexingMaps()[1];
SmallVector<int64_t> rhsOffets =
applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
// If there is a mask associated to rhs, extract it as well.
if (slicesOperands.size() > 4)
extractOperand(4, contractOp.getMasks()[1], rhsPermutationMap,
rhsOffets);
AffineMap accPermutationMap = contractOp.getIndexingMaps()[2];
SmallVector<int64_t> accOffets =
applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
// If a version of the accumulator has already been computed, use it
// otherwise extract the first version from the original operand.
auto accIt = accCache.find(accOffets);
if (accIt != accCache.end())
slicesOperands[2] = accIt->second;
else
extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
SmallVector<int64_t> dstShape =
applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape));
auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
Operation *newOp = cloneOpWithOperandsAndTypes(
rewriter, loc, contractOp, slicesOperands, targetType);
SmallVector<int64_t> dstOffets =
applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets));
// Save the accumulated value untill all the loops are unrolled since
// reduction loop keep updating the accumulator.
accCache[dstOffets] = newOp->getResult(0);
}
// Assemble back the accumulator into a single vector.
Value result = rewriter.create<arith::ConstantOp>(
loc, dstVecType, rewriter.getZeroAttr(dstVecType));
for (const auto &it : accCache) {
SmallVector<int64_t> dstStrides(it.first.size(), 1);
result = rewriter.create<vector::InsertStridedSliceOp>(
loc, it.second, result, it.first, dstStrides);
}
rewriter.replaceOp(contractOp, result);
return success();
}
private:
vector::UnrollVectorOptions options;
};
struct UnrollMultiReductionPattern
: public OpRewritePattern<vector::MultiDimReductionOp> {
UnrollMultiReductionPattern(MLIRContext *context,
const vector::UnrollVectorOptions &options)
: OpRewritePattern<vector::MultiDimReductionOp>(context, /*benefit=*/1),
options(options) {}
LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
PatternRewriter &rewriter) const override {
Optional<SmallVector<int64_t, 4>> targetShape =
getTargetShape(options, reductionOp);
if (!targetShape)
return failure();
SmallVector<int64_t, 4> originalSize = *reductionOp.getShapeForUnroll();
SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
llvm::MapVector<
SmallVector<int64_t>, Value,
llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
accCache;
// Compute shape ratio of 'shape' and 'sizes'.
int64_t sliceCount = computeMaxLinearIndex(ratio);
Location loc = reductionOp.getLoc();
for (int64_t i = 0; i < sliceCount; i++) {
SmallVector<int64_t, 4> offsets =
getVectorOffset(originalSize, *targetShape, i);
SmallVector<int64_t, 4> operandStrides(offsets.size(), 1);
Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
loc, reductionOp.getOperand(), offsets, *targetShape, operandStrides);
SmallVector<int64_t> dstShape;
SmallVector<int64_t> destOffset;
for (size_t i : llvm::seq(size_t(0), targetShape->size())) {
if (!reductionOp.isReducedDim(i)) {
destOffset.push_back(offsets[i]);
dstShape.push_back((*targetShape)[i]);
}
}
auto targetType = VectorType::get(
dstShape, reductionOp.getSourceVectorType().getElementType());
Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp,
slicedOperand, targetType);
Value result = newOp->getResult(0);
// Save the accumulated value until all the loops are unrolled since
// reduction loop keeps updating the accumulator.
auto accIt = accCache.find(destOffset);
if (accIt != accCache.end())
result = makeArithReduction(rewriter, loc, reductionOp.getKind(),
result, accIt->second);
accCache[destOffset] = result;
}
// Assemble back the accumulator into a single vector.
Value result = rewriter.create<arith::ConstantOp>(
loc, reductionOp.getDestType(),
rewriter.getZeroAttr(reductionOp.getDestType()));
for (const auto &it : accCache) {
SmallVector<int64_t> dstStrides(it.first.size(), 1);
result = rewriter.create<vector::InsertStridedSliceOp>(
loc, it.second, result, it.first, dstStrides);
}
rewriter.replaceOp(reductionOp, result);
return success();
}
private:
vector::UnrollVectorOptions options;
};
struct UnrollElementwisePattern : public RewritePattern {
UnrollElementwisePattern(MLIRContext *context,
const vector::UnrollVectorOptions &options)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
options(options) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
return failure();
auto targetShape = getTargetShape(options, op);
if (!targetShape)
return failure();
auto dstVecType = op->getResult(0).getType().cast<VectorType>();
SmallVector<int64_t, 4> originalSize =
*cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
int64_t sliceCount = computeMaxLinearIndex(ratio);
Location loc = op->getLoc();
// Prepare the result vector.
Value result = rewriter.create<arith::ConstantOp>(
loc, dstVecType, rewriter.getZeroAttr(dstVecType));
SmallVector<int64_t, 4> strides(targetShape->size(), 1);
VectorType newVecType =
VectorType::get(*targetShape, dstVecType.getElementType());
for (int64_t i = 0; i < sliceCount; i++) {
SmallVector<int64_t, 4> offsets =
getVectorOffset(originalSize, *targetShape, i);
SmallVector<Value, 4> extractOperands;
for (OpOperand &operand : op->getOpOperands()) {
auto vecType = operand.get().getType().template dyn_cast<VectorType>();
if (!vecType) {
extractOperands.push_back(operand.get());
continue;
}
extractOperands.push_back(
rewriter.create<vector::ExtractStridedSliceOp>(
loc, operand.get(), offsets, *targetShape, strides));
}
Operation *newOp = cloneOpWithOperandsAndTypes(
rewriter, loc, op, extractOperands, newVecType);
result = rewriter.create<vector::InsertStridedSliceOp>(
loc, newOp->getResult(0), result, offsets, strides);
}
rewriter.replaceOp(op, result);
return success();
}
private:
vector::UnrollVectorOptions options;
};
/// Canonicalize an extract_map using the result of a pointwise operation.
/// Transforms:
/// %v = arith.addf %a, %b : vector32xf32>
/// %dv = vector.extract_map %v[%id] : vector<32xf32> to vector<1xf32>
/// to:
/// %da = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
/// %db = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
/// %dv = arith.addf %da, %db : vector<1xf32>
struct PointwiseExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
PatternRewriter &rewriter) const override {
Operation *definedOp = extract.getVector().getDefiningOp();
if (!definedOp || !OpTrait::hasElementwiseMappableTraits(definedOp) ||
definedOp->getNumResults() != 1)
return failure();
Location loc = extract.getLoc();
SmallVector<Value, 4> extractOperands;
for (OpOperand &operand : definedOp->getOpOperands()) {
auto vecType = operand.get().getType().template dyn_cast<VectorType>();
if (!vecType) {
extractOperands.push_back(operand.get());
continue;
}
extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
loc,
VectorType::get(extract.getResultType().getShape(),
vecType.getElementType()),
operand.get(), extract.getIds()));
}
Operation *newOp = cloneOpWithOperandsAndTypes(
rewriter, loc, definedOp, extractOperands, extract.getResultType());
rewriter.replaceOp(extract, newOp->getResult(0));
return success();
}
};
/// Canonicalize an extract_map using the result of a contract operation.
/// This propagate the extract_map to operands.
struct ContractExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
PatternRewriter &rewriter) const override {
Operation *definedOp = extract.getVector().getDefiningOp();
auto contract = dyn_cast_or_null<vector::ContractionOp>(definedOp);
if (!contract)
return failure();
Location loc = contract.getLoc();
unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
AffineMap affineMap = contract.getIndexingMaps()[accIndex];
// Create a map of the dimensions distributed based on the acc affine map.
// Only parallel dimensions are being distributed, reduction dimensions are
// untouched.
DenseMap<int64_t, int64_t> map;
for (unsigned i : llvm::seq(unsigned(0), affineMap.getNumResults()))
map[affineMap.getDimPosition(i)] = extract.getResultType().getDimSize(i);
SmallVector<Value, 4> extractOperands;
for (const auto &it : llvm::enumerate(contract.getIndexingMaps())) {
// For each operands calculate the new vector type after distribution.
Value operand = contract->getOperand(it.index());
auto vecType = operand.getType().cast<VectorType>();
SmallVector<int64_t> operandShape(vecType.getShape().begin(),
vecType.getShape().end());
for (unsigned i : llvm::seq(unsigned(0), it.value().getNumResults())) {
unsigned dim = it.value().getDimPosition(i);
auto distributedDim = map.find(dim);
// If the dimension is not in the map it means it is a reduction and
// doesn't get distributed.
if (distributedDim == map.end())
continue;
operandShape[i] = distributedDim->second;
}
VectorType newVecType =
VectorType::get(operandShape, vecType.getElementType());
extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
loc, newVecType, operand, extract.getIds()));
}
Operation *newOp =
cloneOpWithOperandsAndTypes(rewriter, loc, definedOp, extractOperands,
extract.getResult().getType());
rewriter.replaceOp(extract, newOp->getResult(0));
return success();
}
};
/// Converts TransferRead op used by ExtractMap op into a smaller dimension
/// TransferRead.
/// Example:
/// ```
/// %a = vector.transfer_read %A[%c0, %c0, %c0], %cf0:
/// memref<64x64x64xf32>, vector<64x4x32xf32>
/// %e = vector.extract_map %a[%id] : vector<64x4x32xf32> to vector<2x4x1xf32>
/// ```
/// to:
/// ```
/// %id1 = affine.apply affine_map<()[s0] -> (s0 * 2)> (%id)
/// %e = vector.transfer_read %A[%id1, %c0, %id1], %cf0 :
/// memref<64x64x64xf32>, vector<2x4x1xf32>
/// ```
struct TransferReadExtractPattern
: public OpRewritePattern<vector::TransferReadOp> {
TransferReadExtractPattern(MLIRContext *context)
: OpRewritePattern<vector::TransferReadOp>(context) {}
LogicalResult matchAndRewrite(vector::TransferReadOp read,
PatternRewriter &rewriter) const override {
// TODO: support 0-d corner case.
if (read.getTransferRank() == 0)
return failure();
if (!read.getResult().hasOneUse())
return failure();
auto extract =
dyn_cast<vector::ExtractMapOp>(*read.getResult().getUsers().begin());
if (!extract)
return failure();
if (read.getMask())
return failure();
SmallVector<Value, 4> indices(read.getIndices().begin(),
read.getIndices().end());
AffineMap indexMap = extract.map().compose(read.getPermutationMap());
unsigned idCount = 0;
ImplicitLocOpBuilder lb(read.getLoc(), rewriter);
for (auto it :
llvm::zip(indexMap.getResults(), extract.map().getResults())) {
AffineExpr d0, d1;
bindDims(read.getContext(), d0, d1);
auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
if (!indexExpr)
continue;
unsigned indexPos = indexExpr.getPosition();
unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
auto scale = getAffineConstantExpr(
extract.getResultType().getDimSize(vectorPos), read.getContext());
indices[indexPos] = makeComposedAffineApply(
rewriter, read.getLoc(), d0 + scale * d1,
{indices[indexPos], extract.getIds()[idCount++]});
}
Value newRead = lb.create<vector::TransferReadOp>(
extract.getType(), read.getSource(), indices,
read.getPermutationMapAttr(), read.getPadding(), read.getMask(),
read.getInBoundsAttr());
Value dest = lb.create<arith::ConstantOp>(
read.getType(), rewriter.getZeroAttr(read.getType()));
newRead = lb.create<vector::InsertMapOp>(newRead, dest, extract.getIds());
rewriter.replaceOp(read, newRead);
return success();
}
};
struct TransferWriteInsertPattern
: public OpRewritePattern<vector::TransferWriteOp> {
TransferWriteInsertPattern(MLIRContext *context)
: OpRewritePattern<vector::TransferWriteOp>(context) {}
LogicalResult matchAndRewrite(vector::TransferWriteOp write,
PatternRewriter &rewriter) const override {
// TODO: support 0-d corner case.
if (write.getTransferRank() == 0)
return failure();
auto insert = write.getVector().getDefiningOp<vector::InsertMapOp>();
if (!insert)
return failure();
if (write.getMask())
return failure();
SmallVector<Value, 4> indices(write.getIndices().begin(),
write.getIndices().end());
AffineMap indexMap = insert.map().compose(write.getPermutationMap());
unsigned idCount = 0;
Location loc = write.getLoc();
for (auto it :
llvm::zip(indexMap.getResults(), insert.map().getResults())) {
AffineExpr d0, d1;
bindDims(write.getContext(), d0, d1);
auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
if (!indexExpr)
continue;
unsigned indexPos = indexExpr.getPosition();
unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
auto scale = getAffineConstantExpr(
insert.getSourceVectorType().getDimSize(vectorPos),
write.getContext());
indices[indexPos] = makeComposedAffineApply(
rewriter, loc, d0 + scale * d1,
{indices[indexPos], insert.getIds()[idCount++]});
}
rewriter.create<vector::TransferWriteOp>(
loc, insert.getVector(), write.getSource(), indices,
write.getPermutationMapAttr(), write.getInBoundsAttr());
rewriter.eraseOp(write);
return success();
}
};
struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
UnrollReductionPattern(MLIRContext *context,
const vector::UnrollVectorOptions &options)
: OpRewritePattern<vector::ReductionOp>(context, /*benefit=*/1),
options(options) {}
LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
PatternRewriter &rewriter) const override {
Optional<SmallVector<int64_t, 4>> targetShape =
getTargetShape(options, reductionOp);
if (!targetShape)
return failure();
SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
int64_t ratio = (*shapeRatio(originalSize, *targetShape))[0];
// Create unrolled vector reduction.
Location loc = reductionOp.getLoc();
Value accumulator = nullptr;
for (int64_t i = 0; i < ratio; ++i) {
SmallVector<int64_t> offsets =
getVectorOffset(originalSize, *targetShape, i);
SmallVector<int64_t> strides(offsets.size(), 1);
Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
loc, reductionOp.getVector(), offsets, *targetShape, strides);
Operation *newOp = cloneOpWithOperandsAndTypes(
rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
Value result = newOp->getResult(0);
if (!accumulator) {
// This is the first reduction.
accumulator = result;
} else {
// On subsequent reduction, combine with the accumulator.
accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(),
accumulator, result);
}
}
rewriter.replaceOp(reductionOp, accumulator);
return success();
}
private:
const vector::UnrollVectorOptions options;
};
struct UnrollTranposePattern : public OpRewritePattern<vector::TransposeOp> {
UnrollTranposePattern(MLIRContext *context,
const vector::UnrollVectorOptions &options)
: OpRewritePattern<vector::TransposeOp>(context, /*benefit=*/1),
options(options) {}
LogicalResult matchAndRewrite(vector::TransposeOp tranposeOp,
PatternRewriter &rewriter) const override {
if (tranposeOp.getResultType().getRank() == 0)
return failure();
auto targetShape = getTargetShape(options, tranposeOp);
if (!targetShape)
return failure();
auto originalVectorType = tranposeOp.getResultType();
SmallVector<int64_t, 4> strides(targetShape->size(), 1);
Location loc = tranposeOp.getLoc();
ArrayRef<int64_t> originalSize = originalVectorType.getShape();
SmallVector<int64_t, 4> ratio = *shapeRatio(originalSize, *targetShape);
int64_t sliceCount = computeMaxLinearIndex(ratio);
// Prepare the result vector;
Value result = rewriter.create<arith::ConstantOp>(
loc, originalVectorType, rewriter.getZeroAttr(originalVectorType));
SmallVector<int64_t> permutation;
tranposeOp.getTransp(permutation);
for (int64_t i = 0; i < sliceCount; i++) {
SmallVector<int64_t, 4> elementOffsets =
getVectorOffset(originalSize, *targetShape, i);
SmallVector<int64_t, 4> permutedOffsets(elementOffsets.size());
SmallVector<int64_t, 4> permutedShape(elementOffsets.size());
// Compute the source offsets and shape.
for (auto &indices : llvm::enumerate(permutation)) {
permutedOffsets[indices.value()] = elementOffsets[indices.index()];
permutedShape[indices.value()] = (*targetShape)[indices.index()];
}
Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
loc, tranposeOp.getVector(), permutedOffsets, permutedShape, strides);
Value tranposedSlice =
rewriter.create<vector::TransposeOp>(loc, slicedOperand, permutation);
result = rewriter.create<vector::InsertStridedSliceOp>(
loc, tranposedSlice, result, elementOffsets, strides);
}
rewriter.replaceOp(tranposeOp, result);
return success();
}
private:
vector::UnrollVectorOptions options;
};
} // namespace
void mlir::vector::populateVectorUnrollPatterns(
RewritePatternSet &patterns, const UnrollVectorOptions &options) {
patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
UnrollContractionPattern, UnrollElementwisePattern,
UnrollReductionPattern, UnrollMultiReductionPattern,
UnrollTranposePattern>(patterns.getContext(), options);
}
void mlir::vector::populatePropagateVectorDistributionPatterns(
RewritePatternSet &patterns) {
patterns.add<PointwiseExtractPattern, ContractExtractPattern,
TransferReadExtractPattern, TransferWriteInsertPattern>(
patterns.getContext());
}