[mlir] Add isStatic* size check for ShapedTypes. NFCI. (#147085)

The motivation is to avoid having to negate `isDynamic*` checks, avoid
double negations, and allow for `ShapedType::isStaticDim` to be used in
ADT functions without having to wrap it in a lambda performing the
negation.

Also add the new functions to C and Python bindings.
This commit is contained in:
Jakub Kuderski 2025-07-07 14:57:27 -04:00 committed by GitHub
parent 0032148ea6
commit 6512ca7ddb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
37 changed files with 206 additions and 118 deletions

View File

@ -289,9 +289,12 @@ MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetRank(MlirType type);
/// Checks whether the given shaped type has a static shape. /// Checks whether the given shaped type has a static shape.
MLIR_CAPI_EXPORTED bool mlirShapedTypeHasStaticShape(MlirType type); MLIR_CAPI_EXPORTED bool mlirShapedTypeHasStaticShape(MlirType type);
/// Checks wither the dim-th dimension of the given shaped type is dynamic. /// Checks whether the dim-th dimension of the given shaped type is dynamic.
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim); MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim);
/// Checks whether the dim-th dimension of the given shaped type is static.
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsStaticDim(MlirType type, intptr_t dim);
/// Returns the dim-th dimension of the given ranked shaped type. /// Returns the dim-th dimension of the given ranked shaped type.
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDimSize(MlirType type, MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDimSize(MlirType type,
intptr_t dim); intptr_t dim);
@ -300,17 +303,25 @@ MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDimSize(MlirType type,
/// in shaped types. /// in shaped types.
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicSize(int64_t size); MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicSize(int64_t size);
/// Checks whether the given shaped type dimension value is statically-sized.
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsStaticSize(int64_t size);
/// Returns the value indicating a dynamic size in a shaped type. Prefer /// Returns the value indicating a dynamic size in a shaped type. Prefer
/// mlirShapedTypeIsDynamicSize to direct comparisons with this value. /// mlirShapedTypeIsDynamicSize and mlirShapedTypeIsStaticSize to direct
/// comparisons with this value.
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicSize(void); MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicSize(void);
/// Checks whether the given value is used as a placeholder for dynamic strides /// Checks whether the given value is used as a placeholder for dynamic strides
/// and offsets in shaped types. /// and offsets in shaped types.
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val); MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val);
/// Checks whether the given dimension value of a stride or an offset is
/// statically-sized.
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsStaticStrideOrOffset(int64_t val);
/// Returns the value indicating a dynamic stride or offset in a shaped type. /// Returns the value indicating a dynamic stride or offset in a shaped type.
/// Prefer mlirShapedTypeGetDynamicStrideOrOffset to direct comparisons with /// Prefer mlirShapedTypeIsDynamicStrideOrOffset and
/// this value. /// mlirShapedTypeIsStaticStrideOrOffset to direct comparisons with this value.
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicStrideOrOffset(void); MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicStrideOrOffset(void);
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -36,7 +36,7 @@ def VectorElementTypeInterface : TypeInterface<"VectorElementTypeInterface"> {
This may change in the future, for example, to require types to provide This may change in the future, for example, to require types to provide
their size or alignment given a data layout. Please post an RFC before their size or alignment given a data layout. Please post an RFC before
adding this interface to additional types. Implementing this interface on adding this interface to additional types. Implementing this interface on
downstream types is discourged, until we specified the exact properties of downstream types is discouraged, until we specified the exact properties of
a vector element type in more detail. a vector element type in more detail.
}]; }];
} }
@ -221,7 +221,17 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
/// Whether the given shape has any size that indicates a dynamic dimension. /// Whether the given shape has any size that indicates a dynamic dimension.
static bool isDynamicShape(ArrayRef<int64_t> dSizes) { static bool isDynamicShape(ArrayRef<int64_t> dSizes) {
return any_of(dSizes, [](int64_t dSize) { return isDynamic(dSize); }); return llvm::any_of(dSizes, isDynamic);
}
/// Whether the given dimension size indicates a statically-sized dimension.
static constexpr bool isStatic(int64_t dValue) {
return dValue != kDynamic;
}
/// Whether the given shape has static dimensions only.
static bool isStaticShape(ArrayRef<int64_t> dSizes) {
return llvm::all_of(dSizes, isStatic);
} }
/// Return the number of elements present in the given shape. /// Return the number of elements present in the given shape.
@ -273,11 +283,18 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
return ::mlir::ShapedType::isDynamic($_type.getShape()[idx]); return ::mlir::ShapedType::isDynamic($_type.getShape()[idx]);
} }
/// Returns true if this dimension has a static size (for ranked types);
/// aborts for unranked types.
bool isStaticDim(unsigned idx) const {
assert(idx < getRank() && "invalid index for shaped type");
return ::mlir::ShapedType::isStatic($_type.getShape()[idx]);
}
/// Returns if this type has a static shape, i.e. if the type is ranked and /// Returns if this type has a static shape, i.e. if the type is ranked and
/// all dimensions have known size (>= 0). /// all dimensions have known size (>= 0).
bool hasStaticShape() const { bool hasStaticShape() const {
return $_type.hasRank() && return $_type.hasRank() &&
!::mlir::ShapedType::isDynamicShape($_type.getShape()); ::mlir::ShapedType::isStaticShape($_type.getShape());
} }
/// Returns if this type has a static shape and the shape is equal to /// Returns if this type has a static shape and the shape is equal to

View File

