[mlir] Introduce trailingNDimsContiguous
for MemRefs (#78247)
Extracts logic from `vector::isContiguousSlice` to check whether the trailing dim of a memref are contiguous into a dedicated hook in BuiitinTypes.{h|cpp}. Follow-up for https://github.com/llvm/llvm-project/pull/76848.
This commit is contained in:
parent
44436a9c6b
commit
9478bf0ce6
@ -518,6 +518,16 @@ bool isStrided(MemRefType t);
|
||||
/// stride. Also return "true" for types with no strides.
|
||||
bool isLastMemrefDimUnitStride(MemRefType type);
|
||||
|
||||
/// Return "true" if the last N dimensions of the given type are contiguous.
|
||||
///
|
||||
/// Examples:
|
||||
/// - memref<5x4x3x2xi8, strided<[24, 6, 2, 1]> is contiguous when
|
||||
/// considering both _all_ and _only_ the trailing 3 dims,
|
||||
/// - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]> is _only_ contiguous when
|
||||
/// considering the trailing 3 dims.
|
||||
///
|
||||
bool trailingNDimsContiguous(MemRefType type, int64_t n);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_IR_BUILTINTYPES_H
|
||||
|
@ -257,38 +257,13 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
|
||||
ArrayRef<int64_t> vectorShape = vectorType.getShape();
|
||||
auto vecRank = vectorType.getRank();
|
||||
|
||||
if (!trailingNDimsContiguous(memrefType, vecRank))
|
||||
return false;
|
||||
|
||||
// Extract the trailing dims and strides of the input memref
|
||||
auto memrefShape = memrefType.getShape().take_back(vecRank);
|
||||
int64_t offset;
|
||||
SmallVector<int64_t> stridesFull;
|
||||
if (!succeeded(getStridesAndOffset(memrefType, stridesFull, offset)))
|
||||
return false;
|
||||
auto strides = ArrayRef<int64_t>(stridesFull).take_back(vecRank);
|
||||
memrefType.getLayout().isIdentity();
|
||||
|
||||
// TODO: Add support for memref with trailing dynamic shapes. Memrefs
|
||||
// with leading dynamic dimensions are already supported.
|
||||
if (ShapedType::isDynamicShape(memrefShape))
|
||||
return false;
|
||||
|
||||
// Cond 1: Check whether `memrefType` is contiguous.
|
||||
if (!strides.empty()) {
|
||||
// Cond 1.1: A contiguous memref will always have a unit trailing stride.
|
||||
if (strides.back() != 1)
|
||||
return false;
|
||||
|
||||
// Cond 1.2: Strides of a contiguous memref have to match the flattened
|
||||
// dims.
|
||||
strides = strides.drop_back(1);
|
||||
SmallVector<int64_t> flattenedDims;
|
||||
for (size_t i = 1; i < memrefShape.size(); i++)
|
||||
flattenedDims.push_back(mlir::computeProduct(memrefShape.take_back(i)));
|
||||
|
||||
if (!llvm::equal(strides, llvm::reverse(flattenedDims)))
|
||||
return false;
|
||||
}
|
||||
|
||||
// Cond 2: Compare the dims of `vectorType` against `memrefType` (in reverse).
|
||||
// Compare the dims of `vectorType` against `memrefType` (in reverse).
|
||||
// In the most basic case, all dims will match.
|
||||
auto firstNonMatchingDim =
|
||||
std::mismatch(vectorShape.rbegin(), vectorShape.rend(),
|
||||
|
@ -967,3 +967,35 @@ bool mlir::isLastMemrefDimUnitStride(MemRefType type) {
|
||||
auto successStrides = getStridesAndOffset(type, strides, offset);
|
||||
return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
|
||||
}
|
||||
|
||||
bool mlir::trailingNDimsContiguous(MemRefType type, int64_t n) {
|
||||
if (!isLastMemrefDimUnitStride(type))
|
||||
return false;
|
||||
|
||||
auto memrefShape = type.getShape().take_back(n);
|
||||
if (ShapedType::isDynamicShape(memrefShape))
|
||||
return false;
|
||||
|
||||
if (type.getLayout().isIdentity())
|
||||
return true;
|
||||
|
||||
int64_t offset;
|
||||
SmallVector<int64_t> stridesFull;
|
||||
if (!succeeded(getStridesAndOffset(type, stridesFull, offset)))
|
||||
return false;
|
||||
auto strides = ArrayRef<int64_t>(stridesFull).take_back(n);
|
||||
|
||||
if (strides.empty())
|
||||
return true;
|
||||
|
||||
// Check whether strides match "flattened" dims.
|
||||
SmallVector<int64_t> flattenedDims;
|
||||
auto dimProduct = 1;
|
||||
for (auto dim : llvm::reverse(memrefShape.drop_front(1))) {
|
||||
dimProduct *= dim;
|
||||
flattenedDims.push_back(dimProduct);
|
||||
}
|
||||
|
||||
strides = strides.drop_back(1);
|
||||
return llvm::equal(strides, llvm::reverse(flattenedDims));
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user