[mlir] Introduce trailingNDimsContiguous for MemRefs (#78247)

Extracts logic from `vector::isContiguousSlice` to check whether
the trailing dim of a memref are contiguous into a dedicated hook
in BuiitinTypes.{h|cpp}.

Follow-up for https://github.com/llvm/llvm-project/pull/76848.
This commit is contained in:
Andrzej Warzyński 2024-02-17 08:47:10 +00:00 committed by GitHub
parent 44436a9c6b
commit 9478bf0ce6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 46 additions and 29 deletions

View File

@ -518,6 +518,16 @@ bool isStrided(MemRefType t);
/// stride. Also return "true" for types with no strides. /// stride. Also return "true" for types with no strides.
bool isLastMemrefDimUnitStride(MemRefType type); bool isLastMemrefDimUnitStride(MemRefType type);
/// Return "true" if the last N dimensions of the given type are contiguous.
///
/// Examples:
/// - memref<5x4x3x2xi8, strided<[24, 6, 2, 1]> is contiguous when
/// considering both _all_ and _only_ the trailing 3 dims,
/// - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]> is _only_ contiguous when
/// considering the trailing 3 dims.
///
bool trailingNDimsContiguous(MemRefType type, int64_t n);
} // namespace mlir } // namespace mlir
#endif // MLIR_IR_BUILTINTYPES_H #endif // MLIR_IR_BUILTINTYPES_H

View File

@ -257,38 +257,13 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
ArrayRef<int64_t> vectorShape = vectorType.getShape(); ArrayRef<int64_t> vectorShape = vectorType.getShape();
auto vecRank = vectorType.getRank(); auto vecRank = vectorType.getRank();
if (!trailingNDimsContiguous(memrefType, vecRank))
return false;
// Extract the trailing dims and strides of the input memref // Extract the trailing dims and strides of the input memref
auto memrefShape = memrefType.getShape().take_back(vecRank); auto memrefShape = memrefType.getShape().take_back(vecRank);
int64_t offset;
SmallVector<int64_t> stridesFull;
if (!succeeded(getStridesAndOffset(memrefType, stridesFull, offset)))
return false;
auto strides = ArrayRef<int64_t>(stridesFull).take_back(vecRank);
memrefType.getLayout().isIdentity();
// TODO: Add support for memref with trailing dynamic shapes. Memrefs // Compare the dims of `vectorType` against `memrefType` (in reverse).
// with leading dynamic dimensions are already supported.
if (ShapedType::isDynamicShape(memrefShape))
return false;
// Cond 1: Check whether `memrefType` is contiguous.
if (!strides.empty()) {
// Cond 1.1: A contiguous memref will always have a unit trailing stride.
if (strides.back() != 1)
return false;
// Cond 1.2: Strides of a contiguous memref have to match the flattened
// dims.
strides = strides.drop_back(1);
SmallVector<int64_t> flattenedDims;
for (size_t i = 1; i < memrefShape.size(); i++)
flattenedDims.push_back(mlir::computeProduct(memrefShape.take_back(i)));
if (!llvm::equal(strides, llvm::reverse(flattenedDims)))
return false;
}
// Cond 2: Compare the dims of `vectorType` against `memrefType` (in reverse).
// In the most basic case, all dims will match. // In the most basic case, all dims will match.
auto firstNonMatchingDim = auto firstNonMatchingDim =
std::mismatch(vectorShape.rbegin(), vectorShape.rend(), std::mismatch(vectorShape.rbegin(), vectorShape.rend(),

View File

@ -967,3 +967,35 @@ bool mlir::isLastMemrefDimUnitStride(MemRefType type) {
auto successStrides = getStridesAndOffset(type, strides, offset); auto successStrides = getStridesAndOffset(type, strides, offset);
return succeeded(successStrides) && (strides.empty() || strides.back() == 1); return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
} }
bool mlir::trailingNDimsContiguous(MemRefType type, int64_t n) {
if (!isLastMemrefDimUnitStride(type))
return false;
auto memrefShape = type.getShape().take_back(n);
if (ShapedType::isDynamicShape(memrefShape))
return false;
if (type.getLayout().isIdentity())
return true;
int64_t offset;
SmallVector<int64_t> stridesFull;
if (!succeeded(getStridesAndOffset(type, stridesFull, offset)))
return false;
auto strides = ArrayRef<int64_t>(stridesFull).take_back(n);
if (strides.empty())
return true;
// Check whether strides match "flattened" dims.
SmallVector<int64_t> flattenedDims;
auto dimProduct = 1;
for (auto dim : llvm::reverse(memrefShape.drop_front(1))) {
dimProduct *= dim;
flattenedDims.push_back(dimProduct);
}
strides = strides.drop_back(1);
return llvm::equal(strides, llvm::reverse(flattenedDims));
}