[mlir][vector] NFC: Expose castAwayContractionLeadingOneDim
This commit exposes the transformation behind the pattern. It is useful for more targeted application on a specific op for once. Reviewed By: kuhar Differential Revision: https://reviews.llvm.org/D148758
This commit is contained in:
parent
ca554ad7c2
commit
eca7698a97
@ -44,6 +44,7 @@ enum class AtomicRMWKind : uint64_t;
|
||||
} // namespace arith
|
||||
|
||||
namespace vector {
|
||||
class ContractionOp;
|
||||
class TransferReadOp;
|
||||
class TransferWriteOp;
|
||||
class VectorDialect;
|
||||
@ -76,6 +77,11 @@ void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
|
||||
PatternBenefit benefit = 1);
|
||||
|
||||
/// Cast away the leading unit dim, if exists, for the given contract op.
|
||||
/// Return success if the transformation applies; return failure otherwise.
|
||||
LogicalResult castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
|
||||
RewriterBase &rewriter);
|
||||
|
||||
/// Collect a set of leading one dimension removal patterns.
|
||||
///
|
||||
/// These patterns insert vector.shape_cast to remove leading one dimensions
|
||||
|
@ -279,6 +279,121 @@ struct CastAwayTransferWriteLeadingOneDim
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult
|
||||
mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
|
||||
RewriterBase &rewriter) {
|
||||
VectorType oldAccType = contractOp.getAccType().dyn_cast<VectorType>();
|
||||
if (oldAccType == nullptr)
|
||||
return failure();
|
||||
if (oldAccType.getRank() < 2)
|
||||
return failure();
|
||||
if (oldAccType.getShape()[0] != 1)
|
||||
return failure();
|
||||
// currently we support only dropping one dim but the pattern can be applied
|
||||
// greedily to drop more.
|
||||
int64_t dropDim = 1;
|
||||
|
||||
auto oldIndexingMaps = contractOp.getIndexingMapsArray();
|
||||
SmallVector<AffineMap> newIndexingMaps;
|
||||
|
||||
auto oldIteratorTypes = contractOp.getIteratorTypes();
|
||||
SmallVector<Attribute> newIteratorTypes;
|
||||
|
||||
int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
|
||||
|
||||
if (!isParallelIterator(oldIteratorTypes[dimToDrop]))
|
||||
// only parallel type iterators can be dropped.
|
||||
return failure();
|
||||
|
||||
for (const auto &it : llvm::enumerate(oldIteratorTypes)) {
|
||||
int64_t currDim = it.index();
|
||||
if (currDim == dimToDrop)
|
||||
continue;
|
||||
newIteratorTypes.push_back(it.value());
|
||||
}
|
||||
|
||||
SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
|
||||
contractOp.getAcc()};
|
||||
SmallVector<Value> newOperands;
|
||||
|
||||
for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
|
||||
// Check if the dim to be dropped exists as a leading dim in the operand
|
||||
// if it does then we use vector.extract to drop it.
|
||||
bool validExtract = false;
|
||||
SmallVector<AffineExpr> results;
|
||||
auto map = it.value();
|
||||
int64_t orginalZeroDim = it.value().getDimPosition(0);
|
||||
if (orginalZeroDim != dimToDrop) {
|
||||
// There are two reasons to be in this path, 1. We need to
|
||||
// tranpose the operand to make the dim to be dropped
|
||||
// leading. 2. The dim to be dropped does not exist and in
|
||||
// that case we dont want to add a unit tranpose but we must
|
||||
// check all the indices to make sure this is the case.
|
||||
bool tranposeNeeded = false;
|
||||
SmallVector<int64_t> perm;
|
||||
SmallVector<AffineExpr> transposeResults;
|
||||
|
||||
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
|
||||
int64_t currDim = map.getDimPosition(i);
|
||||
if (currDim == dimToDrop) {
|
||||
tranposeNeeded = true;
|
||||
perm.insert(perm.begin(), i);
|
||||
auto targetExpr = rewriter.getAffineDimExpr(currDim);
|
||||
transposeResults.insert(transposeResults.begin(), targetExpr);
|
||||
} else {
|
||||
perm.push_back(i);
|
||||
auto targetExpr = rewriter.getAffineDimExpr(currDim);
|
||||
transposeResults.push_back(targetExpr);
|
||||
}
|
||||
}
|
||||
// Do the tranpose now if needed so that we can drop the
|
||||
// correct dim using extract later.
|
||||
if (tranposeNeeded) {
|
||||
map = AffineMap::get(map.getNumDims(), 0, transposeResults,
|
||||
contractOp.getContext());
|
||||
operands[it.index()] = rewriter.create<vector::TransposeOp>(
|
||||
contractOp.getLoc(), operands[it.index()], perm);
|
||||
}
|
||||
}
|
||||
// We have taken care to have the dim to be dropped be
|
||||
// the leading dim. If its still not leading that means it
|
||||
// does not exist in this operand and hence we do not need
|
||||
// an extract.
|
||||
if (map.getDimPosition(0) == dimToDrop)
|
||||
validExtract = true;
|
||||
|
||||
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
|
||||
int64_t currDim = map.getDimPosition(i);
|
||||
if (currDim == dimToDrop)
|
||||
// This is the dim we are dropping.
|
||||
continue;
|
||||
auto targetExpr = rewriter.getAffineDimExpr(
|
||||
currDim < dimToDrop ? currDim : currDim - 1);
|
||||
results.push_back(targetExpr);
|
||||
}
|
||||
newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results,
|
||||
contractOp.getContext()));
|
||||
// Extract if its a valid extraction, otherwise use the operand
|
||||
// without extraction.
|
||||
newOperands.push_back(
|
||||
validExtract ? rewriter.create<vector::ExtractOp>(contractOp.getLoc(),
|
||||
operands[it.index()],
|
||||
splatZero(dropDim))
|
||||
: operands[it.index()]);
|
||||
}
|
||||
auto newContractOp = rewriter.create<vector::ContractionOp>(
|
||||
contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2],
|
||||
rewriter.getAffineMapArrayAttr(newIndexingMaps),
|
||||
rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
|
||||
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
|
||||
contractOp, contractOp->getResultTypes()[0], newContractOp);
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
/// Turns vector.contract on vector with leading 1 dimensions into
|
||||
/// vector.extract followed by vector.contract on vector without leading
|
||||
/// 1 dimensions. Also performs tranpose of lhs and rhs operands if required
|
||||
@ -289,112 +404,7 @@ struct CastAwayContractionLeadingOneDim
|
||||
|
||||
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
VectorType oldAccType = contractOp.getAccType().dyn_cast<VectorType>();
|
||||
if (oldAccType == nullptr)
|
||||
return failure();
|
||||
if (oldAccType.getRank() < 2)
|
||||
return failure();
|
||||
if (oldAccType.getShape()[0] != 1)
|
||||
return failure();
|
||||
// currently we support only dropping one dim but the pattern can be applied
|
||||
// greedily to drop more.
|
||||
int64_t dropDim = 1;
|
||||
|
||||
auto oldIndexingMaps = contractOp.getIndexingMapsArray();
|
||||
SmallVector<AffineMap> newIndexingMaps;
|
||||
|
||||
auto oldIteratorTypes = contractOp.getIteratorTypes();
|
||||
SmallVector<Attribute> newIteratorTypes;
|
||||
|
||||
int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
|
||||
|
||||
if (!isParallelIterator(oldIteratorTypes[dimToDrop]))
|
||||
// only parallel type iterators can be dropped.
|
||||
return failure();
|
||||
|
||||
for (const auto &it : llvm::enumerate(oldIteratorTypes)) {
|
||||
int64_t currDim = it.index();
|
||||
if (currDim == dimToDrop)
|
||||
continue;
|
||||
newIteratorTypes.push_back(it.value());
|
||||
}
|
||||
|
||||
SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
|
||||
contractOp.getAcc()};
|
||||
SmallVector<Value> newOperands;
|
||||
|
||||
for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
|
||||
// Check if the dim to be dropped exists as a leading dim in the operand
|
||||
// if it does then we use vector.extract to drop it.
|
||||
bool validExtract = false;
|
||||
SmallVector<AffineExpr> results;
|
||||
auto map = it.value();
|
||||
int64_t orginalZeroDim = it.value().getDimPosition(0);
|
||||
if (orginalZeroDim != dimToDrop) {
|
||||
// There are two reasons to be in this path, 1. We need to
|
||||
// tranpose the operand to make the dim to be dropped
|
||||
// leading. 2. The dim to be dropped does not exist and in
|
||||
// that case we dont want to add a unit tranpose but we must
|
||||
// check all the indices to make sure this is the case.
|
||||
bool tranposeNeeded = false;
|
||||
SmallVector<int64_t> perm;
|
||||
SmallVector<AffineExpr> transposeResults;
|
||||
|
||||
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
|
||||
int64_t currDim = map.getDimPosition(i);
|
||||
if (currDim == dimToDrop) {
|
||||
tranposeNeeded = true;
|
||||
perm.insert(perm.begin(), i);
|
||||
auto targetExpr = rewriter.getAffineDimExpr(currDim);
|
||||
transposeResults.insert(transposeResults.begin(), targetExpr);
|
||||
} else {
|
||||
perm.push_back(i);
|
||||
auto targetExpr = rewriter.getAffineDimExpr(currDim);
|
||||
transposeResults.push_back(targetExpr);
|
||||
}
|
||||
}
|
||||
// Do the tranpose now if needed so that we can drop the
|
||||
// correct dim using extract later.
|
||||
if (tranposeNeeded) {
|
||||
map = AffineMap::get(map.getNumDims(), 0, transposeResults,
|
||||
contractOp.getContext());
|
||||
operands[it.index()] = rewriter.create<vector::TransposeOp>(
|
||||
contractOp.getLoc(), operands[it.index()], perm);
|
||||
}
|
||||
}
|
||||
// We have taken care to have the dim to be dropped be
|
||||
// the leading dim. If its still not leading that means it
|
||||
// does not exist in this operand and hence we do not need
|
||||
// an extract.
|
||||
if (map.getDimPosition(0) == dimToDrop)
|
||||
validExtract = true;
|
||||
|
||||
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
|
||||
int64_t currDim = map.getDimPosition(i);
|
||||
if (currDim == dimToDrop)
|
||||
// This is the dim we are dropping.
|
||||
continue;
|
||||
auto targetExpr = rewriter.getAffineDimExpr(
|
||||
currDim < dimToDrop ? currDim : currDim - 1);
|
||||
results.push_back(targetExpr);
|
||||
}
|
||||
newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results,
|
||||
contractOp.getContext()));
|
||||
// Extract if its a valid extraction, otherwise use the operand
|
||||
// without extraction.
|
||||
newOperands.push_back(validExtract
|
||||
? rewriter.create<vector::ExtractOp>(
|
||||
contractOp.getLoc(), operands[it.index()],
|
||||
splatZero(dropDim))
|
||||
: operands[it.index()]);
|
||||
}
|
||||
auto newContractOp = rewriter.create<vector::ContractionOp>(
|
||||
contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2],
|
||||
rewriter.getAffineMapArrayAttr(newIndexingMaps),
|
||||
rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
|
||||
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
|
||||
contractOp, contractOp->getResultTypes()[0], newContractOp);
|
||||
return success();
|
||||
return castAwayContractionLeadingOneDim(contractOp, rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user