Use type inference when building the TransferWriteOp in the TransferWritePermutationLowering. Previously, the result type has been set to Type() which triggers an assertion if the pattern is used with tensors instead of memrefs. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D118758
278 lines
11 KiB
C++
278 lines
11 KiB
C++
//===- VectorTransferPermutationMapRewritePatterns.cpp - Xfer map rewrite -===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements rewrite patterns for the permutation_map attribute of
|
|
// vector.transfer operations.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
|
|
#include "mlir/Interfaces/VectorInterfaces.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::vector;
|
|
|
|
/// Transpose a vector transfer op's `in_bounds` attribute according to given
|
|
/// indices.
|
|
static ArrayAttr
|
|
transposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
|
|
const SmallVector<unsigned> &permutation) {
|
|
SmallVector<bool> newInBoundsValues;
|
|
for (unsigned pos : permutation)
|
|
newInBoundsValues.push_back(
|
|
attr.getValue()[pos].cast<BoolAttr>().getValue());
|
|
return builder.getBoolArrayAttr(newInBoundsValues);
|
|
}
|
|
|
|
/// Lower transfer_read op with permutation into a transfer_read with a
|
|
/// permutation map composed of leading zeros followed by a minor identiy +
|
|
/// vector.transpose op.
|
|
/// Ex:
|
|
/// vector.transfer_read ...
|
|
/// permutation_map: (d0, d1, d2) -> (0, d1)
|
|
/// into:
|
|
/// %v = vector.transfer_read ...
|
|
/// permutation_map: (d0, d1, d2) -> (d1, 0)
|
|
/// vector.transpose %v, [1, 0]
|
|
///
|
|
/// vector.transfer_read ...
|
|
/// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
|
|
/// into:
|
|
/// %v = vector.transfer_read ...
|
|
/// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
|
|
/// vector.transpose %v, [0, 1, 3, 2, 4]
|
|
/// Note that an alternative is to transform it to linalg.transpose +
|
|
/// vector.transfer_read to do the transpose in memory instead.
|
|
struct TransferReadPermutationLowering
|
|
: public OpRewritePattern<vector::TransferReadOp> {
|
|
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(vector::TransferReadOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
// TODO: support 0-d corner case.
|
|
if (op.getTransferRank() == 0)
|
|
return failure();
|
|
|
|
SmallVector<unsigned> permutation;
|
|
AffineMap map = op.permutation_map();
|
|
if (map.getNumResults() == 0)
|
|
return failure();
|
|
if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
|
|
return failure();
|
|
AffineMap permutationMap =
|
|
map.getPermutationMap(permutation, op.getContext());
|
|
if (permutationMap.isIdentity())
|
|
return failure();
|
|
|
|
permutationMap = map.getPermutationMap(permutation, op.getContext());
|
|
// Caluclate the map of the new read by applying the inverse permutation.
|
|
permutationMap = inversePermutation(permutationMap);
|
|
AffineMap newMap = permutationMap.compose(map);
|
|
// Apply the reverse transpose to deduce the type of the transfer_read.
|
|
ArrayRef<int64_t> originalShape = op.getVectorType().getShape();
|
|
SmallVector<int64_t> newVectorShape(originalShape.size());
|
|
for (const auto &pos : llvm::enumerate(permutation)) {
|
|
newVectorShape[pos.value()] = originalShape[pos.index()];
|
|
}
|
|
|
|
// Transpose mask operand.
|
|
Value newMask;
|
|
if (op.mask()) {
|
|
// Remove unused dims from the permutation map. E.g.:
|
|
// E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, 0, d3, 0, d2)
|
|
// comp = (d0, d1, d2) -> (d2, 0, d1, 0 d0)
|
|
auto comp = compressUnusedDims(map);
|
|
// Get positions of remaining result dims.
|
|
// E.g.: (d0, d1, d2) -> (d2, 0, d1, 0 d0)
|
|
// maskTransposeIndices = [ 2, 1, 0]
|
|
SmallVector<int64_t> maskTransposeIndices;
|
|
for (unsigned i = 0; i < comp.getNumResults(); ++i) {
|
|
if (auto expr = comp.getResult(i).dyn_cast<AffineDimExpr>())
|
|
maskTransposeIndices.push_back(expr.getPosition());
|
|
}
|
|
|
|
newMask = rewriter.create<vector::TransposeOp>(op.getLoc(), op.mask(),
|
|
maskTransposeIndices);
|
|
}
|
|
|
|
// Transpose in_bounds attribute.
|
|
ArrayAttr newInBoundsAttr =
|
|
op.in_bounds() ? transposeInBoundsAttr(
|
|
rewriter, op.in_bounds().getValue(), permutation)
|
|
: ArrayAttr();
|
|
|
|
// Generate new transfer_read operation.
|
|
VectorType newReadType =
|
|
VectorType::get(newVectorShape, op.getVectorType().getElementType());
|
|
Value newRead = rewriter.create<vector::TransferReadOp>(
|
|
op.getLoc(), newReadType, op.source(), op.indices(),
|
|
AffineMapAttr::get(newMap), op.padding(), newMask, newInBoundsAttr);
|
|
|
|
// Transpose result of transfer_read.
|
|
SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
|
|
rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead,
|
|
transposePerm);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Lower transfer_write op with permutation into a transfer_write with a
|
|
/// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
|
|
/// Ex:
|
|
/// vector.transfer_write %v ...
|
|
/// permutation_map: (d0, d1, d2) -> (d2, d0, d1)
|
|
/// into:
|
|
/// %tmp = vector.transpose %v, [2, 0, 1]
|
|
/// vector.transfer_write %tmp ...
|
|
/// permutation_map: (d0, d1, d2) -> (d0, d1, d2)
|
|
///
|
|
/// vector.transfer_write %v ...
|
|
/// permutation_map: (d0, d1, d2, d3) -> (d3, d2)
|
|
/// into:
|
|
/// %tmp = vector.transpose %v, [1, 0]
|
|
/// %v = vector.transfer_write %tmp ...
|
|
/// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
|
|
struct TransferWritePermutationLowering
|
|
: public OpRewritePattern<vector::TransferWriteOp> {
|
|
using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(vector::TransferWriteOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
// TODO: support 0-d corner case.
|
|
if (op.getTransferRank() == 0)
|
|
return failure();
|
|
|
|
SmallVector<unsigned> permutation;
|
|
AffineMap map = op.permutation_map();
|
|
if (map.isMinorIdentity())
|
|
return failure();
|
|
if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
|
|
return failure();
|
|
|
|
// Remove unused dims from the permutation map. E.g.:
|
|
// E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4)
|
|
// comp = (d0, d1, d2) -> (d2, d0, d1)
|
|
auto comp = compressUnusedDims(map);
|
|
// Get positions of remaining result dims.
|
|
SmallVector<int64_t> indices;
|
|
llvm::transform(comp.getResults(), std::back_inserter(indices),
|
|
[](AffineExpr expr) {
|
|
return expr.dyn_cast<AffineDimExpr>().getPosition();
|
|
});
|
|
|
|
// Transpose mask operand.
|
|
Value newMask = op.mask() ? rewriter.create<vector::TransposeOp>(
|
|
op.getLoc(), op.mask(), indices)
|
|
: Value();
|
|
|
|
// Transpose in_bounds attribute.
|
|
ArrayAttr newInBoundsAttr =
|
|
op.in_bounds() ? transposeInBoundsAttr(
|
|
rewriter, op.in_bounds().getValue(), permutation)
|
|
: ArrayAttr();
|
|
|
|
// Generate new transfer_write operation.
|
|
Value newVec =
|
|
rewriter.create<vector::TransposeOp>(op.getLoc(), op.vector(), indices);
|
|
auto newMap = AffineMap::getMinorIdentityMap(
|
|
map.getNumDims(), map.getNumResults(), rewriter.getContext());
|
|
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
|
|
op, newVec, op.source(), op.indices(), AffineMapAttr::get(newMap),
|
|
newMask, newInBoundsAttr);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Lower transfer_read op with broadcast in the leading dimensions into
|
|
/// transfer_read of lower rank + vector.broadcast.
|
|
/// Ex: vector.transfer_read ...
|
|
/// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
|
|
/// into:
|
|
/// %v = vector.transfer_read ...
|
|
/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
|
|
/// vector.broadcast %v
|
|
struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
|
|
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(vector::TransferReadOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
// TODO: support 0-d corner case.
|
|
if (op.getTransferRank() == 0)
|
|
return failure();
|
|
|
|
AffineMap map = op.permutation_map();
|
|
unsigned numLeadingBroadcast = 0;
|
|
for (auto expr : map.getResults()) {
|
|
auto dimExpr = expr.dyn_cast<AffineConstantExpr>();
|
|
if (!dimExpr || dimExpr.getValue() != 0)
|
|
break;
|
|
numLeadingBroadcast++;
|
|
}
|
|
// If there are no leading zeros in the map there is nothing to do.
|
|
if (numLeadingBroadcast == 0)
|
|
return failure();
|
|
VectorType originalVecType = op.getVectorType();
|
|
unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast;
|
|
// Calculate new map, vector type and masks without the leading zeros.
|
|
AffineMap newMap = AffineMap::get(
|
|
map.getNumDims(), 0, map.getResults().take_back(reducedShapeRank),
|
|
op.getContext());
|
|
// Only remove the leading zeros if the rest of the map is a minor identity
|
|
// with broadasting. Otherwise we first want to permute the map.
|
|
if (!newMap.isMinorIdentityWithBroadcasting())
|
|
return failure();
|
|
|
|
// TODO: support zero-dimension vectors natively. See:
|
|
// https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097.
|
|
// In the meantime, lower these to a scalar load when they pop up.
|
|
if (reducedShapeRank == 0) {
|
|
Value newRead;
|
|
if (op.getShapedType().isa<TensorType>()) {
|
|
newRead = rewriter.create<tensor::ExtractOp>(op.getLoc(), op.source(),
|
|
op.indices());
|
|
} else {
|
|
newRead = rewriter.create<memref::LoadOp>(
|
|
op.getLoc(), originalVecType.getElementType(), op.source(),
|
|
op.indices());
|
|
}
|
|
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
|
|
newRead);
|
|
return success();
|
|
}
|
|
SmallVector<int64_t> newShape = llvm::to_vector<4>(
|
|
originalVecType.getShape().take_back(reducedShapeRank));
|
|
// Vector rank cannot be zero. Handled by TransferReadToVectorLoadLowering.
|
|
if (newShape.empty())
|
|
return failure();
|
|
VectorType newReadType =
|
|
VectorType::get(newShape, originalVecType.getElementType());
|
|
ArrayAttr newInBoundsAttr =
|
|
op.in_bounds()
|
|
? rewriter.getArrayAttr(
|
|
op.in_boundsAttr().getValue().take_back(reducedShapeRank))
|
|
: ArrayAttr();
|
|
Value newRead = rewriter.create<vector::TransferReadOp>(
|
|
op.getLoc(), newReadType, op.source(), op.indices(),
|
|
AffineMapAttr::get(newMap), op.padding(), op.mask(), newInBoundsAttr);
|
|
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
|
|
newRead);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
|
|
RewritePatternSet &patterns) {
|
|
patterns.add<TransferReadPermutationLowering,
|
|
TransferWritePermutationLowering, TransferOpReduceRank>(
|
|
patterns.getContext());
|
|
}
|