[mlir] Add isStatic
* size check for ShapedType
s. NFCI. (#147085)
The motivation is to avoid having to negate `isDynamic*` checks, avoid double negations, and allow for `ShapedType::isStaticDim` to be used in ADT functions without having to wrap it in a lambda performing the negation. Also add the new functions to C and Python bindings.
This commit is contained in:
parent
0032148ea6
commit
6512ca7ddb
@ -289,9 +289,12 @@ MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetRank(MlirType type);
|
||||
/// Checks whether the given shaped type has a static shape.
|
||||
MLIR_CAPI_EXPORTED bool mlirShapedTypeHasStaticShape(MlirType type);
|
||||
|
||||
/// Checks wither the dim-th dimension of the given shaped type is dynamic.
|
||||
/// Checks whether the dim-th dimension of the given shaped type is dynamic.
|
||||
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim);
|
||||
|
||||
/// Checks whether the dim-th dimension of the given shaped type is static.
|
||||
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsStaticDim(MlirType type, intptr_t dim);
|
||||
|
||||
/// Returns the dim-th dimension of the given ranked shaped type.
|
||||
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDimSize(MlirType type,
|
||||
intptr_t dim);
|
||||
@ -300,17 +303,25 @@ MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDimSize(MlirType type,
|
||||
/// in shaped types.
|
||||
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicSize(int64_t size);
|
||||
|
||||
/// Checks whether the given shaped type dimension value is statically-sized.
|
||||
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsStaticSize(int64_t size);
|
||||
|
||||
/// Returns the value indicating a dynamic size in a shaped type. Prefer
|
||||
/// mlirShapedTypeIsDynamicSize to direct comparisons with this value.
|
||||
/// mlirShapedTypeIsDynamicSize and mlirShapedTypeIsStaticSize to direct
|
||||
/// comparisons with this value.
|
||||
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicSize(void);
|
||||
|
||||
/// Checks whether the given value is used as a placeholder for dynamic strides
|
||||
/// and offsets in shaped types.
|
||||
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val);
|
||||
|
||||
/// Checks whether the given dimension value of a stride or an offset is
|
||||
/// statically-sized.
|
||||
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsStaticStrideOrOffset(int64_t val);
|
||||
|
||||
/// Returns the value indicating a dynamic stride or offset in a shaped type.
|
||||
/// Prefer mlirShapedTypeGetDynamicStrideOrOffset to direct comparisons with
|
||||
/// this value.
|
||||
/// Prefer mlirShapedTypeIsDynamicStrideOrOffset and
|
||||
/// mlirShapedTypeIsStaticStrideOrOffset to direct comparisons with this value.
|
||||
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicStrideOrOffset(void);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -36,7 +36,7 @@ def VectorElementTypeInterface : TypeInterface<"VectorElementTypeInterface"> {
|
||||
This may change in the future, for example, to require types to provide
|
||||
their size or alignment given a data layout. Please post an RFC before
|
||||
adding this interface to additional types. Implementing this interface on
|
||||
downstream types is discourged, until we specified the exact properties of
|
||||
downstream types is discouraged, until we specified the exact properties of
|
||||
a vector element type in more detail.
|
||||
}];
|
||||
}
|
||||
@ -221,7 +221,17 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
|
||||
|
||||
/// Whether the given shape has any size that indicates a dynamic dimension.
|
||||
static bool isDynamicShape(ArrayRef<int64_t> dSizes) {
|
||||
return any_of(dSizes, [](int64_t dSize) { return isDynamic(dSize); });
|
||||
return llvm::any_of(dSizes, isDynamic);
|
||||
}
|
||||
|
||||
/// Whether the given dimension size indicates a statically-sized dimension.
|
||||
static constexpr bool isStatic(int64_t dValue) {
|
||||
return dValue != kDynamic;
|
||||
}
|
||||
|
||||
/// Whether the given shape has static dimensions only.
|
||||
static bool isStaticShape(ArrayRef<int64_t> dSizes) {
|
||||
return llvm::all_of(dSizes, isStatic);
|
||||
}
|
||||
|
||||
/// Return the number of elements present in the given shape.
|
||||
@ -273,11 +283,18 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
|
||||
return ::mlir::ShapedType::isDynamic($_type.getShape()[idx]);
|
||||
}
|
||||
|
||||
/// Returns true if this dimension has a static size (for ranked types);
|
||||
/// aborts for unranked types.
|
||||
bool isStaticDim(unsigned idx) const {
|
||||
assert(idx < getRank() && "invalid index for shaped type");
|
||||
return ::mlir::ShapedType::isStatic($_type.getShape()[idx]);
|
||||
}
|
||||
|
||||
/// Returns if this type has a static shape, i.e. if the type is ranked and
|
||||
/// all dimensions have known size (>= 0).
|
||||
bool hasStaticShape() const {
|
||||
return $_type.hasRank() &&
|
||||
!::mlir::ShapedType::isDynamicShape($_type.getShape());
|
||||
::mlir::ShapedType::isStaticShape($_type.getShape());
|
||||
}
|
||||
|
||||
/// Returns if this type has a static shape and the shape is equal to
|
||||
|
@ -544,6 +544,15 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) {
|
||||
nb::arg("dim"),
|
||||
"Returns whether the dim-th dimension of the given shaped type is "
|
||||
"dynamic.");
|
||||
c.def(
|
||||
"is_static_dim",
|
||||
[](PyShapedType &self, intptr_t dim) -> bool {
|
||||
self.requireHasRank();
|
||||
return mlirShapedTypeIsStaticDim(self, dim);
|
||||
},
|
||||
nb::arg("dim"),
|
||||
"Returns whether the dim-th dimension of the given shaped type is "
|
||||
"static.");
|
||||
c.def(
|
||||
"get_dim_size",
|
||||
[](PyShapedType &self, intptr_t dim) {
|
||||
@ -558,6 +567,12 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) {
|
||||
nb::arg("dim_size"),
|
||||
"Returns whether the given dimension size indicates a dynamic "
|
||||
"dimension.");
|
||||
c.def_static(
|
||||
"is_static_size",
|
||||
[](int64_t size) -> bool { return mlirShapedTypeIsStaticSize(size); },
|
||||
nb::arg("dim_size"),
|
||||
"Returns whether the given dimension size indicates a static "
|
||||
"dimension.");
|
||||
c.def(
|
||||
"is_dynamic_stride_or_offset",
|
||||
[](PyShapedType &self, int64_t val) -> bool {
|
||||
@ -567,6 +582,15 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) {
|
||||
nb::arg("dim_size"),
|
||||
"Returns whether the given value is used as a placeholder for dynamic "
|
||||
"strides and offsets in shaped types.");
|
||||
c.def(
|
||||
"is_static_stride_or_offset",
|
||||
[](PyShapedType &self, int64_t val) -> bool {
|
||||
self.requireHasRank();
|
||||
return mlirShapedTypeIsStaticStrideOrOffset(val);
|
||||
},
|
||||
nb::arg("dim_size"),
|
||||
"Returns whether the given shaped type stride or offset value is "
|
||||
"statically-sized.");
|
||||
c.def_prop_ro(
|
||||
"shape",
|
||||
[](PyShapedType &self) {
|
||||
|
@ -332,6 +332,11 @@ bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim) {
|
||||
.isDynamicDim(static_cast<unsigned>(dim));
|
||||
}
|
||||
|
||||
bool mlirShapedTypeIsStaticDim(MlirType type, intptr_t dim) {
|
||||
return llvm::cast<ShapedType>(unwrap(type))
|
||||
.isStaticDim(static_cast<unsigned>(dim));
|
||||
}
|
||||
|
||||
int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim) {
|
||||
return llvm::cast<ShapedType>(unwrap(type))
|
||||
.getDimSize(static_cast<unsigned>(dim));
|
||||
@ -343,10 +348,18 @@ bool mlirShapedTypeIsDynamicSize(int64_t size) {
|
||||
return ShapedType::isDynamic(size);
|
||||
}
|
||||
|
||||
bool mlirShapedTypeIsStaticSize(int64_t size) {
|
||||
return ShapedType::isStatic(size);
|
||||
}
|
||||
|
||||
bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val) {
|
||||
return ShapedType::isDynamic(val);
|
||||
}
|
||||
|
||||
bool mlirShapedTypeIsStaticStrideOrOffset(int64_t val) {
|
||||
return ShapedType::isStatic(val);
|
||||
}
|
||||
|
||||
int64_t mlirShapedTypeGetDynamicStrideOrOffset() {
|
||||
return ShapedType::kDynamic;
|
||||
}
|
||||
|
@ -53,7 +53,7 @@ MemRefDescriptor MemRefDescriptor::fromStaticShape(
|
||||
|
||||
// Extract all strides and offsets and verify they are static.
|
||||
auto [strides, offset] = type.getStridesAndOffset();
|
||||
assert(!ShapedType::isDynamic(offset) && "expected static offset");
|
||||
assert(ShapedType::isStatic(offset) && "expected static offset");
|
||||
assert(!llvm::any_of(strides, ShapedType::isDynamic) &&
|
||||
"expected static strides");
|
||||
|
||||
|
@ -609,7 +609,7 @@ bool LLVMTypeConverter::canConvertToBarePtr(BaseMemRefType type) {
|
||||
if (ShapedType::isDynamic(stride))
|
||||
return false;
|
||||
|
||||
return !ShapedType::isDynamic(offset);
|
||||
return ShapedType::isStatic(offset);
|
||||
}
|
||||
|
||||
/// Convert a memref type to a bare pointer to the memref element type.
|
||||
|
@ -43,7 +43,7 @@ static constexpr LLVM::GEPNoWrapFlags kNoWrapFlags =
|
||||
namespace {
|
||||
|
||||
static bool isStaticStrideOrOffset(int64_t strideOrOffset) {
|
||||
return !ShapedType::isDynamic(strideOrOffset);
|
||||
return ShapedType::isStatic(strideOrOffset);
|
||||
}
|
||||
|
||||
static FailureOr<LLVM::LLVMFuncOp>
|
||||
@ -1468,7 +1468,7 @@ private:
|
||||
Value stride = nullptr;
|
||||
int64_t targetRank = targetMemRefType.getRank();
|
||||
for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
|
||||
if (!ShapedType::isDynamic(strides[i])) {
|
||||
if (ShapedType::isStatic(strides[i])) {
|
||||
// If the stride for this dimension is dynamic, then use the product
|
||||
// of the sizes of the inner dimensions.
|
||||
stride =
|
||||
@ -1722,7 +1722,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
|
||||
ArrayRef<int64_t> shape, ValueRange dynamicSizes, unsigned idx,
|
||||
Type indexType) const {
|
||||
assert(idx < shape.size());
|
||||
if (!ShapedType::isDynamic(shape[idx]))
|
||||
if (ShapedType::isStatic(shape[idx]))
|
||||
return createIndexAttrConstant(rewriter, loc, indexType, shape[idx]);
|
||||
// Count the number of dynamic dims in range [0, idx]
|
||||
unsigned nDynamic =
|
||||
@ -1738,7 +1738,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
|
||||
ArrayRef<int64_t> strides, Value nextSize,
|
||||
Value runningStride, unsigned idx, Type indexType) const {
|
||||
assert(idx < strides.size());
|
||||
if (!ShapedType::isDynamic(strides[idx]))
|
||||
if (ShapedType::isStatic(strides[idx]))
|
||||
return createIndexAttrConstant(rewriter, loc, indexType, strides[idx]);
|
||||
if (nextSize)
|
||||
return runningStride
|
||||
|
@ -757,7 +757,7 @@ computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool,
|
||||
// dimension greater than 1 with a different value is undefined behavior.
|
||||
for (auto operand : operands) {
|
||||
auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
|
||||
if (!ShapedType::isDynamic(size) && size > 1)
|
||||
if (ShapedType::isStatic(size) && size > 1)
|
||||
return {rewriter.getIndexAttr(size), operand};
|
||||
}
|
||||
|
||||
|
@ -83,7 +83,7 @@ TensorType inferReshapeExpandedType(TensorType inputType,
|
||||
return totalSize / totalSizeNoPlaceholder;
|
||||
});
|
||||
|
||||
bool resultIsStatic = !ShapedType::isDynamicShape(resultShape);
|
||||
bool resultIsStatic = ShapedType::isStaticShape(resultShape);
|
||||
|
||||
// A syntactic restriction in 'tensor.expand_shape' forbids a dynamically
|
||||
// shaped input from being reshaped into a statically shaped result. We may
|
||||
@ -305,7 +305,7 @@ public:
|
||||
int64_t size = i.value();
|
||||
size_t index = i.index();
|
||||
sizes.push_back(size == -1 ? ShapedType::kDynamic : size);
|
||||
if (!ShapedType::isDynamic(sizes.back()))
|
||||
if (ShapedType::isStatic(sizes.back()))
|
||||
continue;
|
||||
|
||||
auto dim = rewriter.create<tensor::DimOp>(loc, input, index);
|
||||
|
@ -44,7 +44,7 @@ FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue(
|
||||
failed(target.getStridesAndOffset(targetStrides, targetOffset)))
|
||||
return false;
|
||||
auto dynamicToStatic = [](int64_t a, int64_t b) {
|
||||
return ShapedType::isDynamic(a) && !ShapedType::isDynamic(b);
|
||||
return ShapedType::isDynamic(a) && ShapedType::isStatic(b);
|
||||
};
|
||||
if (dynamicToStatic(sourceOffset, targetOffset))
|
||||
return false;
|
||||
|
@ -33,7 +33,7 @@ static bool hasFullyDynamicLayoutMap(MemRefType type) {
|
||||
return false;
|
||||
if (!llvm::all_of(strides, ShapedType::isDynamic))
|
||||
return false;
|
||||
if (!ShapedType::isDynamic(offset))
|
||||
if (ShapedType::isStatic(offset))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
@ -4564,7 +4564,7 @@ static SmallVector<OpFoldResult> getMixedTilesImpl(OpTy op) {
|
||||
SmallVector<OpFoldResult> mixedInnerTiles;
|
||||
unsigned dynamicValIndex = 0;
|
||||
for (int64_t staticTile : op.getStaticInnerTiles()) {
|
||||
if (!ShapedType::isDynamic(staticTile))
|
||||
if (ShapedType::isStatic(staticTile))
|
||||
mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));
|
||||
else
|
||||
mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
|
||||
@ -4829,7 +4829,7 @@ bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
|
||||
std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
|
||||
|
||||
if (!constantTile) {
|
||||
if (!ShapedType::isDynamic(outputTileSizes[pos]) &&
|
||||
if (ShapedType::isStatic(outputTileSizes[pos]) &&
|
||||
(inputShape[pos] % outputTileSizes[pos] != 0))
|
||||
return true;
|
||||
} else if (inputShape[pos] % (*constantTile) != 0) {
|
||||
@ -4935,7 +4935,7 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
|
||||
// use dispatchIndexOpFoldResults on the result, and rely on exact number of
|
||||
// dynamic dims returned by that.
|
||||
for (unsigned i = 0; i < resultDims.size(); ++i) {
|
||||
if (!ShapedType::isDynamic(resultTypeShape[i]))
|
||||
if (ShapedType::isStatic(resultTypeShape[i]))
|
||||
continue;
|
||||
resultDims[i] =
|
||||
getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);
|
||||
|
@ -2061,7 +2061,7 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
|
||||
rewriter.setInsertionPoint(linalgTarget);
|
||||
for (OpOperand &operand : linalgTarget->getOpOperands()) {
|
||||
for (auto [i, dim] : llvm::enumerate(linalgTarget.getShape(&operand))) {
|
||||
if (!ShapedType::isDynamic(dim))
|
||||
if (ShapedType::isStatic(dim))
|
||||
continue;
|
||||
options.setSizeToPadTo(operand.getOperandNumber(), i,
|
||||
tensor::getMixedSize(rewriter,
|
||||
|
@ -335,7 +335,7 @@ VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
|
||||
LinalgOp linalgOp) {
|
||||
// TODO: Support 0-d vectors.
|
||||
for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
|
||||
if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
|
||||
if (ShapedType::isStatic(iterSpaceStaticSizes[vecDim])) {
|
||||
// Create constant index op for static dimensions.
|
||||
iterSpaceValueSizes.push_back(rewriter.create<arith::ConstantIndexOp>(
|
||||
linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
|
||||
@ -1652,7 +1652,7 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
|
||||
for (unsigned i = 0; i < vecToStoreRank; i++)
|
||||
inBoundsVal[i] =
|
||||
(destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
|
||||
!ShapedType::isDynamic(destShape[destRank - vecToStoreRank + i]);
|
||||
ShapedType::isStatic(destShape[destRank - vecToStoreRank + i]);
|
||||
}
|
||||
|
||||
// If missing, initialize the write indices to 0.
|
||||
|
@ -694,7 +694,7 @@ computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
|
||||
int64_t shapeSize = shape[r];
|
||||
std::optional<int64_t> sizeCst = getConstantIntValue(size);
|
||||
auto hasTileSizeOne = sizeCst == 1;
|
||||
auto dividesEvenly = sizeCst && !ShapedType::isDynamic(shapeSize) &&
|
||||
auto dividesEvenly = sizeCst && ShapedType::isStatic(shapeSize) &&
|
||||
((shapeSize % *sizeCst) == 0);
|
||||
if (!hasTileSizeOne && !dividesEvenly) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: shapeSize=" << shapeSize
|
||||
|
@ -99,7 +99,7 @@ static void constifyIndexValues(SmallVectorImpl<OpFoldResult> &values,
|
||||
"incorrect number of const values");
|
||||
for (auto [i, cstVal] : llvm::enumerate(constValues)) {
|
||||
Builder builder(values[i].getContext());
|
||||
if (!ShapedType::isDynamic(cstVal)) {
|
||||
if (ShapedType::isStatic(cstVal)) {
|
||||
// Constant value is known, use it directly.
|
||||
values[i] = builder.getIndexAttr(cstVal);
|
||||
continue;
|
||||
@ -189,7 +189,7 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
|
||||
for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
|
||||
int64_t dimSize = memrefType.getDimSize(dim);
|
||||
// If this is already static dimension, keep it.
|
||||
if (!ShapedType::isDynamic(dimSize)) {
|
||||
if (ShapedType::isStatic(dimSize)) {
|
||||
newShapeConstants.push_back(dimSize);
|
||||
continue;
|
||||
}
|
||||
@ -615,21 +615,21 @@ bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
|
||||
for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
|
||||
auto ss = std::get<0>(it), st = std::get<1>(it);
|
||||
if (ss != st)
|
||||
if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
|
||||
if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
|
||||
return false;
|
||||
}
|
||||
|
||||
// If cast is towards more static offset along any dimension, don't fold.
|
||||
if (sourceOffset != resultOffset)
|
||||
if (ShapedType::isDynamic(sourceOffset) &&
|
||||
!ShapedType::isDynamic(resultOffset))
|
||||
ShapedType::isStatic(resultOffset))
|
||||
return false;
|
||||
|
||||
// If cast is towards more static strides along any dimension, don't fold.
|
||||
for (auto it : llvm::zip(sourceStrides, resultStrides)) {
|
||||
auto ss = std::get<0>(it), st = std::get<1>(it);
|
||||
if (ss != st)
|
||||
if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
|
||||
if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -679,7 +679,7 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
||||
|
||||
for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
|
||||
int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
|
||||
if (!ShapedType::isDynamic(aDim) && !ShapedType::isDynamic(bDim) &&
|
||||
if (ShapedType::isStatic(aDim) && ShapedType::isStatic(bDim) &&
|
||||
aDim != bDim)
|
||||
return false;
|
||||
}
|
||||
@ -1862,7 +1862,7 @@ LogicalResult ReinterpretCastOp::verify() {
|
||||
// Match sizes in result memref type and in static_sizes attribute.
|
||||
for (auto [idx, resultSize, expectedSize] :
|
||||
llvm::enumerate(resultType.getShape(), getStaticSizes())) {
|
||||
if (!ShapedType::isDynamic(resultSize) && resultSize != expectedSize)
|
||||
if (ShapedType::isStatic(resultSize) && resultSize != expectedSize)
|
||||
return emitError("expected result type with size = ")
|
||||
<< (ShapedType::isDynamic(expectedSize)
|
||||
? std::string("dynamic")
|
||||
@ -1881,7 +1881,7 @@ LogicalResult ReinterpretCastOp::verify() {
|
||||
|
||||
// Match offset in result memref type and in static_offsets attribute.
|
||||
int64_t expectedOffset = getStaticOffsets().front();
|
||||
if (!ShapedType::isDynamic(resultOffset) && resultOffset != expectedOffset)
|
||||
if (ShapedType::isStatic(resultOffset) && resultOffset != expectedOffset)
|
||||
return emitError("expected result type with offset = ")
|
||||
<< (ShapedType::isDynamic(expectedOffset)
|
||||
? std::string("dynamic")
|
||||
@ -1891,7 +1891,7 @@ LogicalResult ReinterpretCastOp::verify() {
|
||||
// Match strides in result memref type and in static_strides attribute.
|
||||
for (auto [idx, resultStride, expectedStride] :
|
||||
llvm::enumerate(resultStrides, getStaticStrides())) {
|
||||
if (!ShapedType::isDynamic(resultStride) && resultStride != expectedStride)
|
||||
if (ShapedType::isStatic(resultStride) && resultStride != expectedStride)
|
||||
return emitError("expected result type with stride = ")
|
||||
<< (ShapedType::isDynamic(expectedStride)
|
||||
? std::string("dynamic")
|
||||
@ -1928,7 +1928,7 @@ OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
|
||||
}
|
||||
|
||||
// reinterpret_cast(x) w/o offset/shape/stride changes -> x
|
||||
if (!ShapedType::isDynamicShape(getType().getShape()) &&
|
||||
if (ShapedType::isStaticShape(getType().getShape()) &&
|
||||
src.getType() == getType() && getStaticOffsets().front() == 0) {
|
||||
return src;
|
||||
}
|
||||
@ -2379,7 +2379,7 @@ LogicalResult ExpandShapeOp::verify() {
|
||||
DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
|
||||
ArrayRef<int64_t> resShape = getResult().getType().getShape();
|
||||
for (auto [pos, shape] : llvm::enumerate(resShape)) {
|
||||
if (!ShapedType::isDynamic(shape) && shape != staticOutputShapes[pos]) {
|
||||
if (ShapedType::isStatic(shape) && shape != staticOutputShapes[pos]) {
|
||||
return emitOpError("invalid output shape provided at pos ") << pos;
|
||||
}
|
||||
}
|
||||
@ -2422,7 +2422,7 @@ computeCollapsedLayoutMap(MemRefType srcType,
|
||||
ArrayRef<int64_t> ref = llvm::ArrayRef(reassoc);
|
||||
while (srcShape[ref.back()] == 1 && ref.size() > 1)
|
||||
ref = ref.drop_back();
|
||||
if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
|
||||
if (ShapedType::isStatic(srcShape[ref.back()]) || ref.size() == 1) {
|
||||
resultStrides.push_back(srcStrides[ref.back()]);
|
||||
} else {
|
||||
// Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
|
||||
@ -3509,7 +3509,7 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
|
||||
for (unsigned dim = 0, e = rank; dim < e; ++dim) {
|
||||
int64_t dimSize = memrefType.getDimSize(dim);
|
||||
// If this is already static dimension, keep it.
|
||||
if (!ShapedType::isDynamic(dimSize)) {
|
||||
if (ShapedType::isStatic(dimSize)) {
|
||||
newShapeConstants.push_back(dimSize);
|
||||
continue;
|
||||
}
|
||||
|
@ -118,7 +118,7 @@ resolveSubviewStridedMetadata(RewriterBase &rewriter,
|
||||
// Assert that the computed offset matches the offset of the result type of
|
||||
// the subview op (if both are static).
|
||||
std::optional<int64_t> computedOffset = getConstantIntValue(finalOffset);
|
||||
if (computedOffset && !ShapedType::isDynamic(resultOffset))
|
||||
if (computedOffset && ShapedType::isStatic(resultOffset))
|
||||
assert(*computedOffset == resultOffset &&
|
||||
"mismatch between computed offset and result type offset");
|
||||
#endif // NDEBUG
|
||||
@ -158,7 +158,7 @@ resolveSubviewStridedMetadata(RewriterBase &rewriter,
|
||||
// Assert that the computed stride matches the stride of the result type of
|
||||
// the subview op (if both are static).
|
||||
std::optional<int64_t> computedStride = getConstantIntValue(strides[i]);
|
||||
if (computedStride && !ShapedType::isDynamic(resultStrides[j]))
|
||||
if (computedStride && ShapedType::isStatic(resultStrides[j]))
|
||||
assert(*computedStride == resultStrides[j] &&
|
||||
"mismatch between computed stride and result type stride");
|
||||
++j;
|
||||
@ -458,7 +458,7 @@ getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
|
||||
MemRefType collapseShapeType = collapseShape.getResultType();
|
||||
|
||||
uint64_t size = collapseShapeType.getDimSize(groupId);
|
||||
if (!ShapedType::isDynamic(size)) {
|
||||
if (ShapedType::isStatic(size)) {
|
||||
collapsedSize.push_back(builder.getIndexAttr(size));
|
||||
return collapsedSize;
|
||||
}
|
||||
@ -1091,7 +1091,7 @@ class ExtractStridedMetadataOpCastFolder
|
||||
|
||||
auto getConstantOrValue = [&rewriter](int64_t constant,
|
||||
OpFoldResult ofr) -> OpFoldResult {
|
||||
return !ShapedType::isDynamic(constant)
|
||||
return ShapedType::isStatic(constant)
|
||||
? OpFoldResult(rewriter.getIndexAttr(constant))
|
||||
: ofr;
|
||||
};
|
||||
|
@ -264,7 +264,7 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
|
||||
// add halo sizes if requested
|
||||
int haloAxis = 0;
|
||||
for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
|
||||
if (!ShapedType::isDynamic(outShape[tensorAxis]) &&
|
||||
if (ShapedType::isStatic(outShape[tensorAxis]) &&
|
||||
!innerSplitAxes.empty()) {
|
||||
if (haloSizes[haloAxis * 2] >= 0 &&
|
||||
haloSizes[haloAxis * 2 + 1] >= 0) {
|
||||
@ -415,7 +415,7 @@ LogicalResult MeshOp::verify() {
|
||||
return emitOpError("rank of mesh is expected to be a positive integer");
|
||||
|
||||
for (int64_t dimSize : getShape()) {
|
||||
if (dimSize < 0 && !ShapedType::isDynamic(dimSize))
|
||||
if (dimSize < 0 && ShapedType::isStatic(dimSize))
|
||||
return emitOpError("dimension size of a mesh is expected to be "
|
||||
"non-negative or dynamic");
|
||||
}
|
||||
@ -609,7 +609,7 @@ LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||
auto shardedDimsOffsets = getStaticShardedDimsOffsets();
|
||||
if (!shardedDimsOffsets.empty()) {
|
||||
auto meshShape = mesh.value().getShape();
|
||||
assert(!ShapedType::isDynamicShape(meshShape));
|
||||
assert(ShapedType::isStaticShape(meshShape));
|
||||
uint64_t pos = 0;
|
||||
for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(getSplitAxes())) {
|
||||
if (!innerSplitAxes.empty()) {
|
||||
@ -621,7 +621,7 @@ LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||
if (shardedDimsOffsets.size() <= pos + i) {
|
||||
return emitError() << "sharded dims offsets has wrong size.";
|
||||
}
|
||||
if (!ShapedType::isDynamic(shardedDimsOffsets[pos + i])) {
|
||||
if (ShapedType::isStatic(shardedDimsOffsets[pos + i])) {
|
||||
if (shardedDimsOffsets[pos + i] < off) {
|
||||
return emitError()
|
||||
<< "sharded dims offsets must be non-decreasing.";
|
||||
@ -1036,8 +1036,8 @@ static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < device.size(); ++i) {
|
||||
if (!ShapedType::isDynamic(device[i]) &&
|
||||
!ShapedType::isDynamic(meshShape[meshAxes[i]]) &&
|
||||
if (ShapedType::isStatic(device[i]) &&
|
||||
ShapedType::isStatic(meshShape[meshAxes[i]]) &&
|
||||
meshShape[meshAxes[i]] <= device[i]) {
|
||||
return emitError(loc)
|
||||
<< "Out of bounds coordinate " << i << " for in-group device \""
|
||||
@ -1065,8 +1065,7 @@ static LogicalResult verifyDimensionCompatibility(Location loc,
|
||||
int64_t expectedDimSize,
|
||||
int64_t resultDimSize,
|
||||
int64_t resultAxis) {
|
||||
if (!ShapedType::isDynamic(resultDimSize) &&
|
||||
expectedDimSize != resultDimSize) {
|
||||
if (ShapedType::isStatic(resultDimSize) && expectedDimSize != resultDimSize) {
|
||||
return emitError(loc) << "Dimension size mismatch for result axis "
|
||||
<< resultAxis << ". Expected "
|
||||
<< (ShapedType::isDynamic(expectedDimSize)
|
||||
|
@ -453,8 +453,8 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
|
||||
auto srcHaloSizes = sourceSharding.getStaticHaloSizes();
|
||||
auto tgtHaloSizes = targetSharding.getStaticHaloSizes();
|
||||
assert(srcHaloSizes.empty() || srcHaloSizes.size() == tgtHaloSizes.size());
|
||||
assert(((srcHaloSizes.empty() || !ShapedType::isDynamicShape(srcHaloSizes)) &&
|
||||
!ShapedType::isDynamicShape(tgtHaloSizes) &&
|
||||
assert(((srcHaloSizes.empty() || ShapedType::isStaticShape(srcHaloSizes)) &&
|
||||
ShapedType::isStaticShape(tgtHaloSizes) &&
|
||||
sourceShard.getType().hasStaticShape()) &&
|
||||
"dynamic shapes/halos are not supported yet for mesh-spmdization");
|
||||
auto rank = sourceShard.getType().getRank();
|
||||
|
@ -518,7 +518,7 @@ SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape,
|
||||
SmallVector<AffineExpr> dimRep;
|
||||
dimRep.reserve(srcShape.size());
|
||||
for (int64_t sz : srcShape) {
|
||||
if (!ShapedType::isDynamic(sz)) {
|
||||
if (ShapedType::isStatic(sz)) {
|
||||
// Push back the max coordinate for the given dimension/level size.
|
||||
dimRep.push_back(getAffineConstantExpr(sz - 1, getContext()));
|
||||
} else {
|
||||
@ -1531,7 +1531,7 @@ OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
|
||||
};
|
||||
|
||||
SmallVector<Size> lvlShape = stt.getLvlShape();
|
||||
if (!ShapedType::isDynamic(lvlShape[lvl]))
|
||||
if (ShapedType::isStatic(lvlShape[lvl]))
|
||||
return getIndexAttr(lvlShape[lvl]);
|
||||
|
||||
return {};
|
||||
@ -1876,7 +1876,7 @@ LogicalResult ConcatenateOp::verify() {
|
||||
for (Dimension d = 0; d < dimRank; d++) {
|
||||
const Size dstSh = dstTp.getDimShape()[d];
|
||||
if (d == concatDim) {
|
||||
if (!ShapedType::isDynamic(dstSh)) {
|
||||
if (ShapedType::isStatic(dstSh)) {
|
||||
// If we reach here, then all inputs have static shapes. So we
|
||||
// can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)`
|
||||
// to avoid redundant assertions in the loop.
|
||||
@ -1894,7 +1894,7 @@ LogicalResult ConcatenateOp::verify() {
|
||||
Size prev = dstSh;
|
||||
for (const auto src : getInputs()) {
|
||||
const auto sh = getSparseTensorType(src).getDimShape()[d];
|
||||
if (!ShapedType::isDynamic(prev) && sh != prev)
|
||||
if (ShapedType::isStatic(prev) && sh != prev)
|
||||
return emitError("All dimensions (expect for the concatenating one) "
|
||||
"should be equal.");
|
||||
prev = sh;
|
||||
@ -2058,7 +2058,7 @@ LogicalResult SortOp::verify() {
|
||||
const auto checkDim = [&](Value v, Size minSize,
|
||||
const char *message) -> LogicalResult {
|
||||
const Size sh = getMemRefType(v).getShape()[0];
|
||||
if (!ShapedType::isDynamic(sh) && sh < minSize)
|
||||
if (ShapedType::isStatic(sh) && sh < minSize)
|
||||
return emitError(
|
||||
llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
|
||||
return success();
|
||||
|
@ -259,7 +259,7 @@ translateMap(linalg::GenericOp op, PatternRewriter &rewriter) {
|
||||
// translation.
|
||||
auto populateCstMapping = [ctx](DenseMap<AffineExpr, AffineExpr> &cstMapping,
|
||||
unsigned pos, int64_t lvlSz) {
|
||||
if (!ShapedType::isDynamic(lvlSz)) {
|
||||
if (ShapedType::isStatic(lvlSz)) {
|
||||
auto c0 = getAffineConstantExpr(0, ctx);
|
||||
auto lvlExp = getAffineDimExpr(pos, ctx);
|
||||
auto szExp = getAffineConstantExpr(lvlSz, ctx);
|
||||
|
@ -1348,7 +1348,7 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
|
||||
Level trailCOORank = stt.getLvlRank() - trailCOOStart;
|
||||
// Sets up SparseTensorSpecifier.
|
||||
for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
|
||||
assert(!ShapedType::isDynamic(stt.getDimShape()[lvl]));
|
||||
assert(ShapedType::isStatic(stt.getDimShape()[lvl]));
|
||||
|
||||
// Sets up the level size.
|
||||
auto lvlSize = constantIndex(rewriter, loc, stt.getLvlShape()[lvl]);
|
||||
|
@ -86,7 +86,7 @@ static Value createOrFoldLvlCall(OpBuilder &builder, Location loc,
|
||||
const Dimension dim =
|
||||
stt.isIdentity() ? lvl : stt.getDimToLvl().getDimPosition(lvl);
|
||||
const Size sz = stt.getDynamicDimSize(dim);
|
||||
if (!ShapedType::isDynamic(sz))
|
||||
if (ShapedType::isStatic(sz))
|
||||
return constantIndex(builder, loc, sz);
|
||||
// If we cannot statically compute the size from the shape, then we
|
||||
// must dynamically query it. (In principle we could also dynamically
|
||||
@ -103,7 +103,7 @@ static Value createOrFoldDimCall(OpBuilder &builder, Location loc,
|
||||
SparseTensorType stt, Value tensor,
|
||||
Dimension dim) {
|
||||
const Size sz = stt.getDynamicDimSize(dim);
|
||||
if (!ShapedType::isDynamic(sz))
|
||||
if (ShapedType::isStatic(sz))
|
||||
return constantIndex(builder, loc, sz);
|
||||
if (stt.hasEncoding())
|
||||
return genDimSizeCall(builder, loc, tensor, dim);
|
||||
|
@ -1245,7 +1245,7 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
|
||||
// by concatenate op verifier, which saves us from computing the offset
|
||||
// dynamically.
|
||||
const Size sz = getSparseTensorType(input).getDynamicDimSize(conDim);
|
||||
assert(!ShapedType::isDynamic(sz));
|
||||
assert(ShapedType::isStatic(sz));
|
||||
offset = rewriter.create<arith::AddIOp>(loc, offset,
|
||||
constantIndex(rewriter, loc, sz));
|
||||
iterArg = foreachOp.getResult(0);
|
||||
|
@ -23,7 +23,7 @@ using namespace mlir::tensor;
|
||||
static OpFoldResult getCollapsedOutputDimFromInputShape(
|
||||
OpBuilder &builder, Location loc, int64_t dimIndex, Value src,
|
||||
ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociationMap) {
|
||||
if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) {
|
||||
if (ShapedType::isStatic(dstStaticShape[dimIndex])) {
|
||||
// Static dimension: return Attribute.
|
||||
return builder.getIndexAttr(dstStaticShape[dimIndex]);
|
||||
}
|
||||
|
@ -292,7 +292,7 @@ bool mlir::tensor::preservesStaticInformation(Type source, Type target) {
|
||||
|
||||
// If cast is towards more static sizes along any dimension, don't fold.
|
||||
for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
|
||||
if (!ShapedType::isDynamic(std::get<0>(t)) &&
|
||||
if (ShapedType::isStatic(std::get<0>(t)) &&
|
||||
ShapedType::isDynamic(std::get<1>(t)))
|
||||
return false;
|
||||
}
|
||||
@ -1235,7 +1235,7 @@ struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
|
||||
|
||||
// Case 2 : The tensor cast shape is static, but empty tensor result
|
||||
// shape is dynamic.
|
||||
if (!ShapedType::isDynamic(newDim)) {
|
||||
if (ShapedType::isStatic(newDim)) {
|
||||
newMixedSizes.push_back(rewriter.getIndexAttr(newDim));
|
||||
continue;
|
||||
}
|
||||
@ -2197,7 +2197,7 @@ struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
|
||||
|
||||
for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
|
||||
for (uint64_t outDim : innerReassoc) {
|
||||
if (!ShapedType::isDynamic(newOutputShape[outDim]))
|
||||
if (ShapedType::isStatic(newOutputShape[outDim]))
|
||||
continue;
|
||||
|
||||
// If the cast's src type is dynamic, don't infer any of the
|
||||
@ -3579,7 +3579,7 @@ struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
|
||||
continue;
|
||||
OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
|
||||
int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
|
||||
assert(!ShapedType::isDynamic(sourceSize) &&
|
||||
assert(ShapedType::isStatic(sourceSize) &&
|
||||
"expected padded dimension to have a static size");
|
||||
if (getConstantIntValue(sliceSize) != sourceSize) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -849,7 +849,7 @@ static LogicalResult verifyPoolingOp(T op) {
|
||||
<< kernelSize << ") / " << strideSize;
|
||||
|
||||
const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
|
||||
if (!ShapedType::isDynamic(outputSize) && calculatedOutSize != outputSize)
|
||||
if (ShapedType::isStatic(outputSize) && calculatedOutSize != outputSize)
|
||||
return op.emitOpError("calculated output ")
|
||||
<< dimName << " did not match expected: "
|
||||
<< "calculated=" << calculatedOutSize
|
||||
@ -1301,12 +1301,12 @@ LogicalResult tosa::RFFT2dOp::verify() {
|
||||
return success();
|
||||
|
||||
const int64_t height = inputType.getDimSize(1);
|
||||
if (!ShapedType::isDynamic(height) &&
|
||||
if (ShapedType::isStatic(height) &&
|
||||
failed(verifyDimIsPowerOfTwo(*this, height, "height")))
|
||||
return failure();
|
||||
|
||||
const int64_t width = inputType.getDimSize(2);
|
||||
if (!ShapedType::isDynamic(width) &&
|
||||
if (ShapedType::isStatic(width) &&
|
||||
failed(verifyDimIsPowerOfTwo(*this, width, "width")))
|
||||
return failure();
|
||||
|
||||
@ -1323,7 +1323,7 @@ LogicalResult tosa::RFFT2dOp::verify() {
|
||||
|
||||
// Output width dimension expected to be input_width / 2 + 1
|
||||
const int64_t outputWidth = outputType.getDimSize(2);
|
||||
if (!ShapedType::isDynamic(width) && !ShapedType::isDynamic(outputWidth) &&
|
||||
if (ShapedType::isStatic(width) && ShapedType::isStatic(outputWidth) &&
|
||||
(outputWidth != (width / 2) + 1))
|
||||
return emitOpError(
|
||||
"expected output width to be equal to input_width / 2 + 1, got ")
|
||||
@ -1357,13 +1357,13 @@ LogicalResult tosa::FFT2dOp::verify() {
|
||||
|
||||
const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
|
||||
inputImagType.getDimSize(1));
|
||||
if (!ShapedType::isDynamic(height) &&
|
||||
if (ShapedType::isStatic(height) &&
|
||||
failed(verifyDimIsPowerOfTwo(*this, height, "height")))
|
||||
return failure();
|
||||
|
||||
const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
|
||||
inputImagType.getDimSize(2));
|
||||
if (!ShapedType::isDynamic(width) &&
|
||||
if (ShapedType::isStatic(width) &&
|
||||
failed(verifyDimIsPowerOfTwo(*this, width, "width")))
|
||||
return failure();
|
||||
|
||||
@ -1965,7 +1965,7 @@ LogicalResult tosa::TableOp::verify() {
|
||||
for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
|
||||
int64_t dim = it.index();
|
||||
auto [inputDim, outputDim] = it.value();
|
||||
if (!ShapedType::isDynamic(outputDim) && outputDim != inputDim) {
|
||||
if (ShapedType::isStatic(outputDim) && outputDim != inputDim) {
|
||||
return emitOpError() << "dim(result, " << dim << ") = " << outputDim
|
||||
<< " doesn't match dim(input, " << dim
|
||||
<< ") = " << inputDim;
|
||||
@ -2100,7 +2100,7 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
|
||||
int64_t numElements = inputShape.getNumElements();
|
||||
int64_t staticMul = 1;
|
||||
for (auto val : newShapeValue) {
|
||||
if (!ShapedType::isDynamic(val)) {
|
||||
if (ShapedType::isStatic(val)) {
|
||||
staticMul *= val;
|
||||
}
|
||||
}
|
||||
@ -2988,12 +2988,12 @@ static LogicalResult poolingInferReturnTypes(
|
||||
int64_t height = inputShape.getDimSize(1);
|
||||
int64_t width = inputShape.getDimSize(2);
|
||||
|
||||
if (!ShapedType::isDynamic(height)) {
|
||||
if (ShapedType::isStatic(height)) {
|
||||
int64_t padded = height + pad[0] + pad[1] - kernel[0];
|
||||
outputShape[1] = padded / stride[0] + 1;
|
||||
}
|
||||
|
||||
if (!ShapedType::isDynamic(width)) {
|
||||
if (ShapedType::isStatic(width)) {
|
||||
int64_t padded = width + pad[2] + pad[3] - kernel[1];
|
||||
outputShape[2] = padded / stride[1] + 1;
|
||||
}
|
||||
@ -3042,16 +3042,14 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
|
||||
llvm::ArrayRef<int64_t> stride = adaptor.getStride();
|
||||
llvm::ArrayRef<int64_t> padding = adaptor.getPad();
|
||||
|
||||
if (!ShapedType::isDynamic(inputHeight) &&
|
||||
!ShapedType::isDynamic(weightHeight)) {
|
||||
if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
|
||||
int64_t inputSize = inputHeight + padding[0] + padding[1];
|
||||
int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
|
||||
int64_t unstridedResult = inputSize - filterSize + 1;
|
||||
outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
|
||||
}
|
||||
|
||||
if (!ShapedType::isDynamic(inputWidth) &&
|
||||
!ShapedType::isDynamic(weightWidth)) {
|
||||
if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
|
||||
int64_t inputSize = inputWidth + padding[2] + padding[3];
|
||||
int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
|
||||
int64_t unstridedResult = inputSize - filterSize + 1;
|
||||
@ -3111,24 +3109,21 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
|
||||
llvm::ArrayRef<int64_t> stride = adaptor.getStride();
|
||||
llvm::ArrayRef<int64_t> pad = adaptor.getPad();
|
||||
|
||||
if (!ShapedType::isDynamic(inputDepth) &&
|
||||
!ShapedType::isDynamic(weightDepth)) {
|
||||
if (ShapedType::isStatic(inputDepth) && ShapedType::isStatic(weightDepth)) {
|
||||
int32_t inputSize = inputDepth + pad[0] + pad[1];
|
||||
int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
|
||||
int32_t unstridedResult = inputSize - filterSize + 1;
|
||||
outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
|
||||
}
|
||||
|
||||
if (!ShapedType::isDynamic(inputHeight) &&
|
||||
!ShapedType::isDynamic(weightHeight)) {
|
||||
if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
|
||||
int32_t inputSize = inputHeight + pad[2] + pad[3];
|
||||
int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
|
||||
int32_t unstridedResult = inputSize - filterSize + 1;
|
||||
outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
|
||||
}
|
||||
|
||||
if (!ShapedType::isDynamic(inputWidth) &&
|
||||
!ShapedType::isDynamic(weightWidth)) {
|
||||
if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
|
||||
int32_t inputSize = inputWidth + pad[4] + pad[5];
|
||||
int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
|
||||
int32_t unstridedResult = inputSize - filterSize + 1;
|
||||
@ -3213,8 +3208,8 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
|
||||
|
||||
// If both inputChannels and depthChannels are available we can determine
|
||||
// the output channels.
|
||||
if (!ShapedType::isDynamic(inputChannels) &&
|
||||
!ShapedType::isDynamic(depthChannels)) {
|
||||
if (ShapedType::isStatic(inputChannels) &&
|
||||
ShapedType::isStatic(depthChannels)) {
|
||||
outputShape[3] = inputChannels * depthChannels;
|
||||
}
|
||||
|
||||
@ -3230,16 +3225,14 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
|
||||
llvm::ArrayRef<int64_t> padding = adaptor.getPad();
|
||||
llvm::ArrayRef<int64_t> stride = adaptor.getStride();
|
||||
|
||||
if (!ShapedType::isDynamic(inputHeight) &&
|
||||
!ShapedType::isDynamic(weightHeight)) {
|
||||
if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
|
||||
int64_t inputSize = inputHeight + padding[0] + padding[1];
|
||||
int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
|
||||
int64_t unstridedResult = inputSize - filterSize + 1;
|
||||
outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
|
||||
}
|
||||
|
||||
if (!ShapedType::isDynamic(inputWidth) &&
|
||||
!ShapedType::isDynamic(weightWidth)) {
|
||||
if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
|
||||
int64_t inputSize = inputWidth + padding[2] + padding[3];
|
||||
int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
|
||||
int64_t unstridedResult = inputSize - filterSize + 1;
|
||||
@ -3299,16 +3292,14 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
|
||||
llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
|
||||
llvm::ArrayRef<int64_t> stride = adaptor.getStride();
|
||||
|
||||
if (!ShapedType::isDynamic(inputHeight) &&
|
||||
!ShapedType::isDynamic(weightHeight)) {
|
||||
if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
|
||||
int64_t calculateSize =
|
||||
(inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
|
||||
outputShape[1] =
|
||||
ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
|
||||
}
|
||||
|
||||
if (!ShapedType::isDynamic(inputWidth) &&
|
||||
!ShapedType::isDynamic(weightWidth)) {
|
||||
if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
|
||||
int64_t calculateSize =
|
||||
(inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
|
||||
outputShape[2] =
|
||||
@ -3354,7 +3345,7 @@ LogicalResult TransposeConv2DOp::verify() {
|
||||
|
||||
if (weightType) {
|
||||
const int64_t kernelHeight = weightType.getDimSize(1);
|
||||
if (!ShapedType::isDynamic(kernelHeight)) {
|
||||
if (ShapedType::isStatic(kernelHeight)) {
|
||||
if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
|
||||
"out_pad_top", "KH")))
|
||||
return failure();
|
||||
@ -3365,7 +3356,7 @@ LogicalResult TransposeConv2DOp::verify() {
|
||||
}
|
||||
|
||||
const int64_t kernelWidth = weightType.getDimSize(2);
|
||||
if (!ShapedType::isDynamic(kernelWidth)) {
|
||||
if (ShapedType::isStatic(kernelWidth)) {
|
||||
if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
|
||||
"out_pad_left", "KW")))
|
||||
return failure();
|
||||
@ -3388,8 +3379,8 @@ LogicalResult TransposeConv2DOp::verify() {
|
||||
const int64_t kernelHeight = weightType.getDimSize(1);
|
||||
const int64_t outputHeight = outputType.getDimSize(1);
|
||||
|
||||
if (!ShapedType::isDynamic(inputHeight) &&
|
||||
!ShapedType::isDynamic(outputHeight)) {
|
||||
if (ShapedType::isStatic(inputHeight) &&
|
||||
ShapedType::isStatic(outputHeight)) {
|
||||
if (outputHeight !=
|
||||
(inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
|
||||
return emitOpError(
|
||||
@ -3404,8 +3395,7 @@ LogicalResult TransposeConv2DOp::verify() {
|
||||
const int64_t kernelWidth = weightType.getDimSize(2);
|
||||
const int64_t outputWidth = outputType.getDimSize(2);
|
||||
|
||||
if (!ShapedType::isDynamic(inputWidth) &&
|
||||
!ShapedType::isDynamic(outputWidth)) {
|
||||
if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(outputWidth)) {
|
||||
if (outputWidth !=
|
||||
(inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
|
||||
return emitOpError(
|
||||
|
@ -505,7 +505,7 @@ LogicalResult mlir::reshapeLikeShapesAreCompatible(
|
||||
linearizedStaticShape *= dim.value();
|
||||
}
|
||||
if (foundDynamicShape) {
|
||||
if (!ShapedType::isDynamic(collapsedShape[map.index()])) {
|
||||
if (ShapedType::isStatic(collapsedShape[map.index()])) {
|
||||
return emitError(
|
||||
"expected dimension " + Twine(map.index()) +
|
||||
" of collapsed type to be dynamic since one or more of the "
|
||||
|
@ -280,13 +280,13 @@ std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
|
||||
|
||||
bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets) {
|
||||
return llvm::none_of(sizesOrOffsets, [](int64_t value) {
|
||||
return !ShapedType::isDynamic(value) && value < 0;
|
||||
return ShapedType::isStatic(value) && value < 0;
|
||||
});
|
||||
}
|
||||
|
||||
bool hasValidStrides(SmallVector<int64_t> strides) {
|
||||
return llvm::none_of(strides, [](int64_t value) {
|
||||
return !ShapedType::isDynamic(value) && value == 0;
|
||||
return ShapedType::isStatic(value) && value == 0;
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -1107,7 +1107,7 @@ Type ContractionOp::getExpectedMaskType() {
|
||||
rhsType.getScalableDims()[dimIdx];
|
||||
}
|
||||
|
||||
assert(!ShapedType::isDynamicShape(maskShape) &&
|
||||
assert(ShapedType::isStaticShape(maskShape) &&
|
||||
"Mask shape couldn't be computed");
|
||||
|
||||
return VectorType::get(maskShape,
|
||||
@ -2061,7 +2061,7 @@ static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
|
||||
// `opChange` is a flag. If it is true, it means to update `op` in place.
|
||||
bool opChange = false;
|
||||
for (unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
|
||||
if (!ShapedType::isDynamic(staticPosition[i]))
|
||||
if (ShapedType::isStatic(staticPosition[i]))
|
||||
continue;
|
||||
Attribute positionAttr = dynamicPositionAttr[index];
|
||||
Value position = dynamicPosition[index++];
|
||||
|
@ -339,7 +339,7 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
|
||||
// FIXME: This computation is too weak - it ignores the read indices.
|
||||
for (unsigned i = 0; i < readRank; i++)
|
||||
inBoundsVal[i] = (sourceShape[i] == inputVectorSizes[i]) &&
|
||||
!ShapedType::isDynamic(sourceShape[i]);
|
||||
ShapedType::isStatic(sourceShape[i]);
|
||||
}
|
||||
auto transferReadOp = builder.create<vector::TransferReadOp>(
|
||||
loc,
|
||||
|
@ -230,8 +230,8 @@ void StridedLayoutAttr::print(llvm::raw_ostream &os) const {
|
||||
/// Returns true if this layout is static, i.e. the strides and offset all have
|
||||
/// a known value > 0.
|
||||
bool StridedLayoutAttr::hasStaticLayout() const {
|
||||
return !ShapedType::isDynamic(getOffset()) &&
|
||||
!ShapedType::isDynamicShape(getStrides());
|
||||
return ShapedType::isStatic(getOffset()) &&
|
||||
ShapedType::isStaticShape(getStrides());
|
||||
}
|
||||
|
||||
/// Returns the strided layout as an affine map.
|
||||
@ -1818,7 +1818,7 @@ AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
|
||||
|
||||
// AffineExpr for offset.
|
||||
// Static case.
|
||||
if (!ShapedType::isDynamic(offset)) {
|
||||
if (ShapedType::isStatic(offset)) {
|
||||
auto cst = getAffineConstantExpr(offset, context);
|
||||
expr = cst;
|
||||
} else {
|
||||
@ -1834,7 +1834,7 @@ AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
|
||||
auto d = getAffineDimExpr(dim, context);
|
||||
AffineExpr mult;
|
||||
// Static case.
|
||||
if (!ShapedType::isDynamic(stride))
|
||||
if (ShapedType::isStatic(stride))
|
||||
mult = getAffineConstantExpr(stride, context);
|
||||
else
|
||||
// Dynamic case, new symbol for each new stride.
|
||||
|
@ -321,7 +321,7 @@ RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
|
||||
ArrayRef<int64_t> shape, Type elementType,
|
||||
Attribute encoding) {
|
||||
for (int64_t s : shape)
|
||||
if (s < 0 && !ShapedType::isDynamic(s))
|
||||
if (s < 0 && ShapedType::isStatic(s))
|
||||
return emitError() << "invalid tensor dimension size";
|
||||
if (auto v = llvm::dyn_cast_or_null<VerifiableTensorEncoding>(encoding))
|
||||
if (failed(v.verifyEncoding(shape, elementType, emitError)))
|
||||
@ -644,7 +644,7 @@ LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
|
||||
|
||||
// Negative sizes are not allowed except for `kDynamic`.
|
||||
for (int64_t s : shape)
|
||||
if (s < 0 && !ShapedType::isDynamic(s))
|
||||
if (s < 0 && ShapedType::isStatic(s))
|
||||
return emitError() << "invalid memref size";
|
||||
|
||||
assert(layout && "missing layout specification");
|
||||
|
@ -62,7 +62,7 @@ LogicalResult mlir::verifyCompatibleShape(ArrayRef<int64_t> shape1,
|
||||
for (auto dims : llvm::zip(shape1, shape2)) {
|
||||
int64_t dim1 = std::get<0>(dims);
|
||||
int64_t dim2 = std::get<1>(dims);
|
||||
if (!ShapedType::isDynamic(dim1) && !ShapedType::isDynamic(dim2) &&
|
||||
if (ShapedType::isStatic(dim1) && ShapedType::isStatic(dim2) &&
|
||||
dim1 != dim2)
|
||||
return failure();
|
||||
}
|
||||
|
@ -124,12 +124,12 @@ mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
|
||||
return failure();
|
||||
|
||||
for (int64_t offset : op.getStaticOffsets()) {
|
||||
if (offset < 0 && !ShapedType::isDynamic(offset))
|
||||
if (offset < 0 && ShapedType::isStatic(offset))
|
||||
return op->emitError("expected offsets to be non-negative, but got ")
|
||||
<< offset;
|
||||
}
|
||||
for (int64_t size : op.getStaticSizes()) {
|
||||
if (size < 0 && !ShapedType::isDynamic(size))
|
||||
if (size < 0 && ShapedType::isStatic(size))
|
||||
return op->emitError("expected sizes to be non-negative, but got ")
|
||||
<< size;
|
||||
}
|
||||
|
@ -2497,6 +2497,11 @@ class ShapedType(Type):
|
||||
Returns whether the given dimension size indicates a dynamic dimension.
|
||||
"""
|
||||
@staticmethod
|
||||
def is_static_size(dim_size: int) -> bool:
|
||||
"""
|
||||
Returns whether the given dimension size indicates a static dimension.
|
||||
"""
|
||||
@staticmethod
|
||||
def isinstance(other: Type) -> bool: ...
|
||||
def __init__(self, cast_from_type: Type) -> None: ...
|
||||
def get_dim_size(self, dim: int) -> int:
|
||||
@ -2507,10 +2512,18 @@ class ShapedType(Type):
|
||||
"""
|
||||
Returns whether the dim-th dimension of the given shaped type is dynamic.
|
||||
"""
|
||||
def is_static_dim(self, dim: int) -> bool:
|
||||
"""
|
||||
Returns whether the dim-th dimension of the given shaped type is static.
|
||||
"""
|
||||
def is_dynamic_stride_or_offset(self, dim_size: int) -> bool:
|
||||
"""
|
||||
Returns whether the given value is used as a placeholder for dynamic strides and offsets in shaped types.
|
||||
"""
|
||||
def is_static_stride_or_offset(self, dim_size: int) -> bool:
|
||||
"""
|
||||
Returns whether the given shaped type stride or offset value is statically-sized.
|
||||
"""
|
||||
@property
|
||||
def element_type(self) -> Type:
|
||||
"""
|
||||
|
@ -330,8 +330,29 @@ def testConcreteShapedType():
|
||||
print("dim size:", vector.get_dim_size(1))
|
||||
# CHECK: is_dynamic_size: False
|
||||
print("is_dynamic_size:", vector.is_dynamic_size(3))
|
||||
# CHECK: is_static_size: True
|
||||
print("is_static_size:", vector.is_static_size(3))
|
||||
# CHECK: is_dynamic_stride_or_offset: False
|
||||
print("is_dynamic_stride_or_offset:", vector.is_dynamic_stride_or_offset(1))
|
||||
# CHECK: is_static_stride_or_offset: True
|
||||
print("is_static_stride_or_offset:", vector.is_static_stride_or_offset(1))
|
||||
|
||||
dynamic_size_val = vector.get_dynamic_size()
|
||||
dynamic_stride_val = vector.get_dynamic_stride_or_offset()
|
||||
# CHECK: is_dynamic_size_with_dynamic: True
|
||||
print("is_dynamic_size_with_dynamic:", vector.is_dynamic_size(dynamic_size_val))
|
||||
# CHECK: is_static_size_with_dynamic: False
|
||||
print("is_static_size_with_dynamic:", vector.is_static_size(dynamic_size_val))
|
||||
# CHECK: is_dynamic_stride_or_offset_with_dynamic: True
|
||||
print(
|
||||
"is_dynamic_stride_or_offset_with_dynamic:",
|
||||
vector.is_dynamic_stride_or_offset(dynamic_stride_val),
|
||||
)
|
||||
# CHECK: is_static_stride_or_offset_with_dynamic: False
|
||||
print(
|
||||
"is_static_stride_or_offset_with_dynamic:",
|
||||
vector.is_static_stride_or_offset(dynamic_stride_val),
|
||||
)
|
||||
# CHECK: isinstance(ShapedType): True
|
||||
print("isinstance(ShapedType):", isinstance(vector, ShapedType))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user