@ -544,6 +544,15 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) {
nb::arg("dim"), nb::arg("dim"),
"Returns whether the dim-th dimension of the given shaped type is " "Returns whether the dim-th dimension of the given shaped type is "
"dynamic."); "dynamic.");
c.def(
"is_static_dim",
[](PyShapedType &self, intptr_t dim) -> bool {
self.requireHasRank();
return mlirShapedTypeIsStaticDim(self, dim);
},
nb::arg("dim"),
"Returns whether the dim-th dimension of the given shaped type is "
"static.");
c.def( c.def(
"get_dim_size", "get_dim_size",
[](PyShapedType &self, intptr_t dim) { [](PyShapedType &self, intptr_t dim) {
@ -558,6 +567,12 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) {
nb::arg("dim_size"), nb::arg("dim_size"),
"Returns whether the given dimension size indicates a dynamic " "Returns whether the given dimension size indicates a dynamic "
"dimension."); "dimension.");
c.def_static(
"is_static_size",
[](int64_t size) -> bool { return mlirShapedTypeIsStaticSize(size); },
nb::arg("dim_size"),
"Returns whether the given dimension size indicates a static "
"dimension.");
c.def( c.def(
"is_dynamic_stride_or_offset", "is_dynamic_stride_or_offset",
[](PyShapedType &self, int64_t val) -> bool { [](PyShapedType &self, int64_t val) -> bool {
@ -567,6 +582,15 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) {
nb::arg("dim_size"), nb::arg("dim_size"),
"Returns whether the given value is used as a placeholder for dynamic " "Returns whether the given value is used as a placeholder for dynamic "
"strides and offsets in shaped types."); "strides and offsets in shaped types.");
c.def(
"is_static_stride_or_offset",
[](PyShapedType &self, int64_t val) -> bool {
self.requireHasRank();
return mlirShapedTypeIsStaticStrideOrOffset(val);
},
nb::arg("dim_size"),
"Returns whether the given shaped type stride or offset value is "
"statically-sized.");
c.def_prop_ro( c.def_prop_ro(
"shape", "shape",
[](PyShapedType &self) { [](PyShapedType &self) {

View File

@ -332,6 +332,11 @@ bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim) {
.isDynamicDim(static_cast<unsigned>(dim)); .isDynamicDim(static_cast<unsigned>(dim));
} }
bool mlirShapedTypeIsStaticDim(MlirType type, intptr_t dim) {
return llvm::cast<ShapedType>(unwrap(type))
.isStaticDim(static_cast<unsigned>(dim));
}
int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim) { int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim) {
return llvm::cast<ShapedType>(unwrap(type)) return llvm::cast<ShapedType>(unwrap(type))
.getDimSize(static_cast<unsigned>(dim)); .getDimSize(static_cast<unsigned>(dim));
@ -343,10 +348,18 @@ bool mlirShapedTypeIsDynamicSize(int64_t size) {
return ShapedType::isDynamic(size); return ShapedType::isDynamic(size);
} }
bool mlirShapedTypeIsStaticSize(int64_t size) {
return ShapedType::isStatic(size);
}
bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val) { bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val) {
return ShapedType::isDynamic(val); return ShapedType::isDynamic(val);
} }
bool mlirShapedTypeIsStaticStrideOrOffset(int64_t val) {
return ShapedType::isStatic(val);
}
int64_t mlirShapedTypeGetDynamicStrideOrOffset() { int64_t mlirShapedTypeGetDynamicStrideOrOffset() {
return ShapedType::kDynamic; return ShapedType::kDynamic;
} }

View File

@ -53,7 +53,7 @@ MemRefDescriptor MemRefDescriptor::fromStaticShape(
// Extract all strides and offsets and verify they are static. // Extract all strides and offsets and verify they are static.
auto [strides, offset] = type.getStridesAndOffset(); auto [strides, offset] = type.getStridesAndOffset();
assert(!ShapedType::isDynamic(offset) && "expected static offset"); assert(ShapedType::isStatic(offset) && "expected static offset");
assert(!llvm::any_of(strides, ShapedType::isDynamic) && assert(!llvm::any_of(strides, ShapedType::isDynamic) &&
"expected static strides"); "expected static strides");

View File

@ -609,7 +609,7 @@ bool LLVMTypeConverter::canConvertToBarePtr(BaseMemRefType type) {
if (ShapedType::isDynamic(stride)) if (ShapedType::isDynamic(stride))
return false; return false;
return !ShapedType::isDynamic(offset); return ShapedType::isStatic(offset);
} }
/// Convert a memref type to a bare pointer to the memref element type. /// Convert a memref type to a bare pointer to the memref element type.

View File

@ -43,7 +43,7 @@ static constexpr LLVM::GEPNoWrapFlags kNoWrapFlags =
namespace { namespace {
static bool isStaticStrideOrOffset(int64_t strideOrOffset) { static bool isStaticStrideOrOffset(int64_t strideOrOffset) {
return !ShapedType::isDynamic(strideOrOffset); return ShapedType::isStatic(strideOrOffset);
} }
static FailureOr<LLVM::LLVMFuncOp> static FailureOr<LLVM::LLVMFuncOp>
@ -1468,7 +1468,7 @@ private:
Value stride = nullptr; Value stride = nullptr;
int64_t targetRank = targetMemRefType.getRank(); int64_t targetRank = targetMemRefType.getRank();
for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) { for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
if (!ShapedType::isDynamic(strides[i])) { if (ShapedType::isStatic(strides[i])) {
// If the stride for this dimension is dynamic, then use the product // If the stride for this dimension is dynamic, then use the product
// of the sizes of the inner dimensions. // of the sizes of the inner dimensions.
stride = stride =
@ -1722,7 +1722,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
ArrayRef<int64_t> shape, ValueRange dynamicSizes, unsigned idx, ArrayRef<int64_t> shape, ValueRange dynamicSizes, unsigned idx,
Type indexType) const { Type indexType) const {
assert(idx < shape.size()); assert(idx < shape.size());
if (!ShapedType::isDynamic(shape[idx])) if (ShapedType::isStatic(shape[idx]))
return createIndexAttrConstant(rewriter, loc, indexType, shape[idx]); return createIndexAttrConstant(rewriter, loc, indexType, shape[idx]);
// Count the number of dynamic dims in range [0, idx] // Count the number of dynamic dims in range [0, idx]
unsigned nDynamic = unsigned nDynamic =
@ -1738,7 +1738,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
ArrayRef<int64_t> strides, Value nextSize, ArrayRef<int64_t> strides, Value nextSize,
Value runningStride, unsigned idx, Type indexType) const { Value runningStride, unsigned idx, Type indexType) const {
assert(idx < strides.size()); assert(idx < strides.size());
if (!ShapedType::isDynamic(strides[idx])) if (ShapedType::isStatic(strides[idx]))
return createIndexAttrConstant(rewriter, loc, indexType, strides[idx]); return createIndexAttrConstant(rewriter, loc, indexType, strides[idx]);
if (nextSize) if (nextSize)
return runningStride return runningStride

View File

@ -757,7 +757,7 @@ computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool,
// dimension greater than 1 with a different value is undefined behavior. // dimension greater than 1 with a different value is undefined behavior.
for (auto operand : operands) { for (auto operand : operands) {
auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim); auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
if (!ShapedType::isDynamic(size) && size > 1) if (ShapedType::isStatic(size) && size > 1)
return {rewriter.getIndexAttr(size), operand}; return {rewriter.getIndexAttr(size), operand};
} }

View File

@ -83,7 +83,7 @@ TensorType inferReshapeExpandedType(TensorType inputType,
return totalSize / totalSizeNoPlaceholder; return totalSize / totalSizeNoPlaceholder;
}); });
bool resultIsStatic = !ShapedType::isDynamicShape(resultShape); bool resultIsStatic = ShapedType::isStaticShape(resultShape);
// A syntactic restriction in 'tensor.expand_shape' forbids a dynamically // A syntactic restriction in 'tensor.expand_shape' forbids a dynamically
// shaped input from being reshaped into a statically shaped result. We may // shaped input from being reshaped into a statically shaped result. We may
@ -305,7 +305,7 @@ public:
int64_t size = i.value(); int64_t size = i.value();
size_t index = i.index(); size_t index = i.index();
sizes.push_back(size == -1 ? ShapedType::kDynamic : size); sizes.push_back(size == -1 ? ShapedType::kDynamic : size);
if (!ShapedType::isDynamic(sizes.back())) if (ShapedType::isStatic(sizes.back()))
continue; continue;
auto dim = rewriter.create<tensor::DimOp>(loc, input, index); auto dim = rewriter.create<tensor::DimOp>(loc, input, index);

View File

@ -44,7 +44,7 @@ FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue(
failed(target.getStridesAndOffset(targetStrides, targetOffset))) failed(target.getStridesAndOffset(targetStrides, targetOffset)))
return false; return false;
auto dynamicToStatic = [](int64_t a, int64_t b) { auto dynamicToStatic = [](int64_t a, int64_t b) {
return ShapedType::isDynamic(a) && !ShapedType::isDynamic(b); return ShapedType::isDynamic(a) && ShapedType::isStatic(b);
}; };
if (dynamicToStatic(sourceOffset, targetOffset)) if (dynamicToStatic(sourceOffset, targetOffset))
return false; return false;

View File

@ -33,7 +33,7 @@ static bool hasFullyDynamicLayoutMap(MemRefType type) {
return false; return false;
if (!llvm::all_of(strides, ShapedType::isDynamic)) if (!llvm::all_of(strides, ShapedType::isDynamic))
return false; return false;
if (!ShapedType::isDynamic(offset)) if (ShapedType::isStatic(offset))
return false; return false;
return true; return true;
} }

View File

@ -4564,7 +4564,7 @@ static SmallVector<OpFoldResult> getMixedTilesImpl(OpTy op) {
SmallVector<OpFoldResult> mixedInnerTiles; SmallVector<OpFoldResult> mixedInnerTiles;
unsigned dynamicValIndex = 0; unsigned dynamicValIndex = 0;
for (int64_t staticTile : op.getStaticInnerTiles()) { for (int64_t staticTile : op.getStaticInnerTiles()) {
if (!ShapedType::isDynamic(staticTile)) if (ShapedType::isStatic(staticTile))
mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile)); mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));
else else
mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]); mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
@ -4829,7 +4829,7 @@ bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
std::optional<int64_t> constantTile = getConstantIntValue(tileSize); std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
if (!constantTile) { if (!constantTile) {
if (!ShapedType::isDynamic(outputTileSizes[pos]) && if (ShapedType::isStatic(outputTileSizes[pos]) &&
(inputShape[pos] % outputTileSizes[pos] != 0)) (inputShape[pos] % outputTileSizes[pos] != 0))
return true; return true;
} else if (inputShape[pos] % (*constantTile) != 0) { } else if (inputShape[pos] % (*constantTile) != 0) {
@ -4935,7 +4935,7 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
// use dispatchIndexOpFoldResults on the result, and rely on exact number of // use dispatchIndexOpFoldResults on the result, and rely on exact number of
// dynamic dims returned by that. // dynamic dims returned by that.
for (unsigned i = 0; i < resultDims.size(); ++i) { for (unsigned i = 0; i < resultDims.size(); ++i) {
if (!ShapedType::isDynamic(resultTypeShape[i])) if (ShapedType::isStatic(resultTypeShape[i]))
continue; continue;
resultDims[i] = resultDims[i] =
getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]); getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);

View File

@ -2061,7 +2061,7 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
rewriter.setInsertionPoint(linalgTarget); rewriter.setInsertionPoint(linalgTarget);
for (OpOperand &operand : linalgTarget->getOpOperands()) { for (OpOperand &operand : linalgTarget->getOpOperands()) {
for (auto [i, dim] : llvm::enumerate(linalgTarget.getShape(&operand))) { for (auto [i, dim] : llvm::enumerate(linalgTarget.getShape(&operand))) {
if (!ShapedType::isDynamic(dim)) if (ShapedType::isStatic(dim))
continue; continue;
options.setSizeToPadTo(operand.getOperandNumber(), i, options.setSizeToPadTo(operand.getOperandNumber(), i,
tensor::getMixedSize(rewriter, tensor::getMixedSize(rewriter,

View File

@ -335,7 +335,7 @@ VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
LinalgOp linalgOp) { LinalgOp linalgOp) {
// TODO: Support 0-d vectors. // TODO: Support 0-d vectors.
for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) { for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) { if (ShapedType::isStatic(iterSpaceStaticSizes[vecDim])) {
// Create constant index op for static dimensions. // Create constant index op for static dimensions.
iterSpaceValueSizes.push_back(rewriter.create<arith::ConstantIndexOp>( iterSpaceValueSizes.push_back(rewriter.create<arith::ConstantIndexOp>(
linalgOp.getLoc(), iterSpaceStaticSizes[vecDim])); linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
@ -1652,7 +1652,7 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
for (unsigned i = 0; i < vecToStoreRank; i++) for (unsigned i = 0; i < vecToStoreRank; i++)
inBoundsVal[i] = inBoundsVal[i] =
(destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) && (destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
!ShapedType::isDynamic(destShape[destRank - vecToStoreRank + i]); ShapedType::isStatic(destShape[destRank - vecToStoreRank + i]);
} }
// If missing, initialize the write indices to 0. // If missing, initialize the write indices to 0.

View File

@ -694,7 +694,7 @@ computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
int64_t shapeSize = shape[r]; int64_t shapeSize = shape[r];
std::optional<int64_t> sizeCst = getConstantIntValue(size); std::optional<int64_t> sizeCst = getConstantIntValue(size);
auto hasTileSizeOne = sizeCst == 1; auto hasTileSizeOne = sizeCst == 1;
auto dividesEvenly = sizeCst && !ShapedType::isDynamic(shapeSize) && auto dividesEvenly = sizeCst && ShapedType::isStatic(shapeSize) &&
((shapeSize % *sizeCst) == 0); ((shapeSize % *sizeCst) == 0);
if (!hasTileSizeOne && !dividesEvenly) { if (!hasTileSizeOne && !dividesEvenly) {
LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: shapeSize=" << shapeSize LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: shapeSize=" << shapeSize

View File

@ -99,7 +99,7 @@ static void constifyIndexValues(SmallVectorImpl<OpFoldResult> &values,
"incorrect number of const values"); "incorrect number of const values");
for (auto [i, cstVal] : llvm::enumerate(constValues)) { for (auto [i, cstVal] : llvm::enumerate(constValues)) {
Builder builder(values[i].getContext()); Builder builder(values[i].getContext());
if (!ShapedType::isDynamic(cstVal)) { if (ShapedType::isStatic(cstVal)) {
// Constant value is known, use it directly. // Constant value is known, use it directly.
values[i] = builder.getIndexAttr(cstVal); values[i] = builder.getIndexAttr(cstVal);
continue; continue;
@ -189,7 +189,7 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
int64_t dimSize = memrefType.getDimSize(dim); int64_t dimSize = memrefType.getDimSize(dim);
// If this is already static dimension, keep it. // If this is already static dimension, keep it.
if (!ShapedType::isDynamic(dimSize)) { if (ShapedType::isStatic(dimSize)) {
newShapeConstants.push_back(dimSize); newShapeConstants.push_back(dimSize);
continue; continue;
} }
@ -615,21 +615,21 @@ bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
auto ss = std::get<0>(it), st = std::get<1>(it); auto ss = std::get<0>(it), st = std::get<1>(it);
if (ss != st) if (ss != st)
if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st)) if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
return false; return false;
} }
// If cast is towards more static offset along any dimension, don't fold. // If cast is towards more static offset along any dimension, don't fold.
if (sourceOffset != resultOffset) if (sourceOffset != resultOffset)
if (ShapedType::isDynamic(sourceOffset) && if (ShapedType::isDynamic(sourceOffset) &&
!ShapedType::isDynamic(resultOffset)) ShapedType::isStatic(resultOffset))
return false; return false;
// If cast is towards more static strides along any dimension, don't fold. // If cast is towards more static strides along any dimension, don't fold.
for (auto it : llvm::zip(sourceStrides, resultStrides)) { for (auto it : llvm::zip(sourceStrides, resultStrides)) {
auto ss = std::get<0>(it), st = std::get<1>(it); auto ss = std::get<0>(it), st = std::get<1>(it);
if (ss != st) if (ss != st)
if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st)) if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
return false; return false;
} }
@ -679,7 +679,7 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
if (!ShapedType::isDynamic(aDim) && !ShapedType::isDynamic(bDim) && if (ShapedType::isStatic(aDim) && ShapedType::isStatic(bDim) &&
aDim != bDim) aDim != bDim)
return false; return false;
} }
@ -1862,7 +1862,7 @@ LogicalResult ReinterpretCastOp::verify() {
// Match sizes in result memref type and in static_sizes attribute. // Match sizes in result memref type and in static_sizes attribute.
for (auto [idx, resultSize, expectedSize] : for (auto [idx, resultSize, expectedSize] :
llvm::enumerate(resultType.getShape(), getStaticSizes())) { llvm::enumerate(resultType.getShape(), getStaticSizes())) {
if (!ShapedType::isDynamic(resultSize) && resultSize != expectedSize) if (ShapedType::isStatic(resultSize) && resultSize != expectedSize)
return emitError("expected result type with size = ") return emitError("expected result type with size = ")
<< (ShapedType::isDynamic(expectedSize) << (ShapedType::isDynamic(expectedSize)
? std::string("dynamic") ? std::string("dynamic")
@ -1881,7 +1881,7 @@ LogicalResult ReinterpretCastOp::verify() {
// Match offset in result memref type and in static_offsets attribute. // Match offset in result memref type and in static_offsets attribute.
int64_t expectedOffset = getStaticOffsets().front(); int64_t expectedOffset = getStaticOffsets().front();
if (!ShapedType::isDynamic(resultOffset) && resultOffset != expectedOffset) if (ShapedType::isStatic(resultOffset) && resultOffset != expectedOffset)
return emitError("expected result type with offset = ") return emitError("expected result type with offset = ")
<< (ShapedType::isDynamic(expectedOffset) << (ShapedType::isDynamic(expectedOffset)
? std::string("dynamic") ? std::string("dynamic")
@ -1891,7 +1891,7 @@ LogicalResult ReinterpretCastOp::verify() {
// Match strides in result memref type and in static_strides attribute. // Match strides in result memref type and in static_strides attribute.
for (auto [idx, resultStride, expectedStride] : for (auto [idx, resultStride, expectedStride] :
llvm::enumerate(resultStrides, getStaticStrides())) { llvm::enumerate(resultStrides, getStaticStrides())) {
if (!ShapedType::isDynamic(resultStride) && resultStride != expectedStride) if (ShapedType::isStatic(resultStride) && resultStride != expectedStride)
return emitError("expected result type with stride = ") return emitError("expected result type with stride = ")
<< (ShapedType::isDynamic(expectedStride) << (ShapedType::isDynamic(expectedStride)
? std::string("dynamic") ? std::string("dynamic")
@ -1928,7 +1928,7 @@ OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
} }
// reinterpret_cast(x) w/o offset/shape/stride changes -> x // reinterpret_cast(x) w/o offset/shape/stride changes -> x
if (!ShapedType::isDynamicShape(getType().getShape()) && if (ShapedType::isStaticShape(getType().getShape()) &&
src.getType() == getType() && getStaticOffsets().front() == 0) { src.getType() == getType() && getStaticOffsets().front() == 0) {
return src; return src;
} }
@ -2379,7 +2379,7 @@ LogicalResult ExpandShapeOp::verify() {
DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr(); DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
ArrayRef<int64_t> resShape = getResult().getType().getShape(); ArrayRef<int64_t> resShape = getResult().getType().getShape();
for (auto [pos, shape] : llvm::enumerate(resShape)) { for (auto [pos, shape] : llvm::enumerate(resShape)) {
if (!ShapedType::isDynamic(shape) && shape != staticOutputShapes[pos]) { if (ShapedType::isStatic(shape) && shape != staticOutputShapes[pos]) {
return emitOpError("invalid output shape provided at pos ") << pos; return emitOpError("invalid output shape provided at pos ") << pos;
} }
} }
@ -2422,7 +2422,7 @@ computeCollapsedLayoutMap(MemRefType srcType,
ArrayRef<int64_t> ref = llvm::ArrayRef(reassoc); ArrayRef<int64_t> ref = llvm::ArrayRef(reassoc);
while (srcShape[ref.back()] == 1 && ref.size() > 1) while (srcShape[ref.back()] == 1 && ref.size() > 1)
ref = ref.drop_back(); ref = ref.drop_back();
if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) { if (ShapedType::isStatic(srcShape[ref.back()]) || ref.size() == 1) {
resultStrides.push_back(srcStrides[ref.back()]); resultStrides.push_back(srcStrides[ref.back()]);
} else { } else {
// Dynamically-sized dims may turn out to be dims of size 1 at runtime, so // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
@ -3509,7 +3509,7 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
for (unsigned dim = 0, e = rank; dim < e; ++dim) { for (unsigned dim = 0, e = rank; dim < e; ++dim) {
int64_t dimSize = memrefType.getDimSize(dim); int64_t dimSize = memrefType.getDimSize(dim);
// If this is already static dimension, keep it. // If this is already static dimension, keep it.
if (!ShapedType::isDynamic(dimSize)) { if (ShapedType::isStatic(dimSize)) {
newShapeConstants.push_back(dimSize); newShapeConstants.push_back(dimSize);
continue; continue;
} }

View File

@ -118,7 +118,7 @@ resolveSubviewStridedMetadata(RewriterBase &rewriter,
// Assert that the computed offset matches the offset of the result type of // Assert that the computed offset matches the offset of the result type of
// the subview op (if both are static). // the subview op (if both are static).
std::optional<int64_t> computedOffset = getConstantIntValue(finalOffset); std::optional<int64_t> computedOffset = getConstantIntValue(finalOffset);
if (computedOffset && !ShapedType::isDynamic(resultOffset)) if (computedOffset && ShapedType::isStatic(resultOffset))
assert(*computedOffset == resultOffset && assert(*computedOffset == resultOffset &&
"mismatch between computed offset and result type offset"); "mismatch between computed offset and result type offset");
#endif // NDEBUG #endif // NDEBUG
@ -158,7 +158,7 @@ resolveSubviewStridedMetadata(RewriterBase &rewriter,
// Assert that the computed stride matches the stride of the result type of // Assert that the computed stride matches the stride of the result type of
// the subview op (if both are static). // the subview op (if both are static).
std::optional<int64_t> computedStride = getConstantIntValue(strides[i]); std::optional<int64_t> computedStride = getConstantIntValue(strides[i]);
if (computedStride && !ShapedType::isDynamic(resultStrides[j])) if (computedStride && ShapedType::isStatic(resultStrides[j]))
assert(*computedStride == resultStrides[j] && assert(*computedStride == resultStrides[j] &&
"mismatch between computed stride and result type stride"); "mismatch between computed stride and result type stride");
++j; ++j;
@ -458,7 +458,7 @@ getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
MemRefType collapseShapeType = collapseShape.getResultType(); MemRefType collapseShapeType = collapseShape.getResultType();
uint64_t size = collapseShapeType.getDimSize(groupId); uint64_t size = collapseShapeType.getDimSize(groupId);
if (!ShapedType::isDynamic(size)) { if (ShapedType::isStatic(size)) {
collapsedSize.push_back(builder.getIndexAttr(size)); collapsedSize.push_back(builder.getIndexAttr(size));
return collapsedSize; return collapsedSize;
} }
@ -1091,7 +1091,7 @@ class ExtractStridedMetadataOpCastFolder
auto getConstantOrValue = [&rewriter](int64_t constant, auto getConstantOrValue = [&rewriter](int64_t constant,
OpFoldResult ofr) -> OpFoldResult { OpFoldResult ofr) -> OpFoldResult {
return !ShapedType::isDynamic(constant) return ShapedType::isStatic(constant)
? OpFoldResult(rewriter.getIndexAttr(constant)) ? OpFoldResult(rewriter.getIndexAttr(constant))
: ofr; : ofr;
}; };

View File

@ -264,7 +264,7 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape,
// add halo sizes if requested // add halo sizes if requested
int haloAxis = 0; int haloAxis = 0;
for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) { for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
if (!ShapedType::isDynamic(outShape[tensorAxis]) && if (ShapedType::isStatic(outShape[tensorAxis]) &&
!innerSplitAxes.empty()) { !innerSplitAxes.empty()) {
if (haloSizes[haloAxis * 2] >= 0 && if (haloSizes[haloAxis * 2] >= 0 &&
haloSizes[haloAxis * 2 + 1] >= 0) { haloSizes[haloAxis * 2 + 1] >= 0) {
@ -415,7 +415,7 @@ LogicalResult MeshOp::verify() {
return emitOpError("rank of mesh is expected to be a positive integer"); return emitOpError("rank of mesh is expected to be a positive integer");
for (int64_t dimSize : getShape()) { for (int64_t dimSize : getShape()) {
if (dimSize < 0 && !ShapedType::isDynamic(dimSize)) if (dimSize < 0 && ShapedType::isStatic(dimSize))
return emitOpError("dimension size of a mesh is expected to be " return emitOpError("dimension size of a mesh is expected to be "
"non-negative or dynamic"); "non-negative or dynamic");
} }
@ -609,7 +609,7 @@ LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto shardedDimsOffsets = getStaticShardedDimsOffsets(); auto shardedDimsOffsets = getStaticShardedDimsOffsets();
if (!shardedDimsOffsets.empty()) { if (!shardedDimsOffsets.empty()) {
auto meshShape = mesh.value().getShape(); auto meshShape = mesh.value().getShape();
assert(!ShapedType::isDynamicShape(meshShape)); assert(ShapedType::isStaticShape(meshShape));
uint64_t pos = 0; uint64_t pos = 0;
for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(getSplitAxes())) { for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(getSplitAxes())) {
if (!innerSplitAxes.empty()) { if (!innerSplitAxes.empty()) {
@ -621,7 +621,7 @@ LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
if (shardedDimsOffsets.size() <= pos + i) { if (shardedDimsOffsets.size() <= pos + i) {
return emitError() << "sharded dims offsets has wrong size."; return emitError() << "sharded dims offsets has wrong size.";
} }
if (!ShapedType::isDynamic(shardedDimsOffsets[pos + i])) { if (ShapedType::isStatic(shardedDimsOffsets[pos + i])) {
if (shardedDimsOffsets[pos + i] < off) { if (shardedDimsOffsets[pos + i] < off) {
return emitError() return emitError()
<< "sharded dims offsets must be non-decreasing."; << "sharded dims offsets must be non-decreasing.";
@ -1036,8 +1036,8 @@ static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
} }
for (size_t i = 0; i < device.size(); ++i) { for (size_t i = 0; i < device.size(); ++i) {
if (!ShapedType::isDynamic(device[i]) && if (ShapedType::isStatic(device[i]) &&
!ShapedType::isDynamic(meshShape[meshAxes[i]]) && ShapedType::isStatic(meshShape[meshAxes[i]]) &&
meshShape[meshAxes[i]] <= device[i]) { meshShape[meshAxes[i]] <= device[i]) {
return emitError(loc) return emitError(loc)
<< "Out of bounds coordinate " << i << " for in-group device \"" << "Out of bounds coordinate " << i << " for in-group device \""
@ -1065,8 +1065,7 @@ static LogicalResult verifyDimensionCompatibility(Location loc,
int64_t expectedDimSize, int64_t expectedDimSize,
int64_t resultDimSize, int64_t resultDimSize,
int64_t resultAxis) { int64_t resultAxis) {
if (!ShapedType::isDynamic(resultDimSize) && if (ShapedType::isStatic(resultDimSize) && expectedDimSize != resultDimSize) {
expectedDimSize != resultDimSize) {
return emitError(loc) << "Dimension size mismatch for result axis " return emitError(loc) << "Dimension size mismatch for result axis "
<< resultAxis << ". Expected " << resultAxis << ". Expected "
<< (ShapedType::isDynamic(expectedDimSize) << (ShapedType::isDynamic(expectedDimSize)

View File

@ -453,8 +453,8 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
auto srcHaloSizes = sourceSharding.getStaticHaloSizes(); auto srcHaloSizes = sourceSharding.getStaticHaloSizes();
auto tgtHaloSizes = targetSharding.getStaticHaloSizes(); auto tgtHaloSizes = targetSharding.getStaticHaloSizes();
assert(srcHaloSizes.empty() || srcHaloSizes.size() == tgtHaloSizes.size()); assert(srcHaloSizes.empty() || srcHaloSizes.size() == tgtHaloSizes.size());
assert(((srcHaloSizes.empty() || !ShapedType::isDynamicShape(srcHaloSizes)) && assert(((srcHaloSizes.empty() || ShapedType::isStaticShape(srcHaloSizes)) &&
!ShapedType::isDynamicShape(tgtHaloSizes) && ShapedType::isStaticShape(tgtHaloSizes) &&
sourceShard.getType().hasStaticShape()) && sourceShard.getType().hasStaticShape()) &&
"dynamic shapes/halos are not supported yet for mesh-spmdization"); "dynamic shapes/halos are not supported yet for mesh-spmdization");
auto rank = sourceShard.getType().getRank(); auto rank = sourceShard.getType().getRank();

View File

@ -518,7 +518,7 @@ SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape,
SmallVector<AffineExpr> dimRep; SmallVector<AffineExpr> dimRep;
dimRep.reserve(srcShape.size()); dimRep.reserve(srcShape.size());
for (int64_t sz : srcShape) { for (int64_t sz : srcShape) {
if (!ShapedType::isDynamic(sz)) { if (ShapedType::isStatic(sz)) {
// Push back the max coordinate for the given dimension/level size. // Push back the max coordinate for the given dimension/level size.
dimRep.push_back(getAffineConstantExpr(sz - 1, getContext())); dimRep.push_back(getAffineConstantExpr(sz - 1, getContext()));
} else { } else {
@ -1531,7 +1531,7 @@ OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
}; };
SmallVector<Size> lvlShape = stt.getLvlShape(); SmallVector<Size> lvlShape = stt.getLvlShape();
if (!ShapedType::isDynamic(lvlShape[lvl])) if (ShapedType::isStatic(lvlShape[lvl]))
return getIndexAttr(lvlShape[lvl]); return getIndexAttr(lvlShape[lvl]);
return {}; return {};
@ -1876,7 +1876,7 @@ LogicalResult ConcatenateOp::verify() {
for (Dimension d = 0; d < dimRank; d++) { for (Dimension d = 0; d < dimRank; d++) {
const Size dstSh = dstTp.getDimShape()[d]; const Size dstSh = dstTp.getDimShape()[d];
if (d == concatDim) { if (d == concatDim) {
if (!ShapedType::isDynamic(dstSh)) { if (ShapedType::isStatic(dstSh)) {
// If we reach here, then all inputs have static shapes. So we // If we reach here, then all inputs have static shapes. So we
// can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)` // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)`
// to avoid redundant assertions in the loop. // to avoid redundant assertions in the loop.
@ -1894,7 +1894,7 @@ LogicalResult ConcatenateOp::verify() {
Size prev = dstSh; Size prev = dstSh;
for (const auto src : getInputs()) { for (const auto src : getInputs()) {
const auto sh = getSparseTensorType(src).getDimShape()[d]; const auto sh = getSparseTensorType(src).getDimShape()[d];
if (!ShapedType::isDynamic(prev) && sh != prev) if (ShapedType::isStatic(prev) && sh != prev)
return emitError("All dimensions (expect for the concatenating one) " return emitError("All dimensions (expect for the concatenating one) "
"should be equal."); "should be equal.");
prev = sh; prev = sh;
@ -2058,7 +2058,7 @@ LogicalResult SortOp::verify() {
const auto checkDim = [&](Value v, Size minSize, const auto checkDim = [&](Value v, Size minSize,
const char *message) -> LogicalResult { const char *message) -> LogicalResult {
const Size sh = getMemRefType(v).getShape()[0]; const Size sh = getMemRefType(v).getShape()[0];
if (!ShapedType::isDynamic(sh) && sh < minSize) if (ShapedType::isStatic(sh) && sh < minSize)
return emitError( return emitError(
llvm::formatv("{0} got {1} < {2}", message, sh, minSize)); llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
return success(); return success();

View File

@ -259,7 +259,7 @@ translateMap(linalg::GenericOp op, PatternRewriter &rewriter) {
// translation. // translation.
auto populateCstMapping = [ctx](DenseMap<AffineExpr, AffineExpr> &cstMapping, auto populateCstMapping = [ctx](DenseMap<AffineExpr, AffineExpr> &cstMapping,
unsigned pos, int64_t lvlSz) { unsigned pos, int64_t lvlSz) {
if (!ShapedType::isDynamic(lvlSz)) { if (ShapedType::isStatic(lvlSz)) {
auto c0 = getAffineConstantExpr(0, ctx); auto c0 = getAffineConstantExpr(0, ctx);
auto lvlExp = getAffineDimExpr(pos, ctx); auto lvlExp = getAffineDimExpr(pos, ctx);
auto szExp = getAffineConstantExpr(lvlSz, ctx); auto szExp = getAffineConstantExpr(lvlSz, ctx);

View File

@ -1348,7 +1348,7 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
Level trailCOORank = stt.getLvlRank() - trailCOOStart; Level trailCOORank = stt.getLvlRank() - trailCOOStart;
// Sets up SparseTensorSpecifier. // Sets up SparseTensorSpecifier.
for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) { for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
assert(!ShapedType::isDynamic(stt.getDimShape()[lvl])); assert(ShapedType::isStatic(stt.getDimShape()[lvl]));
// Sets up the level size. // Sets up the level size.
auto lvlSize = constantIndex(rewriter, loc, stt.getLvlShape()[lvl]); auto lvlSize = constantIndex(rewriter, loc, stt.getLvlShape()[lvl]);

View File

@ -86,7 +86,7 @@ static Value createOrFoldLvlCall(OpBuilder &builder, Location loc,
const Dimension dim = const Dimension dim =
stt.isIdentity() ? lvl : stt.getDimToLvl().getDimPosition(lvl); stt.isIdentity() ? lvl : stt.getDimToLvl().getDimPosition(lvl);
const Size sz = stt.getDynamicDimSize(dim); const Size sz = stt.getDynamicDimSize(dim);
if (!ShapedType::isDynamic(sz)) if (ShapedType::isStatic(sz))
return constantIndex(builder, loc, sz); return constantIndex(builder, loc, sz);
// If we cannot statically compute the size from the shape, then we // If we cannot statically compute the size from the shape, then we
// must dynamically query it. (In principle we could also dynamically // must dynamically query it. (In principle we could also dynamically
@ -103,7 +103,7 @@ static Value createOrFoldDimCall(OpBuilder &builder, Location loc,
SparseTensorType stt, Value tensor, SparseTensorType stt, Value tensor,
Dimension dim) { Dimension dim) {
const Size sz = stt.getDynamicDimSize(dim); const Size sz = stt.getDynamicDimSize(dim);
if (!ShapedType::isDynamic(sz)) if (ShapedType::isStatic(sz))
return constantIndex(builder, loc, sz); return constantIndex(builder, loc, sz);
if (stt.hasEncoding()) if (stt.hasEncoding())
return genDimSizeCall(builder, loc, tensor, dim); return genDimSizeCall(builder, loc, tensor, dim);

View File

@ -1245,7 +1245,7 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
// by concatenate op verifier, which saves us from computing the offset // by concatenate op verifier, which saves us from computing the offset
// dynamically. // dynamically.
const Size sz = getSparseTensorType(input).getDynamicDimSize(conDim); const Size sz = getSparseTensorType(input).getDynamicDimSize(conDim);
assert(!ShapedType::isDynamic(sz)); assert(ShapedType::isStatic(sz));
offset = rewriter.create<arith::AddIOp>(loc, offset, offset = rewriter.create<arith::AddIOp>(loc, offset,
constantIndex(rewriter, loc, sz)); constantIndex(rewriter, loc, sz));
iterArg = foreachOp.getResult(0); iterArg = foreachOp.getResult(0);

View File

@ -23,7 +23,7 @@ using namespace mlir::tensor;
static OpFoldResult getCollapsedOutputDimFromInputShape( static OpFoldResult getCollapsedOutputDimFromInputShape(
OpBuilder &builder, Location loc, int64_t dimIndex, Value src, OpBuilder &builder, Location loc, int64_t dimIndex, Value src,
ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociationMap) { ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociationMap) {
if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) { if (ShapedType::isStatic(dstStaticShape[dimIndex])) {
// Static dimension: return Attribute. // Static dimension: return Attribute.
return builder.getIndexAttr(dstStaticShape[dimIndex]); return builder.getIndexAttr(dstStaticShape[dimIndex]);
} }

View File

@ -292,7 +292,7 @@ bool mlir::tensor::preservesStaticInformation(Type source, Type target) {
// If cast is towards more static sizes along any dimension, don't fold. // If cast is towards more static sizes along any dimension, don't fold.
for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) { for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
if (!ShapedType::isDynamic(std::get<0>(t)) && if (ShapedType::isStatic(std::get<0>(t)) &&
ShapedType::isDynamic(std::get<1>(t))) ShapedType::isDynamic(std::get<1>(t)))
return false; return false;
} }
@ -1235,7 +1235,7 @@ struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
// Case 2 : The tensor cast shape is static, but empty tensor result // Case 2 : The tensor cast shape is static, but empty tensor result
// shape is dynamic. // shape is dynamic.
if (!ShapedType::isDynamic(newDim)) { if (ShapedType::isStatic(newDim)) {
newMixedSizes.push_back(rewriter.getIndexAttr(newDim)); newMixedSizes.push_back(rewriter.getIndexAttr(newDim));
continue; continue;
} }
@ -2197,7 +2197,7 @@ struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) { for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
for (uint64_t outDim : innerReassoc) { for (uint64_t outDim : innerReassoc) {
if (!ShapedType::isDynamic(newOutputShape[outDim])) if (ShapedType::isStatic(newOutputShape[outDim]))
continue; continue;
// If the cast's src type is dynamic, don't infer any of the // If the cast's src type is dynamic, don't infer any of the
@ -3579,7 +3579,7 @@ struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
continue; continue;
OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()]; OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()]; int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
assert(!ShapedType::isDynamic(sourceSize) && assert(ShapedType::isStatic(sourceSize) &&
"expected padded dimension to have a static size"); "expected padded dimension to have a static size");
if (getConstantIntValue(sliceSize) != sourceSize) { if (getConstantIntValue(sliceSize) != sourceSize) {
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(

View File

@ -849,7 +849,7 @@ static LogicalResult verifyPoolingOp(T op) {
<< kernelSize << ") / " << strideSize; << kernelSize << ") / " << strideSize;
const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1; const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
if (!ShapedType::isDynamic(outputSize) && calculatedOutSize != outputSize) if (ShapedType::isStatic(outputSize) && calculatedOutSize != outputSize)
return op.emitOpError("calculated output ") return op.emitOpError("calculated output ")
<< dimName << " did not match expected: " << dimName << " did not match expected: "
<< "calculated=" << calculatedOutSize << "calculated=" << calculatedOutSize
@ -1301,12 +1301,12 @@ LogicalResult tosa::RFFT2dOp::verify() {
return success(); return success();
const int64_t height = inputType.getDimSize(1); const int64_t height = inputType.getDimSize(1);
if (!ShapedType::isDynamic(height) && if (ShapedType::isStatic(height) &&
failed(verifyDimIsPowerOfTwo(*this, height, "height"))) failed(verifyDimIsPowerOfTwo(*this, height, "height")))
return failure(); return failure();
const int64_t width = inputType.getDimSize(2); const int64_t width = inputType.getDimSize(2);
if (!ShapedType::isDynamic(width) && if (ShapedType::isStatic(width) &&
failed(verifyDimIsPowerOfTwo(*this, width, "width"))) failed(verifyDimIsPowerOfTwo(*this, width, "width")))
return failure(); return failure();
@ -1323,7 +1323,7 @@ LogicalResult tosa::RFFT2dOp::verify() {
// Output width dimension expected to be input_width / 2 + 1 // Output width dimension expected to be input_width / 2 + 1
const int64_t outputWidth = outputType.getDimSize(2); const int64_t outputWidth = outputType.getDimSize(2);
if (!ShapedType::isDynamic(width) && !ShapedType::isDynamic(outputWidth) && if (ShapedType::isStatic(width) && ShapedType::isStatic(outputWidth) &&
(outputWidth != (width / 2) + 1)) (outputWidth != (width / 2) + 1))
return emitOpError( return emitOpError(
"expected output width to be equal to input_width / 2 + 1, got ") "expected output width to be equal to input_width / 2 + 1, got ")
@ -1357,13 +1357,13 @@ LogicalResult tosa::FFT2dOp::verify() {
const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1), const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
inputImagType.getDimSize(1)); inputImagType.getDimSize(1));
if (!ShapedType::isDynamic(height) && if (ShapedType::isStatic(height) &&
failed(verifyDimIsPowerOfTwo(*this, height, "height"))) failed(verifyDimIsPowerOfTwo(*this, height, "height")))
return failure(); return failure();
const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2), const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
inputImagType.getDimSize(2)); inputImagType.getDimSize(2));
if (!ShapedType::isDynamic(width) && if (ShapedType::isStatic(width) &&
failed(verifyDimIsPowerOfTwo(*this, width, "width"))) failed(verifyDimIsPowerOfTwo(*this, width, "width")))
return failure(); return failure();
@ -1965,7 +1965,7 @@ LogicalResult tosa::TableOp::verify() {
for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) { for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
int64_t dim = it.index(); int64_t dim = it.index();
auto [inputDim, outputDim] = it.value(); auto [inputDim, outputDim] = it.value();
if (!ShapedType::isDynamic(outputDim) && outputDim != inputDim) { if (ShapedType::isStatic(outputDim) && outputDim != inputDim) {
return emitOpError() << "dim(result, " << dim << ") = " << outputDim return emitOpError() << "dim(result, " << dim << ") = " << outputDim
<< " doesn't match dim(input, " << dim << " doesn't match dim(input, " << dim
<< ") = " << inputDim; << ") = " << inputDim;
@ -2100,7 +2100,7 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
int64_t numElements = inputShape.getNumElements(); int64_t numElements = inputShape.getNumElements();
int64_t staticMul = 1; int64_t staticMul = 1;
for (auto val : newShapeValue) { for (auto val : newShapeValue) {
if (!ShapedType::isDynamic(val)) { if (ShapedType::isStatic(val)) {
staticMul *= val; staticMul *= val;
} }
} }
@ -2988,12 +2988,12 @@ static LogicalResult poolingInferReturnTypes(
int64_t height = inputShape.getDimSize(1); int64_t height = inputShape.getDimSize(1);
int64_t width = inputShape.getDimSize(2); int64_t width = inputShape.getDimSize(2);
if (!ShapedType::isDynamic(height)) { if (ShapedType::isStatic(height)) {
int64_t padded = height + pad[0] + pad[1] - kernel[0]; int64_t padded = height + pad[0] + pad[1] - kernel[0];
outputShape[1] = padded / stride[0] + 1; outputShape[1] = padded / stride[0] + 1;
} }
if (!ShapedType::isDynamic(width)) { if (ShapedType::isStatic(width)) {
int64_t padded = width + pad[2] + pad[3] - kernel[1]; int64_t padded = width + pad[2] + pad[3] - kernel[1];
outputShape[2] = padded / stride[1] + 1; outputShape[2] = padded / stride[1] + 1;
} }
@ -3042,16 +3042,14 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
llvm::ArrayRef<int64_t> stride = adaptor.getStride(); llvm::ArrayRef<int64_t> stride = adaptor.getStride();
llvm::ArrayRef<int64_t> padding = adaptor.getPad(); llvm::ArrayRef<int64_t> padding = adaptor.getPad();
if (!ShapedType::isDynamic(inputHeight) && if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
!ShapedType::isDynamic(weightHeight)) {
int64_t inputSize = inputHeight + padding[0] + padding[1]; int64_t inputSize = inputHeight + padding[0] + padding[1];
int64_t filterSize = (weightHeight - 1) * dilation[0] + 1; int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
int64_t unstridedResult = inputSize - filterSize + 1; int64_t unstridedResult = inputSize - filterSize + 1;
outputShape[1] = (unstridedResult - 1) / stride[0] + 1; outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
} }
if (!ShapedType::isDynamic(inputWidth) && if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
!ShapedType::isDynamic(weightWidth)) {
int64_t inputSize = inputWidth + padding[2] + padding[3]; int64_t inputSize = inputWidth + padding[2] + padding[3];
int64_t filterSize = (weightWidth - 1) * dilation[1] + 1; int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
int64_t unstridedResult = inputSize - filterSize + 1; int64_t unstridedResult = inputSize - filterSize + 1;
@ -3111,24 +3109,21 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
llvm::ArrayRef<int64_t> stride = adaptor.getStride(); llvm::ArrayRef<int64_t> stride = adaptor.getStride();
llvm::ArrayRef<int64_t> pad = adaptor.getPad(); llvm::ArrayRef<int64_t> pad = adaptor.getPad();
if (!ShapedType::isDynamic(inputDepth) && if (ShapedType::isStatic(inputDepth) && ShapedType::isStatic(weightDepth)) {
!ShapedType::isDynamic(weightDepth)) {
int32_t inputSize = inputDepth + pad[0] + pad[1]; int32_t inputSize = inputDepth + pad[0] + pad[1];
int32_t filterSize = (weightDepth - 1) * dilation[0] + 1; int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
int32_t unstridedResult = inputSize - filterSize + 1; int32_t unstridedResult = inputSize - filterSize + 1;
outputShape[1] = (unstridedResult - 1) / stride[0] + 1; outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
} }
if (!ShapedType::isDynamic(inputHeight) && if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
!ShapedType::isDynamic(weightHeight)) {
int32_t inputSize = inputHeight + pad[2] + pad[3]; int32_t inputSize = inputHeight + pad[2] + pad[3];
int32_t filterSize = (weightHeight - 1) * dilation[1] + 1; int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
int32_t unstridedResult = inputSize - filterSize + 1; int32_t unstridedResult = inputSize - filterSize + 1;
outputShape[2] = (unstridedResult - 1) / stride[1] + 1; outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
} }
if (!ShapedType::isDynamic(inputWidth) && if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
!ShapedType::isDynamic(weightWidth)) {
int32_t inputSize = inputWidth + pad[4] + pad[5]; int32_t inputSize = inputWidth + pad[4] + pad[5];
int32_t filterSize = (weightWidth - 1) * dilation[2] + 1; int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
int32_t unstridedResult = inputSize - filterSize + 1; int32_t unstridedResult = inputSize - filterSize + 1;
@ -3213,8 +3208,8 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
// If both inputChannels and depthChannels are available we can determine // If both inputChannels and depthChannels are available we can determine
// the output channels. // the output channels.
if (!ShapedType::isDynamic(inputChannels) && if (ShapedType::isStatic(inputChannels) &&
!ShapedType::isDynamic(depthChannels)) { ShapedType::isStatic(depthChannels)) {
outputShape[3] = inputChannels * depthChannels; outputShape[3] = inputChannels * depthChannels;
} }
@ -3230,16 +3225,14 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
llvm::ArrayRef<int64_t> padding = adaptor.getPad(); llvm::ArrayRef<int64_t> padding = adaptor.getPad();
llvm::ArrayRef<int64_t> stride = adaptor.getStride(); llvm::ArrayRef<int64_t> stride = adaptor.getStride();
if (!ShapedType::isDynamic(inputHeight) && if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
!ShapedType::isDynamic(weightHeight)) {
int64_t inputSize = inputHeight + padding[0] + padding[1]; int64_t inputSize = inputHeight + padding[0] + padding[1];
int64_t filterSize = (weightHeight - 1) * dilation[0] + 1; int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
int64_t unstridedResult = inputSize - filterSize + 1; int64_t unstridedResult = inputSize - filterSize + 1;
outputShape[1] = (unstridedResult - 1) / stride[0] + 1; outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
} }
if (!ShapedType::isDynamic(inputWidth) && if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
!ShapedType::isDynamic(weightWidth)) {
int64_t inputSize = inputWidth + padding[2] + padding[3]; int64_t inputSize = inputWidth + padding[2] + padding[3];
int64_t filterSize = (weightWidth - 1) * dilation[1] + 1; int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
int64_t unstridedResult = inputSize - filterSize + 1; int64_t unstridedResult = inputSize - filterSize + 1;
@ -3299,16 +3292,14 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
llvm::ArrayRef<int64_t> padding = adaptor.getOutPad(); llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
llvm::ArrayRef<int64_t> stride = adaptor.getStride(); llvm::ArrayRef<int64_t> stride = adaptor.getStride();
if (!ShapedType::isDynamic(inputHeight) && if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
!ShapedType::isDynamic(weightHeight)) {
int64_t calculateSize = int64_t calculateSize =
(inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight; (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
outputShape[1] = outputShape[1] =
ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1]; ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
} }
if (!ShapedType::isDynamic(inputWidth) && if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
!ShapedType::isDynamic(weightWidth)) {
int64_t calculateSize = int64_t calculateSize =
(inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth; (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
outputShape[2] = outputShape[2] =
@ -3354,7 +3345,7 @@ LogicalResult TransposeConv2DOp::verify() {
if (weightType) { if (weightType) {
const int64_t kernelHeight = weightType.getDimSize(1); const int64_t kernelHeight = weightType.getDimSize(1);
if (!ShapedType::isDynamic(kernelHeight)) { if (ShapedType::isStatic(kernelHeight)) {
if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight, if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
"out_pad_top", "KH"))) "out_pad_top", "KH")))
return failure(); return failure();
@ -3365,7 +3356,7 @@ LogicalResult TransposeConv2DOp::verify() {
} }
const int64_t kernelWidth = weightType.getDimSize(2); const int64_t kernelWidth = weightType.getDimSize(2);
if (!ShapedType::isDynamic(kernelWidth)) { if (ShapedType::isStatic(kernelWidth)) {
if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth, if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
"out_pad_left", "KW"))) "out_pad_left", "KW")))
return failure(); return failure();
@ -3388,8 +3379,8 @@ LogicalResult TransposeConv2DOp::verify() {
const int64_t kernelHeight = weightType.getDimSize(1); const int64_t kernelHeight = weightType.getDimSize(1);
const int64_t outputHeight = outputType.getDimSize(1); const int64_t outputHeight = outputType.getDimSize(1);
if (!ShapedType::isDynamic(inputHeight) && if (ShapedType::isStatic(inputHeight) &&
!ShapedType::isDynamic(outputHeight)) { ShapedType::isStatic(outputHeight)) {
if (outputHeight != if (outputHeight !=
(inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight) (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
return emitOpError( return emitOpError(
@ -3404,8 +3395,7 @@ LogicalResult TransposeConv2DOp::verify() {
const int64_t kernelWidth = weightType.getDimSize(2); const int64_t kernelWidth = weightType.getDimSize(2);
const int64_t outputWidth = outputType.getDimSize(2); const int64_t outputWidth = outputType.getDimSize(2);
if (!ShapedType::isDynamic(inputWidth) && if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(outputWidth)) {
!ShapedType::isDynamic(outputWidth)) {
if (outputWidth != if (outputWidth !=
(inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth) (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
return emitOpError( return emitOpError(

View File

@ -505,7 +505,7 @@ LogicalResult mlir::reshapeLikeShapesAreCompatible(
linearizedStaticShape *= dim.value(); linearizedStaticShape *= dim.value();
} }
if (foundDynamicShape) { if (foundDynamicShape) {
if (!ShapedType::isDynamic(collapsedShape[map.index()])) { if (ShapedType::isStatic(collapsedShape[map.index()])) {
return emitError( return emitError(
"expected dimension " + Twine(map.index()) + "expected dimension " + Twine(map.index()) +
" of collapsed type to be dynamic since one or more of the " " of collapsed type to be dynamic since one or more of the "

View File

@ -280,13 +280,13 @@ std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets) { bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets) {
return llvm::none_of(sizesOrOffsets, [](int64_t value) { return llvm::none_of(sizesOrOffsets, [](int64_t value) {
return !ShapedType::isDynamic(value) && value < 0; return ShapedType::isStatic(value) && value < 0;
}); });
} }
bool hasValidStrides(SmallVector<int64_t> strides) { bool hasValidStrides(SmallVector<int64_t> strides) {
return llvm::none_of(strides, [](int64_t value) { return llvm::none_of(strides, [](int64_t value) {
return !ShapedType::isDynamic(value) && value == 0; return ShapedType::isStatic(value) && value == 0;
}); });
} }

View File

@ -1107,7 +1107,7 @@ Type ContractionOp::getExpectedMaskType() {
rhsType.getScalableDims()[dimIdx]; rhsType.getScalableDims()[dimIdx];
} }
assert(!ShapedType::isDynamicShape(maskShape) && assert(ShapedType::isStaticShape(maskShape) &&
"Mask shape couldn't be computed"); "Mask shape couldn't be computed");
return VectorType::get(maskShape, return VectorType::get(maskShape,
@ -2061,7 +2061,7 @@ static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
// `opChange` is a flag. If it is true, it means to update `op` in place. // `opChange` is a flag. If it is true, it means to update `op` in place.
bool opChange = false; bool opChange = false;
for (unsigned i = 0, e = staticPosition.size(); i < e; ++i) { for (unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
if (!ShapedType::isDynamic(staticPosition[i])) if (ShapedType::isStatic(staticPosition[i]))
continue; continue;
Attribute positionAttr = dynamicPositionAttr[index]; Attribute positionAttr = dynamicPositionAttr[index];
Value position = dynamicPosition[index++]; Value position = dynamicPosition[index++];

View File

@ -339,7 +339,7 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
// FIXME: This computation is too weak - it ignores the read indices. // FIXME: This computation is too weak - it ignores the read indices.
for (unsigned i = 0; i < readRank; i++) for (unsigned i = 0; i < readRank; i++)
inBoundsVal[i] = (sourceShape[i] == inputVectorSizes[i]) && inBoundsVal[i] = (sourceShape[i] == inputVectorSizes[i]) &&
!ShapedType::isDynamic(sourceShape[i]); ShapedType::isStatic(sourceShape[i]);
} }
auto transferReadOp = builder.create<vector::TransferReadOp>( auto transferReadOp = builder.create<vector::TransferReadOp>(
loc, loc,

View File

@ -230,8 +230,8 @@ void StridedLayoutAttr::print(llvm::raw_ostream &os) const {
/// Returns true if this layout is static, i.e. the strides and offset all have /// Returns true if this layout is static, i.e. the strides and offset all have
/// a known value > 0. /// a known value > 0.
bool StridedLayoutAttr::hasStaticLayout() const { bool StridedLayoutAttr::hasStaticLayout() const {
return !ShapedType::isDynamic(getOffset()) && return ShapedType::isStatic(getOffset()) &&
!ShapedType::isDynamicShape(getStrides()); ShapedType::isStaticShape(getStrides());
} }
/// Returns the strided layout as an affine map. /// Returns the strided layout as an affine map.
@ -1818,7 +1818,7 @@ AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
// AffineExpr for offset. // AffineExpr for offset.
// Static case. // Static case.
if (!ShapedType::isDynamic(offset)) { if (ShapedType::isStatic(offset)) {
auto cst = getAffineConstantExpr(offset, context); auto cst = getAffineConstantExpr(offset, context);
expr = cst; expr = cst;
} else { } else {
@ -1834,7 +1834,7 @@ AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
auto d = getAffineDimExpr(dim, context); auto d = getAffineDimExpr(dim, context);
AffineExpr mult; AffineExpr mult;
// Static case. // Static case.
if (!ShapedType::isDynamic(stride)) if (ShapedType::isStatic(stride))
mult = getAffineConstantExpr(stride, context); mult = getAffineConstantExpr(stride, context);
else else
// Dynamic case, new symbol for each new stride. // Dynamic case, new symbol for each new stride.

View File

@ -321,7 +321,7 @@ RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType, ArrayRef<int64_t> shape, Type elementType,
Attribute encoding) { Attribute encoding) {
for (int64_t s : shape) for (int64_t s : shape)
if (s < 0 && !ShapedType::isDynamic(s)) if (s < 0 && ShapedType::isStatic(s))
return emitError() << "invalid tensor dimension size"; return emitError() << "invalid tensor dimension size";
if (auto v = llvm::dyn_cast_or_null<VerifiableTensorEncoding>(encoding)) if (auto v = llvm::dyn_cast_or_null<VerifiableTensorEncoding>(encoding))
if (failed(v.verifyEncoding(shape, elementType, emitError))) if (failed(v.verifyEncoding(shape, elementType, emitError)))
@ -644,7 +644,7 @@ LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
// Negative sizes are not allowed except for `kDynamic`. // Negative sizes are not allowed except for `kDynamic`.
for (int64_t s : shape) for (int64_t s : shape)
if (s < 0 && !ShapedType::isDynamic(s)) if (s < 0 && ShapedType::isStatic(s))
return emitError() << "invalid memref size"; return emitError() << "invalid memref size";
assert(layout && "missing layout specification"); assert(layout && "missing layout specification");

View File

@ -62,7 +62,7 @@ LogicalResult mlir::verifyCompatibleShape(ArrayRef<int64_t> shape1,
for (auto dims : llvm::zip(shape1, shape2)) { for (auto dims : llvm::zip(shape1, shape2)) {
int64_t dim1 = std::get<0>(dims); int64_t dim1 = std::get<0>(dims);
int64_t dim2 = std::get<1>(dims); int64_t dim2 = std::get<1>(dims);
if (!ShapedType::isDynamic(dim1) && !ShapedType::isDynamic(dim2) && if (ShapedType::isStatic(dim1) && ShapedType::isStatic(dim2) &&
dim1 != dim2) dim1 != dim2)
return failure(); return failure();
} }

View File

@ -124,12 +124,12 @@ mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
return failure(); return failure();
for (int64_t offset : op.getStaticOffsets()) { for (int64_t offset : op.getStaticOffsets()) {
if (offset < 0 && !ShapedType::isDynamic(offset)) if (offset < 0 && ShapedType::isStatic(offset))
return op->emitError("expected offsets to be non-negative, but got ") return op->emitError("expected offsets to be non-negative, but got ")
<< offset; << offset;
} }
for (int64_t size : op.getStaticSizes()) { for (int64_t size : op.getStaticSizes()) {
if (size < 0 && !ShapedType::isDynamic(size)) if (size < 0 && ShapedType::isStatic(size))
return op->emitError("expected sizes to be non-negative, but got ") return op->emitError("expected sizes to be non-negative, but got ")
<< size; << size;
} }

View File

@ -2497,6 +2497,11 @@ class ShapedType(Type):
Returns whether the given dimension size indicates a dynamic dimension. Returns whether the given dimension size indicates a dynamic dimension.
""" """
@staticmethod @staticmethod
def is_static_size(dim_size: int) -> bool:
"""
Returns whether the given dimension size indicates a static dimension.
"""
@staticmethod
def isinstance(other: Type) -> bool: ... def isinstance(other: Type) -> bool: ...
def __init__(self, cast_from_type: Type) -> None: ... def __init__(self, cast_from_type: Type) -> None: ...
def get_dim_size(self, dim: int) -> int: def get_dim_size(self, dim: int) -> int:
@ -2507,10 +2512,18 @@ class ShapedType(Type):
""" """
Returns whether the dim-th dimension of the given shaped type is dynamic. Returns whether the dim-th dimension of the given shaped type is dynamic.
""" """
def is_static_dim(self, dim: int) -> bool:
"""
Returns whether the dim-th dimension of the given shaped type is static.
"""
def is_dynamic_stride_or_offset(self, dim_size: int) -> bool: def is_dynamic_stride_or_offset(self, dim_size: int) -> bool:
""" """
Returns whether the given value is used as a placeholder for dynamic strides and offsets in shaped types. Returns whether the given value is used as a placeholder for dynamic strides and offsets in shaped types.
""" """
def is_static_stride_or_offset(self, dim_size: int) -> bool:
"""
Returns whether the given shaped type stride or offset value is statically-sized.
"""
@property @property
def element_type(self) -> Type: def element_type(self) -> Type:
""" """

View File

@ -330,8 +330,29 @@ def testConcreteShapedType():
print("dim size:", vector.get_dim_size(1)) print("dim size:", vector.get_dim_size(1))
# CHECK: is_dynamic_size: False # CHECK: is_dynamic_size: False
print("is_dynamic_size:", vector.is_dynamic_size(3)) print("is_dynamic_size:", vector.is_dynamic_size(3))
# CHECK: is_static_size: True
print("is_static_size:", vector.is_static_size(3))
# CHECK: is_dynamic_stride_or_offset: False # CHECK: is_dynamic_stride_or_offset: False
print("is_dynamic_stride_or_offset:", vector.is_dynamic_stride_or_offset(1)) print("is_dynamic_stride_or_offset:", vector.is_dynamic_stride_or_offset(1))
# CHECK: is_static_stride_or_offset: True
print("is_static_stride_or_offset:", vector.is_static_stride_or_offset(1))
dynamic_size_val = vector.get_dynamic_size()
dynamic_stride_val = vector.get_dynamic_stride_or_offset()
# CHECK: is_dynamic_size_with_dynamic: True
print("is_dynamic_size_with_dynamic:", vector.is_dynamic_size(dynamic_size_val))
# CHECK: is_static_size_with_dynamic: False
print("is_static_size_with_dynamic:", vector.is_static_size(dynamic_size_val))
# CHECK: is_dynamic_stride_or_offset_with_dynamic: True
print(
"is_dynamic_stride_or_offset_with_dynamic:",
vector.is_dynamic_stride_or_offset(dynamic_stride_val),
)
# CHECK: is_static_stride_or_offset_with_dynamic: False
print(
"is_static_stride_or_offset_with_dynamic:",
vector.is_static_stride_or_offset(dynamic_stride_val),
)
# CHECK: isinstance(ShapedType): True # CHECK: isinstance(ShapedType): True
print("isinstance(ShapedType):", isinstance(vector, ShapedType)) print("isinstance(ShapedType):", isinstance(vector, ShapedType))