[mlir][vector] NFC: Expose castAwayContractionLeadingOneDim
This commit exposes the transformation behind the pattern. It is useful for more targeted application on a specific op for once. Reviewed By: kuhar Differential Revision: https://reviews.llvm.org/D148758
This commit is contained in:
parent
ca554ad7c2
commit
eca7698a97
@ -44,6 +44,7 @@ enum class AtomicRMWKind : uint64_t;
|
|||||||
} // namespace arith
|
} // namespace arith
|
||||||
|
|
||||||
namespace vector {
|
namespace vector {
|
||||||
|
class ContractionOp;
|
||||||
class TransferReadOp;
|
class TransferReadOp;
|
||||||
class TransferWriteOp;
|
class TransferWriteOp;
|
||||||
class VectorDialect;
|
class VectorDialect;
|
||||||
@ -76,6 +77,11 @@ void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||||||
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
|
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
|
||||||
PatternBenefit benefit = 1);
|
PatternBenefit benefit = 1);
|
||||||
|
|
||||||
|
/// Cast away the leading unit dim, if exists, for the given contract op.
|
||||||
|
/// Return success if the transformation applies; return failure otherwise.
|
||||||
|
LogicalResult castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
|
||||||
|
RewriterBase &rewriter);
|
||||||
|
|
||||||
/// Collect a set of leading one dimension removal patterns.
|
/// Collect a set of leading one dimension removal patterns.
|
||||||
///
|
///
|
||||||
/// These patterns insert vector.shape_cast to remove leading one dimensions
|
/// These patterns insert vector.shape_cast to remove leading one dimensions
|
||||||
|
@ -279,6 +279,121 @@ struct CastAwayTransferWriteLeadingOneDim
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
|
||||||
|
RewriterBase &rewriter) {
|
||||||
|
VectorType oldAccType = contractOp.getAccType().dyn_cast<VectorType>();
|
||||||
|
if (oldAccType == nullptr)
|
||||||
|
return failure();
|
||||||
|
if (oldAccType.getRank() < 2)
|
||||||
|
return failure();
|
||||||
|
if (oldAccType.getShape()[0] != 1)
|
||||||
|
return failure();
|
||||||
|
// currently we support only dropping one dim but the pattern can be applied
|
||||||
|
// greedily to drop more.
|
||||||
|
int64_t dropDim = 1;
|
||||||
|
|
||||||
|
auto oldIndexingMaps = contractOp.getIndexingMapsArray();
|
||||||
|
SmallVector<AffineMap> newIndexingMaps;
|
||||||
|
|
||||||
|
auto oldIteratorTypes = contractOp.getIteratorTypes();
|
||||||
|
SmallVector<Attribute> newIteratorTypes;
|
||||||
|
|
||||||
|
int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
|
||||||
|
|
||||||
|
if (!isParallelIterator(oldIteratorTypes[dimToDrop]))
|
||||||
|
// only parallel type iterators can be dropped.
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
for (const auto &it : llvm::enumerate(oldIteratorTypes)) {
|
||||||
|
int64_t currDim = it.index();
|
||||||
|
if (currDim == dimToDrop)
|
||||||
|
continue;
|
||||||
|
newIteratorTypes.push_back(it.value());
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
|
||||||
|
contractOp.getAcc()};
|
||||||
|
SmallVector<Value> newOperands;
|
||||||
|
|
||||||
|
for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
|
||||||
|
// Check if the dim to be dropped exists as a leading dim in the operand
|
||||||
|
// if it does then we use vector.extract to drop it.
|
||||||
|
bool validExtract = false;
|
||||||
|
SmallVector<AffineExpr> results;
|
||||||
|
auto map = it.value();
|
||||||
|
int64_t orginalZeroDim = it.value().getDimPosition(0);
|
||||||
|
if (orginalZeroDim != dimToDrop) {
|
||||||
|
// There are two reasons to be in this path, 1. We need to
|
||||||
|
// tranpose the operand to make the dim to be dropped
|
||||||
|
// leading. 2. The dim to be dropped does not exist and in
|
||||||
|
// that case we dont want to add a unit tranpose but we must
|
||||||
|
// check all the indices to make sure this is the case.
|
||||||
|
bool tranposeNeeded = false;
|
||||||
|
SmallVector<int64_t> perm;
|
||||||
|
SmallVector<AffineExpr> transposeResults;
|
||||||
|
|
||||||
|
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
|
||||||
|
int64_t currDim = map.getDimPosition(i);
|
||||||
|
if (currDim == dimToDrop) {
|
||||||
|
tranposeNeeded = true;
|
||||||
|
perm.insert(perm.begin(), i);
|
||||||
|
auto targetExpr = rewriter.getAffineDimExpr(currDim);
|
||||||
|
transposeResults.insert(transposeResults.begin(), targetExpr);
|
||||||
|
} else {
|
||||||
|
perm.push_back(i);
|
||||||
|
auto targetExpr = rewriter.getAffineDimExpr(currDim);
|
||||||
|
transposeResults.push_back(targetExpr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Do the tranpose now if needed so that we can drop the
|
||||||
|
// correct dim using extract later.
|
||||||
|
if (tranposeNeeded) {
|
||||||
|
map = AffineMap::get(map.getNumDims(), 0, transposeResults,
|
||||||
|
contractOp.getContext());
|
||||||
|
operands[it.index()] = rewriter.create<vector::TransposeOp>(
|
||||||
|
contractOp.getLoc(), operands[it.index()], perm);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// We have taken care to have the dim to be dropped be
|
||||||
|
// the leading dim. If its still not leading that means it
|
||||||
|
// does not exist in this operand and hence we do not need
|
||||||
|
// an extract.
|
||||||
|
if (map.getDimPosition(0) == dimToDrop)
|
||||||
|
validExtract = true;
|
||||||
|
|
||||||
|
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
|
||||||
|
int64_t currDim = map.getDimPosition(i);
|
||||||
|
if (currDim == dimToDrop)
|
||||||
|
// This is the dim we are dropping.
|
||||||
|
continue;
|
||||||
|
auto targetExpr = rewriter.getAffineDimExpr(
|
||||||
|
currDim < dimToDrop ? currDim : currDim - 1);
|
||||||
|
results.push_back(targetExpr);
|
||||||
|
}
|
||||||
|
newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results,
|
||||||
|
contractOp.getContext()));
|
||||||
|
// Extract if its a valid extraction, otherwise use the operand
|
||||||
|
// without extraction.
|
||||||
|
newOperands.push_back(
|
||||||
|
validExtract ? rewriter.create<vector::ExtractOp>(contractOp.getLoc(),
|
||||||
|
operands[it.index()],
|
||||||
|
splatZero(dropDim))
|
||||||
|
: operands[it.index()]);
|
||||||
|
}
|
||||||
|
auto newContractOp = rewriter.create<vector::ContractionOp>(
|
||||||
|
contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2],
|
||||||
|
rewriter.getAffineMapArrayAttr(newIndexingMaps),
|
||||||
|
rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
|
||||||
|
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
|
||||||
|
contractOp, contractOp->getResultTypes()[0], newContractOp);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
/// Turns vector.contract on vector with leading 1 dimensions into
|
/// Turns vector.contract on vector with leading 1 dimensions into
|
||||||
/// vector.extract followed by vector.contract on vector without leading
|
/// vector.extract followed by vector.contract on vector without leading
|
||||||
/// 1 dimensions. Also performs tranpose of lhs and rhs operands if required
|
/// 1 dimensions. Also performs tranpose of lhs and rhs operands if required
|
||||||
@ -289,112 +404,7 @@ struct CastAwayContractionLeadingOneDim
|
|||||||
|
|
||||||
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
|
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
VectorType oldAccType = contractOp.getAccType().dyn_cast<VectorType>();
|
return castAwayContractionLeadingOneDim(contractOp, rewriter);
|
||||||
if (oldAccType == nullptr)
|
|
||||||
return failure();
|
|
||||||
if (oldAccType.getRank() < 2)
|
|
||||||
return failure();
|
|
||||||
if (oldAccType.getShape()[0] != 1)
|
|
||||||
return failure();
|
|
||||||
// currently we support only dropping one dim but the pattern can be applied
|
|
||||||
// greedily to drop more.
|
|
||||||
int64_t dropDim = 1;
|
|
||||||
|
|
||||||
auto oldIndexingMaps = contractOp.getIndexingMapsArray();
|
|
||||||
SmallVector<AffineMap> newIndexingMaps;
|
|
||||||
|
|
||||||
auto oldIteratorTypes = contractOp.getIteratorTypes();
|
|
||||||
SmallVector<Attribute> newIteratorTypes;
|
|
||||||
|
|
||||||
int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0);
|
|
||||||
|
|
||||||
if (!isParallelIterator(oldIteratorTypes[dimToDrop]))
|
|
||||||
// only parallel type iterators can be dropped.
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
for (const auto &it : llvm::enumerate(oldIteratorTypes)) {
|
|
||||||
int64_t currDim = it.index();
|
|
||||||
if (currDim == dimToDrop)
|
|
||||||
continue;
|
|
||||||
newIteratorTypes.push_back(it.value());
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
|
|
||||||
contractOp.getAcc()};
|
|
||||||
SmallVector<Value> newOperands;
|
|
||||||
|
|
||||||
for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
|
|
||||||
// Check if the dim to be dropped exists as a leading dim in the operand
|
|
||||||
// if it does then we use vector.extract to drop it.
|
|
||||||
bool validExtract = false;
|
|
||||||
SmallVector<AffineExpr> results;
|
|
||||||
auto map = it.value();
|
|
||||||
int64_t orginalZeroDim = it.value().getDimPosition(0);
|
|
||||||
if (orginalZeroDim != dimToDrop) {
|
|
||||||
// There are two reasons to be in this path, 1. We need to
|
|
||||||
// tranpose the operand to make the dim to be dropped
|
|
||||||
// leading. 2. The dim to be dropped does not exist and in
|
|
||||||
// that case we dont want to add a unit tranpose but we must
|
|
||||||
// check all the indices to make sure this is the case.
|
|
||||||
bool tranposeNeeded = false;
|
|
||||||
SmallVector<int64_t> perm;
|
|
||||||
SmallVector<AffineExpr> transposeResults;
|
|
||||||
|
|
||||||
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
|
|
||||||
int64_t currDim = map.getDimPosition(i);
|
|
||||||
if (currDim == dimToDrop) {
|
|
||||||
tranposeNeeded = true;
|
|
||||||
perm.insert(perm.begin(), i);
|
|
||||||
auto targetExpr = rewriter.getAffineDimExpr(currDim);
|
|
||||||
transposeResults.insert(transposeResults.begin(), targetExpr);
|
|
||||||
} else {
|
|
||||||
perm.push_back(i);
|
|
||||||
auto targetExpr = rewriter.getAffineDimExpr(currDim);
|
|
||||||
transposeResults.push_back(targetExpr);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Do the tranpose now if needed so that we can drop the
|
|
||||||
// correct dim using extract later.
|
|
||||||
if (tranposeNeeded) {
|
|
||||||
map = AffineMap::get(map.getNumDims(), 0, transposeResults,
|
|
||||||
contractOp.getContext());
|
|
||||||
operands[it.index()] = rewriter.create<vector::TransposeOp>(
|
|
||||||
contractOp.getLoc(), operands[it.index()], perm);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// We have taken care to have the dim to be dropped be
|
|
||||||
// the leading dim. If its still not leading that means it
|
|
||||||
// does not exist in this operand and hence we do not need
|
|
||||||
// an extract.
|
|
||||||
if (map.getDimPosition(0) == dimToDrop)
|
|
||||||
validExtract = true;
|
|
||||||
|
|
||||||
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
|
|
||||||
int64_t currDim = map.getDimPosition(i);
|
|
||||||
if (currDim == dimToDrop)
|
|
||||||
// This is the dim we are dropping.
|
|
||||||
continue;
|
|
||||||
auto targetExpr = rewriter.getAffineDimExpr(
|
|
||||||
currDim < dimToDrop ? currDim : currDim - 1);
|
|
||||||
results.push_back(targetExpr);
|
|
||||||
}
|
|
||||||
newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results,
|
|
||||||
contractOp.getContext()));
|
|
||||||
// Extract if its a valid extraction, otherwise use the operand
|
|
||||||
// without extraction.
|
|
||||||
newOperands.push_back(validExtract
|
|
||||||
? rewriter.create<vector::ExtractOp>(
|
|
||||||
contractOp.getLoc(), operands[it.index()],
|
|
||||||
splatZero(dropDim))
|
|
||||||
: operands[it.index()]);
|
|
||||||
}
|
|
||||||
auto newContractOp = rewriter.create<vector::ContractionOp>(
|
|
||||||
contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2],
|
|
||||||
rewriter.getAffineMapArrayAttr(newIndexingMaps),
|
|
||||||
rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
|
|
||||||
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
|
|
||||||
contractOp, contractOp->getResultTypes()[0], newContractOp);
|
|
||||||
return success();
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user