
This commit refactors the getStridesAndOffet() method on MemRefType to just call `MemRefLayoutAttrInterface::getStridesAndOffset(shape, strides& offset&)`, allowing downstream users and future layouts (ex, a potential contiguous layout) to implement it without needing to patch BuiltinTypes or without needing them to conform their affine maps to the canonical strided form.
221 lines
8.4 KiB
C++
221 lines
8.4 KiB
C++
//===- BuiltinAttributeInterfaces.cpp -------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/IR/BuiltinAttributeInterfaces.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/Diagnostics.h"
|
|
#include "llvm/ADT/Sequence.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::detail;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
/// Tablegen Interface Definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/IR/BuiltinAttributeInterfaces.cpp.inc"
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ElementsAttr
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
Type ElementsAttr::getElementType(ElementsAttr elementsAttr) {
|
|
return elementsAttr.getShapedType().getElementType();
|
|
}
|
|
|
|
int64_t ElementsAttr::getNumElements(ElementsAttr elementsAttr) {
|
|
return elementsAttr.getShapedType().getNumElements();
|
|
}
|
|
|
|
bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef<uint64_t> index) {
|
|
// Verify that the rank of the indices matches the held type.
|
|
int64_t rank = type.getRank();
|
|
if (rank == 0 && index.size() == 1 && index[0] == 0)
|
|
return true;
|
|
if (rank != static_cast<int64_t>(index.size()))
|
|
return false;
|
|
|
|
// Verify that all of the indices are within the shape dimensions.
|
|
ArrayRef<int64_t> shape = type.getShape();
|
|
return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) {
|
|
int64_t dim = static_cast<int64_t>(index[i]);
|
|
return 0 <= dim && dim < shape[i];
|
|
});
|
|
}
|
|
bool ElementsAttr::isValidIndex(ElementsAttr elementsAttr,
|
|
ArrayRef<uint64_t> index) {
|
|
return isValidIndex(elementsAttr.getShapedType(), index);
|
|
}
|
|
|
|
uint64_t ElementsAttr::getFlattenedIndex(Type type, ArrayRef<uint64_t> index) {
|
|
ShapedType shapeType = llvm::cast<ShapedType>(type);
|
|
assert(isValidIndex(shapeType, index) &&
|
|
"expected valid multi-dimensional index");
|
|
|
|
// Reduce the provided multidimensional index into a flattended 1D row-major
|
|
// index.
|
|
auto rank = shapeType.getRank();
|
|
ArrayRef<int64_t> shape = shapeType.getShape();
|
|
uint64_t valueIndex = 0;
|
|
uint64_t dimMultiplier = 1;
|
|
for (int i = rank - 1; i >= 0; --i) {
|
|
valueIndex += index[i] * dimMultiplier;
|
|
dimMultiplier *= shape[i];
|
|
}
|
|
return valueIndex;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MemRefLayoutAttrInterface
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult mlir::detail::verifyAffineMapAsLayout(
|
|
AffineMap m, ArrayRef<int64_t> shape,
|
|
function_ref<InFlightDiagnostic()> emitError) {
|
|
if (m.getNumDims() != shape.size())
|
|
return emitError() << "memref layout mismatch between rank and affine map: "
|
|
<< shape.size() << " != " << m.getNumDims();
|
|
|
|
return success();
|
|
}
|
|
|
|
// Fallback cases for terminal dim/sym/cst that are not part of a binary op (
|
|
// i.e. single term). Accumulate the AffineExpr into the existing one.
|
|
static void extractStridesFromTerm(AffineExpr e,
|
|
AffineExpr multiplicativeFactor,
|
|
MutableArrayRef<AffineExpr> strides,
|
|
AffineExpr &offset) {
|
|
if (auto dim = dyn_cast<AffineDimExpr>(e))
|
|
strides[dim.getPosition()] =
|
|
strides[dim.getPosition()] + multiplicativeFactor;
|
|
else
|
|
offset = offset + e * multiplicativeFactor;
|
|
}
|
|
|
|
/// Takes a single AffineExpr `e` and populates the `strides` array with the
|
|
/// strides expressions for each dim position.
|
|
/// The convention is that the strides for dimensions d0, .. dn appear in
|
|
/// order to make indexing intuitive into the result.
|
|
static LogicalResult extractStrides(AffineExpr e,
|
|
AffineExpr multiplicativeFactor,
|
|
MutableArrayRef<AffineExpr> strides,
|
|
AffineExpr &offset) {
|
|
auto bin = dyn_cast<AffineBinaryOpExpr>(e);
|
|
if (!bin) {
|
|
extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
|
|
return success();
|
|
}
|
|
|
|
if (bin.getKind() == AffineExprKind::CeilDiv ||
|
|
bin.getKind() == AffineExprKind::FloorDiv ||
|
|
bin.getKind() == AffineExprKind::Mod)
|
|
return failure();
|
|
|
|
if (bin.getKind() == AffineExprKind::Mul) {
|
|
auto dim = dyn_cast<AffineDimExpr>(bin.getLHS());
|
|
if (dim) {
|
|
strides[dim.getPosition()] =
|
|
strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
|
|
return success();
|
|
}
|
|
// LHS and RHS may both contain complex expressions of dims. Try one path
|
|
// and if it fails try the other. This is guaranteed to succeed because
|
|
// only one path may have a `dim`, otherwise this is not an AffineExpr in
|
|
// the first place.
|
|
if (bin.getLHS().isSymbolicOrConstant())
|
|
return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
|
|
strides, offset);
|
|
return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
|
|
strides, offset);
|
|
}
|
|
|
|
if (bin.getKind() == AffineExprKind::Add) {
|
|
auto res1 =
|
|
extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
|
|
auto res2 =
|
|
extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
|
|
return success(succeeded(res1) && succeeded(res2));
|
|
}
|
|
|
|
llvm_unreachable("unexpected binary operation");
|
|
}
|
|
|
|
/// A stride specification is a list of integer values that are either static
|
|
/// or dynamic (encoded with ShapedType::kDynamic). Strides encode
|
|
/// the distance in the number of elements between successive entries along a
|
|
/// particular dimension.
|
|
///
|
|
/// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a
|
|
/// non-contiguous memory region of `42` by `16` `f32` elements in which the
|
|
/// distance between two consecutive elements along the outer dimension is `1`
|
|
/// and the distance between two consecutive elements along the inner dimension
|
|
/// is `64`.
|
|
///
|
|
/// The convention is that the strides for dimensions d0, .. dn appear in
|
|
/// order to make indexing intuitive into the result.
|
|
static LogicalResult getStridesAndOffset(AffineMap m, ArrayRef<int64_t> shape,
|
|
SmallVectorImpl<AffineExpr> &strides,
|
|
AffineExpr &offset) {
|
|
if (m.getNumResults() != 1 && !m.isIdentity())
|
|
return failure();
|
|
|
|
auto zero = getAffineConstantExpr(0, m.getContext());
|
|
auto one = getAffineConstantExpr(1, m.getContext());
|
|
offset = zero;
|
|
strides.assign(shape.size(), zero);
|
|
|
|
// Canonical case for empty map.
|
|
if (m.isIdentity()) {
|
|
// 0-D corner case, offset is already 0.
|
|
if (shape.empty())
|
|
return success();
|
|
auto stridedExpr = makeCanonicalStridedLayoutExpr(shape, m.getContext());
|
|
if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
|
|
return success();
|
|
assert(false && "unexpected failure: extract strides in canonical layout");
|
|
}
|
|
|
|
// Non-canonical case requires more work.
|
|
auto stridedExpr =
|
|
simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
|
|
if (failed(extractStrides(stridedExpr, one, strides, offset))) {
|
|
offset = AffineExpr();
|
|
strides.clear();
|
|
return failure();
|
|
}
|
|
|
|
// Simplify results to allow folding to constants and simple checks.
|
|
unsigned numDims = m.getNumDims();
|
|
unsigned numSymbols = m.getNumSymbols();
|
|
offset = simplifyAffineExpr(offset, numDims, numSymbols);
|
|
for (auto &stride : strides)
|
|
stride = simplifyAffineExpr(stride, numDims, numSymbols);
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult mlir::detail::getAffineMapStridesAndOffset(
|
|
AffineMap map, ArrayRef<int64_t> shape, SmallVectorImpl<int64_t> &strides,
|
|
int64_t &offset) {
|
|
AffineExpr offsetExpr;
|
|
SmallVector<AffineExpr, 4> strideExprs;
|
|
if (failed(::getStridesAndOffset(map, shape, strideExprs, offsetExpr)))
|
|
return failure();
|
|
if (auto cst = llvm::dyn_cast<AffineConstantExpr>(offsetExpr))
|
|
offset = cst.getValue();
|
|
else
|
|
offset = ShapedType::kDynamic;
|
|
for (auto e : strideExprs) {
|
|
if (auto c = llvm::dyn_cast<AffineConstantExpr>(e))
|
|
strides.push_back(c.getValue());
|
|
else
|
|
strides.push_back(ShapedType::kDynamic);
|
|
}
|
|
return success();
|
|
}
|