[mlir][vector] NFC: Expose castAwayContractionLeadingOneDim

This commit exposes the transformation behind the pattern.
It is useful for more targeted application on a specific op
for once.

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D148758
This commit is contained in:
Lei Zhang 2023-04-21 09:41:01 -07:00
parent ca554ad7c2
commit eca7698a97
2 changed files with 122 additions and 106 deletions

View File

@ -44,6 +44,7 @@ enum class AtomicRMWKind : uint64_t;
} // namespace arith } // namespace arith
namespace vector { namespace vector {
class ContractionOp;
class TransferReadOp; class TransferReadOp;
class TransferWriteOp; class TransferWriteOp;
class VectorDialect; class VectorDialect;
@ -76,6 +77,11 @@ void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns,
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, void populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1); PatternBenefit benefit = 1);
/// Cast away the leading unit dim, if exists, for the given contract op.
/// Return success if the transformation applies; return failure otherwise.
LogicalResult castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
RewriterBase &rewriter);
/// Collect a set of leading one dimension removal patterns. /// Collect a set of leading one dimension removal patterns.
/// ///
/// These patterns insert vector.shape_cast to remove leading one dimensions /// These patterns insert vector.shape_cast to remove leading one dimensions

View File

@ -279,6 +279,121 @@ struct CastAwayTransferWriteLeadingOneDim
} }
}; };
} // namespace
LogicalResult
mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
RewriterBase &rewriter) {
VectorType oldAccType = contractOp.getAccType().dyn_cast<VectorType>();
if (oldAccType == nullptr)
return failure();
if (oldAccType.getRank() < 2)
return failure();
if (oldAccType.getShape()[0] != 1)
return failure();
// currently we support only dropping one dim but the pattern can be applied
// greedily to drop more.
int64_t dropDim = 1;
auto oldIndexingMaps = contractOp.getIndexingMapsArray();
SmallVector<AffineMap> newIndexingMaps;
auto oldIteratorTypes = contractOp.getIteratorTypes();
SmallVector<Attribute> newIteratorTypes;
int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
if (!isParallelIterator(oldIteratorTypes[dimToDrop]))
// only parallel type iterators can be dropped.
return failure();
for (const auto &it : llvm::enumerate(oldIteratorTypes)) {
int64_t currDim = it.index();
if (currDim == dimToDrop)
continue;
newIteratorTypes.push_back(it.value());
}
SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
contractOp.getAcc()};
SmallVector<Value> newOperands;
for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
// Check if the dim to be dropped exists as a leading dim in the operand
// if it does then we use vector.extract to drop it.
bool validExtract = false;
SmallVector<AffineExpr> results;
auto map = it.value();
int64_t orginalZeroDim = it.value().getDimPosition(0);
if (orginalZeroDim != dimToDrop) {
// There are two reasons to be in this path, 1. We need to
// tranpose the operand to make the dim to be dropped
// leading. 2. The dim to be dropped does not exist and in
// that case we dont want to add a unit tranpose but we must
// check all the indices to make sure this is the case.
bool tranposeNeeded = false;
SmallVector<int64_t> perm;
SmallVector<AffineExpr> transposeResults;
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
int64_t currDim = map.getDimPosition(i);
if (currDim == dimToDrop) {
tranposeNeeded = true;
perm.insert(perm.begin(), i);
auto targetExpr = rewriter.getAffineDimExpr(currDim);
transposeResults.insert(transposeResults.begin(), targetExpr);
} else {
perm.push_back(i);
auto targetExpr = rewriter.getAffineDimExpr(currDim);
transposeResults.push_back(targetExpr);
}
}
// Do the tranpose now if needed so that we can drop the
// correct dim using extract later.
if (tranposeNeeded) {
map = AffineMap::get(map.getNumDims(), 0, transposeResults,
contractOp.getContext());
operands[it.index()] = rewriter.create<vector::TransposeOp>(
contractOp.getLoc(), operands[it.index()], perm);
}
}
// We have taken care to have the dim to be dropped be
// the leading dim. If its still not leading that means it
// does not exist in this operand and hence we do not need
// an extract.
if (map.getDimPosition(0) == dimToDrop)
validExtract = true;
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
int64_t currDim = map.getDimPosition(i);
if (currDim == dimToDrop)
// This is the dim we are dropping.
continue;
auto targetExpr = rewriter.getAffineDimExpr(
currDim < dimToDrop ? currDim : currDim - 1);
results.push_back(targetExpr);
}
newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results,
contractOp.getContext()));
// Extract if its a valid extraction, otherwise use the operand
// without extraction.
newOperands.push_back(
validExtract ? rewriter.create<vector::ExtractOp>(contractOp.getLoc(),
operands[it.index()],
splatZero(dropDim))
: operands[it.index()]);
}
auto newContractOp = rewriter.create<vector::ContractionOp>(
contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2],
rewriter.getAffineMapArrayAttr(newIndexingMaps),
rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
contractOp, contractOp->getResultTypes()[0], newContractOp);
return success();
}
namespace {
/// Turns vector.contract on vector with leading 1 dimensions into /// Turns vector.contract on vector with leading 1 dimensions into
/// vector.extract followed by vector.contract on vector without leading /// vector.extract followed by vector.contract on vector without leading
/// 1 dimensions. Also performs tranpose of lhs and rhs operands if required /// 1 dimensions. Also performs tranpose of lhs and rhs operands if required
@ -289,112 +404,7 @@ struct CastAwayContractionLeadingOneDim
LogicalResult matchAndRewrite(vector::ContractionOp contractOp, LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
VectorType oldAccType = contractOp.getAccType().dyn_cast<VectorType>(); return castAwayContractionLeadingOneDim(contractOp, rewriter);
if (oldAccType == nullptr)
return failure();
if (oldAccType.getRank() < 2)
return failure();
if (oldAccType.getShape()[0] != 1)
return failure();
// currently we support only dropping one dim but the pattern can be applied
// greedily to drop more.
int64_t dropDim = 1;
auto oldIndexingMaps = contractOp.getIndexingMapsArray();
SmallVector<AffineMap> newIndexingMaps;
auto oldIteratorTypes = contractOp.getIteratorTypes();
SmallVector<Attribute> newIteratorTypes;
int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
if (!isParallelIterator(oldIteratorTypes[dimToDrop]))
// only parallel type iterators can be dropped.
return failure();
for (const auto &it : llvm::enumerate(oldIteratorTypes)) {
int64_t currDim = it.index();
if (currDim == dimToDrop)
continue;
newIteratorTypes.push_back(it.value());
}
SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
contractOp.getAcc()};
SmallVector<Value> newOperands;
for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
// Check if the dim to be dropped exists as a leading dim in the operand
// if it does then we use vector.extract to drop it.
bool validExtract = false;
SmallVector<AffineExpr> results;
auto map = it.value();
int64_t orginalZeroDim = it.value().getDimPosition(0);
if (orginalZeroDim != dimToDrop) {
// There are two reasons to be in this path, 1. We need to
// tranpose the operand to make the dim to be dropped
// leading. 2. The dim to be dropped does not exist and in
// that case we dont want to add a unit tranpose but we must
// check all the indices to make sure this is the case.
bool tranposeNeeded = false;
SmallVector<int64_t> perm;
SmallVector<AffineExpr> transposeResults;
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
int64_t currDim = map.getDimPosition(i);
if (currDim == dimToDrop) {
tranposeNeeded = true;
perm.insert(perm.begin(), i);
auto targetExpr = rewriter.getAffineDimExpr(currDim);
transposeResults.insert(transposeResults.begin(), targetExpr);
} else {
perm.push_back(i);
auto targetExpr = rewriter.getAffineDimExpr(currDim);
transposeResults.push_back(targetExpr);
}
}
// Do the tranpose now if needed so that we can drop the
// correct dim using extract later.
if (tranposeNeeded) {
map = AffineMap::get(map.getNumDims(), 0, transposeResults,
contractOp.getContext());
operands[it.index()] = rewriter.create<vector::TransposeOp>(
contractOp.getLoc(), operands[it.index()], perm);
}
}
// We have taken care to have the dim to be dropped be
// the leading dim. If its still not leading that means it
// does not exist in this operand and hence we do not need
// an extract.
if (map.getDimPosition(0) == dimToDrop)
validExtract = true;
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
int64_t currDim = map.getDimPosition(i);
if (currDim == dimToDrop)
// This is the dim we are dropping.
continue;
auto targetExpr = rewriter.getAffineDimExpr(
currDim < dimToDrop ? currDim : currDim - 1);
results.push_back(targetExpr);
}
newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results,
contractOp.getContext()));
// Extract if its a valid extraction, otherwise use the operand
// without extraction.
newOperands.push_back(validExtract
? rewriter.create<vector::ExtractOp>(
contractOp.getLoc(), operands[it.index()],
splatZero(dropDim))
: operands[it.index()]);
}
auto newContractOp = rewriter.create<vector::ContractionOp>(
contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2],
rewriter.getAffineMapArrayAttr(newIndexingMaps),
rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
contractOp, contractOp->getResultTypes()[0], newContractOp);
return success();
} }
}; };