Switch member calls to isa/dyn_cast/cast/... to free function calls. (#89356)

This change cleans up call sites. Next step is to mark the member
functions deprecated.

See https://mlir.llvm.org/deprecation and
https://discourse.llvm.org/t/preferred-casting-style-going-forward.
This commit is contained in:
Christian Sigg 2024-04-19 15:58:27 +02:00 committed by GitHub
parent ce2f6423f0
commit a5757c5b65
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
80 changed files with 240 additions and 264 deletions

View File

@ -142,7 +142,7 @@ mlir::transform::HasOperandSatisfyingOp::apply(
transform::detail::prepareValueMappings(
yieldedMappings, getBody().front().getTerminator()->getOperands(),
state);
results.setParams(getPosition().cast<OpResult>(),
results.setParams(cast<OpResult>(getPosition()),
{rewriter.getI32IntegerAttr(operand.getOperandNumber())});
for (auto &&[result, mapping] : llvm::zip(getResults(), yieldedMappings))
results.setMappedValues(result, mapping);

View File

@ -87,7 +87,7 @@ private:
void
populateIteratorTypes(Type t,
SmallVector<utils::IteratorType> &iterTypes) const {
RankedTensorType rankedTensorType = t.dyn_cast<RankedTensorType>();
RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(t);
if (!rankedTensorType) {
return;
}
@ -106,7 +106,7 @@ struct ElementwiseShardingInterface
ElementwiseShardingInterface<ElemwiseOp>, ElemwiseOp> {
SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
Value val = op->getOperand(0);
auto type = val.getType().dyn_cast<RankedTensorType>();
auto type = dyn_cast<RankedTensorType>(val.getType());
if (!type)
return {};
SmallVector<utils::IteratorType> types(type.getRank(),
@ -117,7 +117,7 @@ struct ElementwiseShardingInterface
SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
MLIRContext *ctx = op->getContext();
Value val = op->getOperand(0);
auto type = val.getType().dyn_cast<RankedTensorType>();
auto type = dyn_cast<RankedTensorType>(val.getType());
if (!type)
return {};
int64_t rank = type.getRank();

View File

@ -60,11 +60,11 @@ public:
if (llvm::isa<FloatType>(resElemType))
return impl::verifySameOperandsAndResultElementType(op);
if (auto resIntType = resElemType.dyn_cast<IntegerType>()) {
if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
IntegerType lhsIntType =
getElementTypeOrSelf(op->getOperand(0)).cast<IntegerType>();
cast<IntegerType>(getElementTypeOrSelf(op->getOperand(0)));
IntegerType rhsIntType =
getElementTypeOrSelf(op->getOperand(1)).cast<IntegerType>();
cast<IntegerType>(getElementTypeOrSelf(op->getOperand(1)));
if (lhsIntType != rhsIntType)
return op->emitOpError(
"requires the same element type for all operands");

View File

@ -154,7 +154,7 @@ public:
/// Support llvm style casting.
static bool classof(Attribute attr) {
auto fusedLoc = llvm::dyn_cast<FusedLoc>(attr);
return fusedLoc && fusedLoc.getMetadata().isa_and_nonnull<MetadataT>();
return fusedLoc && mlir::isa_and_nonnull<MetadataT>(fusedLoc.getMetadata());
}
};

View File

@ -135,7 +135,7 @@ MlirAttribute mlirLLVMDIExpressionAttrGet(MlirContext ctx, intptr_t nOperations,
unwrap(ctx),
llvm::map_to_vector(
unwrapList(nOperations, operations, attrStorage),
[](Attribute a) { return a.cast<DIExpressionElemAttr>(); })));
[](Attribute a) { return cast<DIExpressionElemAttr>(a); })));
}
MlirAttribute mlirLLVMDINullTypeAttrGet(MlirContext ctx) {
@ -165,7 +165,7 @@ MlirAttribute mlirLLVMDICompositeTypeAttrGet(
cast<DIScopeAttr>(unwrap(scope)), cast<DITypeAttr>(unwrap(baseType)),
DIFlags(flags), sizeInBits, alignInBits,
llvm::map_to_vector(unwrapList(nElements, elements, elementsStorage),
[](Attribute a) { return a.cast<DINodeAttr>(); })));
[](Attribute a) { return cast<DINodeAttr>(a); })));
}
MlirAttribute
@ -259,7 +259,7 @@ MlirAttribute mlirLLVMDISubroutineTypeAttrGet(MlirContext ctx,
return wrap(DISubroutineTypeAttr::get(
unwrap(ctx), callingConvention,
llvm::map_to_vector(unwrapList(nTypes, types, attrStorage),
[](Attribute a) { return a.cast<DITypeAttr>(); })));
[](Attribute a) { return cast<DITypeAttr>(a); })));
}
MlirAttribute mlirLLVMDISubprogramAttrGet(

View File

@ -311,11 +311,11 @@ MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
}
bool mlirVectorTypeIsScalable(MlirType type) {
return unwrap(type).cast<VectorType>().isScalable();
return cast<VectorType>(unwrap(type)).isScalable();
}
bool mlirVectorTypeIsDimScalable(MlirType type, intptr_t dim) {
return unwrap(type).cast<VectorType>().getScalableDims()[dim];
return cast<VectorType>(unwrap(type)).getScalableDims()[dim];
}
//===----------------------------------------------------------------------===//

View File

@ -371,7 +371,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
bool isUnsigned, Value llvmInput,
SmallVector<Value, 4> &operands) {
Type inputType = llvmInput.getType();
auto vectorType = inputType.dyn_cast<VectorType>();
auto vectorType = dyn_cast<VectorType>(inputType);
Type elemType = vectorType.getElementType();
if (elemType.isBF16())
@ -414,7 +414,7 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
Value output, int32_t subwordOffset,
bool clamp, SmallVector<Value, 4> &operands) {
Type inputType = output.getType();
auto vectorType = inputType.dyn_cast<VectorType>();
auto vectorType = dyn_cast<VectorType>(inputType);
Type elemType = vectorType.getElementType();
if (elemType.isBF16())
output = rewriter.create<LLVM::BitcastOp>(
@ -569,9 +569,8 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
/// on the architecture you are compiling for.
static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
Chipset chipset) {
auto sourceVectorType = wmma.getSourceA().getType().dyn_cast<VectorType>();
auto destVectorType = wmma.getDestC().getType().dyn_cast<VectorType>();
auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
auto elemSourceType = sourceVectorType.getElementType();
auto elemDestType = destVectorType.getElementType();
@ -727,7 +726,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
Type f32 = getTypeConverter()->convertType(op.getResult().getType());
Value source = adaptor.getSource();
auto sourceVecType = op.getSource().getType().dyn_cast<VectorType>();
auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
Type sourceElemType = getElementTypeOrSelf(op.getSource());
// Extend to a v4i8
if (!sourceVecType || sourceVecType.getNumElements() < 4) {

View File

@ -65,7 +65,7 @@ static Value castF32To(Type elementType, Value f32, Location loc,
LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
Type inType = op.getIn().getType();
if (auto inVecType = inType.dyn_cast<VectorType>()) {
if (auto inVecType = dyn_cast<VectorType>(inType)) {
if (inVecType.isScalable())
return failure();
if (inVecType.getShape().size() > 1)
@ -81,13 +81,13 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
Location loc = op.getLoc();
Value in = op.getIn();
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
if (!in.getType().isa<VectorType>()) {
if (!isa<VectorType>(in.getType())) {
Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
loc, rewriter.getF32Type(), in, 0);
Value result = castF32To(outElemType, asFloat, loc, rewriter);
return rewriter.replaceOp(op, result);
}
VectorType inType = in.getType().cast<VectorType>();
VectorType inType = cast<VectorType>(in.getType());
int64_t numElements = inType.getNumElements();
Value zero = rewriter.create<arith::ConstantOp>(
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
@ -179,7 +179,7 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
if (op.getRoundingmodeAttr())
return failure();
Type outType = op.getOut().getType();
if (auto outVecType = outType.dyn_cast<VectorType>()) {
if (auto outVecType = dyn_cast<VectorType>(outType)) {
if (outVecType.isScalable())
return failure();
if (outVecType.getShape().size() > 1)
@ -202,7 +202,7 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
if (saturateFP8)
in = clampInput(rewriter, loc, outElemType, in);
VectorType truncResType = VectorType::get(4, outElemType);
if (!in.getType().isa<VectorType>()) {
if (!isa<VectorType>(in.getType())) {
Value asFloat = castToF32(in, loc, rewriter);
Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
@ -210,7 +210,7 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
Value result = rewriter.create<vector::ExtractOp>(loc, asF8s, 0);
return rewriter.replaceOp(op, result);
}
VectorType outType = op.getOut().getType().cast<VectorType>();
VectorType outType = cast<VectorType>(op.getOut().getType());
int64_t numElements = outType.getNumElements();
Value zero = rewriter.create<arith::ConstantOp>(
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));

View File

@ -214,7 +214,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
auto remapping = signatureConversion.getInputMapping(idx);
NamedAttrList argAttr =
argAttrs ? argAttrs[idx].cast<DictionaryAttr>() : NamedAttrList();
argAttrs ? cast<DictionaryAttr>(argAttrs[idx]) : NamedAttrList();
auto copyAttribute = [&](StringRef attrName) {
Attribute attr = argAttr.erase(attrName);
if (!attr)
@ -234,9 +234,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
return;
}
for (size_t i = 0, e = remapping->size; i < e; ++i) {
if (llvmFuncOp.getArgument(remapping->inputNo + i)
.getType()
.isa<LLVM::LLVMPointerType>()) {
if (isa<LLVM::LLVMPointerType>(
llvmFuncOp.getArgument(remapping->inputNo + i).getType())) {
llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr);
}
}

View File

@ -668,7 +668,7 @@ static int32_t getCuSparseLtDataTypeFrom(Type type) {
static int32_t getCuSparseDataTypeFrom(Type type) {
if (llvm::isa<ComplexType>(type)) {
// get the element type
auto elementType = type.cast<ComplexType>().getElementType();
auto elementType = cast<ComplexType>(type).getElementType();
if (elementType.isBF16())
return 15; // CUDA_C_16BF
if (elementType.isF16())

View File

@ -1579,7 +1579,7 @@ struct NVGPUWarpgroupMmaStoreOpLowering
if (offset)
ti = makeAdd(ti, makeConst(offset));
auto structType = matrixD.getType().cast<LLVM::LLVMStructType>();
auto structType = cast<LLVM::LLVMStructType>(matrixD.getType());
// Number of 32-bit registers owns per thread
constexpr unsigned numAdjacentRegisters = 2;
@ -1606,9 +1606,9 @@ struct NVGPUWarpgroupMmaStoreOpLowering
int offset = 0;
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value matriDValue = adaptor.getMatrixD();
auto stype = matriDValue.getType().cast<LLVM::LLVMStructType>();
auto stype = cast<LLVM::LLVMStructType>(matriDValue.getType());
for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
auto structType = matrixD.cast<LLVM::LLVMStructType>();
auto structType = cast<LLVM::LLVMStructType>(matrixD);
Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx);
storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
offset += structType.getBody().size();
@ -1626,13 +1626,9 @@ struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
LLVM::LLVMStructType packStructType =
getTypeConverter()
->convertType(op.getMatrixC().getType())
.cast<LLVM::LLVMStructType>();
Type elemType = packStructType.getBody()
.front()
.cast<LLVM::LLVMStructType>()
LLVM::LLVMStructType packStructType = cast<LLVM::LLVMStructType>(
getTypeConverter()->convertType(op.getMatrixC().getType()));
Type elemType = cast<LLVM::LLVMStructType>(packStructType.getBody().front())
.getBody()
.front();
Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType));
@ -1640,7 +1636,7 @@ struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
SmallVector<Value> innerStructs;
// Unpack the structs and set all values to zero
for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) {
auto structType = s.cast<LLVM::LLVMStructType>();
auto structType = cast<LLVM::LLVMStructType>(s);
Value structValue = b.create<LLVM::ExtractValueOp>(packStruct, idx);
for (unsigned i = 0; i < structType.getBody().size(); ++i) {
structValue = b.create<LLVM::InsertValueOp>(

View File

@ -618,7 +618,7 @@ static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
static SmallVector<Value> expandInputRanks(PatternRewriter &rewriter,
Location loc, Operation *operation) {
auto rank =
operation->getResultTypes().front().cast<RankedTensorType>().getRank();
cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
return llvm::map_to_vector(operation->getOperands(), [&](Value operand) {
return expandRank(rewriter, loc, operand, rank);
});
@ -680,7 +680,7 @@ computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool,
// dimension, that is the target size. An occurrence of an additional static
// dimension greater than 1 with a different value is undefined behavior.
for (auto operand : operands) {
auto size = operand.getType().cast<RankedTensorType>().getDimSize(dim);
auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
if (!ShapedType::isDynamic(size) && size > 1)
return {rewriter.getIndexAttr(size), operand};
}
@ -688,7 +688,7 @@ computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool,
// Filter operands with dynamic dimension
auto operandsWithDynamicDim =
llvm::to_vector(llvm::make_filter_range(operands, [&](Value operand) {
return operand.getType().cast<RankedTensorType>().isDynamicDim(dim);
return cast<RankedTensorType>(operand.getType()).isDynamicDim(dim);
}));
// If no operand has a dynamic dimension, it means all sizes were 1
@ -718,7 +718,7 @@ static std::pair<SmallVector<OpFoldResult>, SmallVector<Value>>
computeTargetShape(PatternRewriter &rewriter, Location loc,
IndexPool &indexPool, ValueRange operands) {
assert(!operands.empty());
auto rank = operands.front().getType().cast<RankedTensorType>().getRank();
auto rank = cast<RankedTensorType>(operands.front().getType()).getRank();
SmallVector<OpFoldResult> targetShape;
SmallVector<Value> masterOperands;
for (auto dim : llvm::seq<int64_t>(0, rank)) {
@ -735,7 +735,7 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc,
int64_t dim, OpFoldResult targetSize,
Value masterOperand) {
// Nothing to do if this is a static dimension
auto rankedTensorType = operand.getType().cast<RankedTensorType>();
auto rankedTensorType = cast<RankedTensorType>(operand.getType());
if (!rankedTensorType.isDynamicDim(dim))
return operand;
@ -817,7 +817,7 @@ static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
IndexPool &indexPool, Value operand,
ArrayRef<OpFoldResult> targetShape,
ArrayRef<Value> masterOperands) {
int64_t rank = operand.getType().cast<RankedTensorType>().getRank();
int64_t rank = cast<RankedTensorType>(operand.getType()).getRank();
assert((int64_t)targetShape.size() == rank);
assert((int64_t)masterOperands.size() == rank);
for (auto index : llvm::seq<int64_t>(0, rank))
@ -848,8 +848,7 @@ emitElementwiseComputation(PatternRewriter &rewriter, Location loc,
Operation *operation, ValueRange operands,
ArrayRef<OpFoldResult> targetShape) {
// Generate output tensor
auto resultType =
operation->getResultTypes().front().cast<RankedTensorType>();
auto resultType = cast<RankedTensorType>(operation->getResultTypes().front());
Value outputTensor = rewriter.create<tensor::EmptyOp>(
loc, targetShape, resultType.getElementType());
@ -2274,8 +2273,7 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
llvm::SmallVector<int64_t, 3> staticSizes;
dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
auto elementType =
input.getType().cast<RankedTensorType>().getElementType();
auto elementType = cast<RankedTensorType>(input.getType()).getElementType();
return RankedTensorType::get(staticSizes, elementType);
}
@ -2327,7 +2325,7 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
auto loc = rfft2d.getLoc();
auto input = rfft2d.getInput();
auto elementType =
input.getType().cast<ShapedType>().getElementType().cast<FloatType>();
cast<FloatType>(cast<ShapedType>(input.getType()).getElementType());
// Compute the output type and set of dynamic sizes
llvm::SmallVector<Value> dynamicSizes;

View File

@ -1204,10 +1204,10 @@ convertElementwiseOp(RewriterBase &rewriter, Operation *op,
return rewriter.notifyMatchFailure(op, "no mapping");
matrixOperands.push_back(it->second);
}
auto resultType = matrixOperands[0].getType().cast<gpu::MMAMatrixType>();
auto resultType = cast<gpu::MMAMatrixType>(matrixOperands[0].getType());
if (opType == gpu::MMAElementwiseOp::EXTF) {
// The floating point extension case has a different result type.
auto vectorType = op->getResultTypes()[0].cast<VectorType>();
auto vectorType = cast<VectorType>(op->getResultTypes()[0]);
resultType = gpu::MMAMatrixType::get(resultType.getShape(),
vectorType.getElementType(),
resultType.getOperand());

View File

@ -631,8 +631,7 @@ static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter,
Type vectorType) {
const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
auto denseValue =
DenseElementsAttr::get(vectorType.cast<ShapedType>(), value);
auto denseValue = DenseElementsAttr::get(cast<ShapedType>(vectorType), value);
return rewriter.create<LLVM::ConstantOp>(loc, vectorType, denseValue);
}

View File

@ -227,8 +227,8 @@ LogicalResult WMMAOp::verify() {
Type sourceAType = getSourceA().getType();
Type destType = getDestC().getType();
VectorType sourceVectorAType = sourceAType.dyn_cast<VectorType>();
VectorType destVectorType = destType.dyn_cast<VectorType>();
VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
VectorType destVectorType = dyn_cast<VectorType>(destType);
Type sourceAElemType = sourceVectorAType.getElementType();
Type destElemType = destVectorType.getElementType();

View File

@ -26,7 +26,7 @@ struct ConstantOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto constantOp = cast<arith::ConstantOp>(op);
auto type = constantOp.getType().dyn_cast<RankedTensorType>();
auto type = dyn_cast<RankedTensorType>(constantOp.getType());
// Only ranked tensors are supported.
if (!type)

View File

@ -106,7 +106,7 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions(
targetType](Type type) -> std::optional<Type> {
if (llvm::is_contained(sourceTypes, type))
return targetType;
if (auto shaped = type.dyn_cast<ShapedType>())
if (auto shaped = dyn_cast<ShapedType>(type))
if (llvm::is_contained(sourceTypes, shaped.getElementType()))
return shaped.clone(targetType);
// All other types legal

View File

@ -99,7 +99,7 @@ public:
Value extsiLhs;
Value extsiRhs;
if (auto lhsExtInType =
origLhsExtOp.getIn().getType().dyn_cast<mlir::VectorType>()) {
dyn_cast<mlir::VectorType>(origLhsExtOp.getIn().getType())) {
if (lhsExtInType.getElementTypeBitWidth() <= 8) {
Type targetLhsExtTy =
matchContainerType(rewriter.getI8Type(), lhsExtInType);
@ -108,7 +108,7 @@ public:
}
}
if (auto rhsExtInType =
origRhsExtOp.getIn().getType().dyn_cast<mlir::VectorType>()) {
dyn_cast<mlir::VectorType>(origRhsExtOp.getIn().getType())) {
if (rhsExtInType.getElementTypeBitWidth() <= 8) {
Type targetRhsExtTy =
matchContainerType(rewriter.getI8Type(), rhsExtInType);
@ -161,9 +161,9 @@ public:
extractOperand(op.getAcc(), accPermutationMap, accOffsets);
auto inputElementType =
tiledLhs.getType().cast<ShapedType>().getElementType();
cast<ShapedType>(tiledLhs.getType()).getElementType();
auto accElementType =
tiledAcc.getType().cast<ShapedType>().getElementType();
cast<ShapedType>(tiledAcc.getType()).getElementType();
auto inputExpandedType = VectorType::get({2, 8}, inputElementType);
auto outputExpandedType = VectorType::get({2, 2}, accElementType);
@ -175,9 +175,9 @@ public:
auto emptyOperand = rewriter.create<arith::ConstantOp>(
loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType));
SmallVector<int64_t> offsets(
emptyOperand.getType().cast<ShapedType>().getRank(), 0);
cast<ShapedType>(emptyOperand.getType()).getRank(), 0);
SmallVector<int64_t> strides(
tiledOperand.getType().cast<ShapedType>().getRank(), 1);
cast<ShapedType>(tiledOperand.getType()).getRank(), 1);
return rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, tiledOperand, emptyOperand, offsets, strides);
};
@ -214,7 +214,7 @@ public:
// Insert the tiled result back into the non tiled result of the
// contract op.
SmallVector<int64_t> strides(
tiledRes.getType().cast<ShapedType>().getRank(), 1);
cast<ShapedType>(tiledRes.getType()).getRank(), 1);
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, tiledRes, result, accOffsets, strides);
}

View File

@ -39,7 +39,7 @@ static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) {
return builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(value));
}
static bool isMemref(Value v) { return v.getType().isa<BaseMemRefType>(); }
static bool isMemref(Value v) { return isa<BaseMemRefType>(v.getType()); }
//===----------------------------------------------------------------------===//
// Ownership
@ -222,8 +222,8 @@ bool ValueComparator::operator()(const Value &lhs, const Value &rhs) const {
return false;
// Block arguments are less than results.
bool lhsIsBBArg = lhs.isa<BlockArgument>();
if (lhsIsBBArg != rhs.isa<BlockArgument>()) {
bool lhsIsBBArg = isa<BlockArgument>(lhs);
if (lhsIsBBArg != isa<BlockArgument>(rhs)) {
return lhsIsBBArg;
}

View File

@ -684,7 +684,7 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
// Op is not bufferizable.
auto memSpace =
options.defaultMemorySpaceFn(value.getType().cast<TensorType>());
options.defaultMemorySpaceFn(cast<TensorType>(value.getType()));
if (!memSpace.has_value())
return op->emitError("could not infer memory space");
@ -939,7 +939,7 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
// If we do not know the memory space and there is no default memory space,
// report a failure.
auto memSpace =
options.defaultMemorySpaceFn(value.getType().cast<TensorType>());
options.defaultMemorySpaceFn(cast<TensorType>(value.getType()));
if (!memSpace.has_value())
return op->emitError("could not infer memory space");
@ -987,7 +987,7 @@ bufferization::detail::unknownGetAliasingValues(OpOperand &opOperand) {
for (Region &region : opOperand.getOwner()->getRegions())
if (!region.getBlocks().empty())
for (BlockArgument bbArg : region.getBlocks().front().getArguments())
if (bbArg.getType().isa<TensorType>())
if (isa<TensorType>(bbArg.getType()))
r.addAlias({bbArg, BufferRelation::Unknown, /*isDefinite=*/false});
return r;
}

View File

@ -46,7 +46,7 @@ static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) {
return builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(value));
}
static bool isMemref(Value v) { return v.getType().isa<BaseMemRefType>(); }
static bool isMemref(Value v) { return isa<BaseMemRefType>(v.getType()); }
/// Return "true" if the given op is guaranteed to have neither "Allocate" nor
/// "Free" side effects.

View File

@ -378,7 +378,7 @@ OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
if (!rhs)
return {};
ArrayAttr arrayAttr = rhs.dyn_cast<ArrayAttr>();
ArrayAttr arrayAttr = dyn_cast<ArrayAttr>(rhs);
if (!arrayAttr || arrayAttr.size() != 2)
return {};

View File

@ -17,7 +17,7 @@
using namespace mlir;
using namespace mlir::bufferization;
static bool isMemref(Value v) { return v.getType().isa<BaseMemRefType>(); }
static bool isMemref(Value v) { return isa<BaseMemRefType>(v.getType()); }
namespace {
/// While CondBranchOp also implement the BranchOpInterface, we add a

View File

@ -160,13 +160,13 @@ LogicalResult AddOp::verify() {
Type lhsType = getLhs().getType();
Type rhsType = getRhs().getType();
if (lhsType.isa<emitc::PointerType>() && rhsType.isa<emitc::PointerType>())
if (isa<emitc::PointerType>(lhsType) && isa<emitc::PointerType>(rhsType))
return emitOpError("requires that at most one operand is a pointer");
if ((lhsType.isa<emitc::PointerType>() &&
!rhsType.isa<IntegerType, emitc::OpaqueType>()) ||
(rhsType.isa<emitc::PointerType>() &&
!lhsType.isa<IntegerType, emitc::OpaqueType>()))
if ((isa<emitc::PointerType>(lhsType) &&
!isa<IntegerType, emitc::OpaqueType>(rhsType)) ||
(isa<emitc::PointerType>(rhsType) &&
!isa<IntegerType, emitc::OpaqueType>(lhsType)))
return emitOpError("requires that one operand is an integer or of opaque "
"type if the other is a pointer");
@ -778,16 +778,16 @@ LogicalResult SubOp::verify() {
Type rhsType = getRhs().getType();
Type resultType = getResult().getType();
if (rhsType.isa<emitc::PointerType>() && !lhsType.isa<emitc::PointerType>())
if (isa<emitc::PointerType>(rhsType) && !isa<emitc::PointerType>(lhsType))
return emitOpError("rhs can only be a pointer if lhs is a pointer");
if (lhsType.isa<emitc::PointerType>() &&
!rhsType.isa<IntegerType, emitc::OpaqueType, emitc::PointerType>())
if (isa<emitc::PointerType>(lhsType) &&
!isa<IntegerType, emitc::OpaqueType, emitc::PointerType>(rhsType))
return emitOpError("requires that rhs is an integer, pointer or of opaque "
"type if lhs is a pointer");
if (lhsType.isa<emitc::PointerType>() && rhsType.isa<emitc::PointerType>() &&
!resultType.isa<IntegerType, emitc::OpaqueType>())
if (isa<emitc::PointerType>(lhsType) && isa<emitc::PointerType>(rhsType) &&
!isa<IntegerType, emitc::OpaqueType>(resultType))
return emitOpError("requires that the result is an integer or of opaque "
"type if lhs and rhs are pointers");
return success();

View File

@ -196,7 +196,7 @@ getSubgroupMmaNativeVectorSize(Operation *op, int64_t m, int64_t n, int64_t k) {
auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users);
if (!extract)
return std::nullopt;
auto vecType = extract.getResult().getType().cast<VectorType>();
auto vecType = cast<VectorType>(extract.getResult().getType());
if (sliceType && sliceType != vecType)
return std::nullopt;
sliceType = vecType;
@ -204,7 +204,7 @@ getSubgroupMmaNativeVectorSize(Operation *op, int64_t m, int64_t n, int64_t k) {
return llvm::to_vector(sliceType.getShape());
}
if ((OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)) {
if (auto vecType = op->getResultTypes()[0].dyn_cast<VectorType>()) {
if (auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0])) {
// TODO: The condition for unrolling elementwise should be restricted
// only to operations that need unrolling (connected to the contract).
if (vecType.getRank() < 2)
@ -219,7 +219,7 @@ getSubgroupMmaNativeVectorSize(Operation *op, int64_t m, int64_t n, int64_t k) {
auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users);
if (!extract)
return std::nullopt;
auto vecType = extract.getResult().getType().cast<VectorType>();
auto vecType = cast<VectorType>(extract.getResult().getType());
if (sliceType && sliceType != vecType)
return std::nullopt;
sliceType = vecType;

View File

@ -354,7 +354,7 @@ static WalkResult loadOperation(
// Gather the variadicities of each result
for (Attribute attr : resultsOp->getVariadicity())
resultVariadicity.push_back(attr.cast<VariadicityAttr>().getValue());
resultVariadicity.push_back(cast<VariadicityAttr>(attr).getValue());
}
// Gather which constraint slots correspond to attributes constraints
@ -367,7 +367,7 @@ static WalkResult loadOperation(
for (const auto &[name, value] : llvm::zip(names, values)) {
for (auto [i, constr] : enumerate(constrToValue)) {
if (constr == value) {
attributesContraints[name.cast<StringAttr>()] = i;
attributesContraints[cast<StringAttr>(name)] = i;
break;
}
}

View File

@ -42,7 +42,7 @@ static char getRegisterType(Type type) {
return 'f';
if (type.isF64())
return 'd';
if (auto ptr = type.dyn_cast<LLVM::LLVMPointerType>()) {
if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) {
// Shared address spaces is addressed with 32-bit pointers.
if (ptr.getAddressSpace() == kSharedMemorySpace) {
return 'r';

View File

@ -559,7 +559,7 @@ static void destructureIndices(Type currType, ArrayRef<GEPArg> indices,
// we don't do anything here. The verifier will catch it and emit a proper
// error. All other canonicalization is done in the fold method.
bool requiresConst = !rawConstantIndices.empty() &&
currType.isa_and_nonnull<LLVMStructType>();
isa_and_nonnull<LLVMStructType>(currType);
if (Value val = llvm::dyn_cast_if_present<Value>(iter)) {
APInt intC;
if (requiresConst && matchPattern(val, m_ConstantInt(&intC)) &&
@ -2564,14 +2564,14 @@ LogicalResult LLVM::ConstantOp::verify() {
}
// See the comment for getLLVMConstant for more details about why 8-bit
// floats can be represented by integers.
if (getType().isa<IntegerType>() && !getType().isInteger(floatWidth)) {
if (isa<IntegerType>(getType()) && !getType().isInteger(floatWidth)) {
return emitOpError() << "expected integer type of width " << floatWidth;
}
}
if (auto splatAttr = dyn_cast<SplatElementsAttr>(getValue())) {
if (!getType().isa<VectorType>() && !getType().isa<LLVM::LLVMArrayType>() &&
!getType().isa<LLVM::LLVMFixedVectorType>() &&
!getType().isa<LLVM::LLVMScalableVectorType>())
if (!isa<VectorType>(getType()) && !isa<LLVM::LLVMArrayType>(getType()) &&
!isa<LLVM::LLVMFixedVectorType>(getType()) &&
!isa<LLVM::LLVMScalableVectorType>(getType()))
return emitOpError() << "expected vector or array type";
}
return success();

View File

@ -319,7 +319,7 @@ LogicalResult LLVM::StoreOp::ensureOnlySafeAccesses(
static Type getTypeAtIndex(const DestructurableMemorySlot &slot,
Attribute index) {
auto subelementIndexMap =
slot.elemType.cast<DestructurableTypeInterface>().getSubelementIndexMap();
cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap();
if (!subelementIndexMap)
return {};
assert(!subelementIndexMap->empty());
@ -913,8 +913,7 @@ bool LLVM::MemsetOp::canRewire(const DestructurableMemorySlot &slot,
if (getIsVolatile())
return false;
if (!slot.elemType.cast<DestructurableTypeInterface>()
.getSubelementIndexMap())
if (!cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap())
return false;
if (!areAllIndicesI32(slot))
@ -928,7 +927,7 @@ DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot,
RewriterBase &rewriter,
const DataLayout &dataLayout) {
std::optional<DenseMap<Attribute, Type>> types =
slot.elemType.cast<DestructurableTypeInterface>().getSubelementIndexMap();
cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap();
IntegerAttr memsetLenAttr;
bool successfulMatch =
@ -1047,8 +1046,7 @@ static bool memcpyCanRewire(MemcpyLike op, const DestructurableMemorySlot &slot,
if (op.getIsVolatile())
return false;
if (!slot.elemType.cast<DestructurableTypeInterface>()
.getSubelementIndexMap())
if (!cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap())
return false;
if (!areAllIndicesI32(slot))

View File

@ -475,7 +475,7 @@ LogicalResult SplitStores::matchAndRewrite(StoreOp store,
}
}
auto destructurableType = typeHint.dyn_cast<DestructurableTypeInterface>();
auto destructurableType = dyn_cast<DestructurableTypeInterface>(typeHint);
if (!destructurableType)
return failure();

View File

@ -202,9 +202,9 @@ DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation(
body,
[&](Operation *elem, Operation *red) {
return elem->getName().getStringRef() ==
(*contractionOps)[0].cast<StringAttr>().getValue() &&
cast<StringAttr>((*contractionOps)[0]).getValue() &&
red->getName().getStringRef() ==
(*contractionOps)[1].cast<StringAttr>().getValue();
cast<StringAttr>((*contractionOps)[1]).getValue();
},
os);
if (result)
@ -259,11 +259,11 @@ transform::MatchStructuredClassifyContractionDimsOp::matchOperation(
return builder.getI64IntegerAttr(value);
}));
};
results.setParams(getBatch().cast<OpResult>(),
results.setParams(cast<OpResult>(getBatch()),
makeI64Attrs(contractionDims->batch));
results.setParams(getM().cast<OpResult>(), makeI64Attrs(contractionDims->m));
results.setParams(getN().cast<OpResult>(), makeI64Attrs(contractionDims->n));
results.setParams(getK().cast<OpResult>(), makeI64Attrs(contractionDims->k));
results.setParams(cast<OpResult>(getM()), makeI64Attrs(contractionDims->m));
results.setParams(cast<OpResult>(getN()), makeI64Attrs(contractionDims->n));
results.setParams(cast<OpResult>(getK()), makeI64Attrs(contractionDims->k));
return DiagnosedSilenceableFailure::success();
}
@ -288,17 +288,17 @@ transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation(
return builder.getI64IntegerAttr(value);
}));
};
results.setParams(getBatch().cast<OpResult>(),
results.setParams(cast<OpResult>(getBatch()),
makeI64Attrs(convolutionDims->batch));
results.setParams(getOutputImage().cast<OpResult>(),
results.setParams(cast<OpResult>(getOutputImage()),
makeI64Attrs(convolutionDims->outputImage));
results.setParams(getOutputChannel().cast<OpResult>(),
results.setParams(cast<OpResult>(getOutputChannel()),
makeI64Attrs(convolutionDims->outputChannel));
results.setParams(getFilterLoop().cast<OpResult>(),
results.setParams(cast<OpResult>(getFilterLoop()),
makeI64Attrs(convolutionDims->filterLoop));
results.setParams(getInputChannel().cast<OpResult>(),
results.setParams(cast<OpResult>(getInputChannel()),
makeI64Attrs(convolutionDims->inputChannel));
results.setParams(getDepth().cast<OpResult>(),
results.setParams(cast<OpResult>(getDepth()),
makeI64Attrs(convolutionDims->depth));
auto makeI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
@ -307,9 +307,9 @@ transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation(
return builder.getI64IntegerAttr(value);
}));
};
results.setParams(getStrides().cast<OpResult>(),
results.setParams(cast<OpResult>(getStrides()),
makeI64AttrsFromI64(convolutionDims->strides));
results.setParams(getDilations().cast<OpResult>(),
results.setParams(cast<OpResult>(getDilations()),
makeI64AttrsFromI64(convolutionDims->dilations));
return DiagnosedSilenceableFailure::success();
}

View File

@ -1219,7 +1219,7 @@ transform::MatchOp::apply(transform::TransformRewriter &rewriter,
// All the operands must must be equal to the specified type
auto typeattr =
dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
Type t = typeattr.getValue().cast<::mlir::Type>();
Type t = cast<::mlir::Type>(typeattr.getValue());
if (!llvm::all_of(op->getOperandTypes(),
[&](Type operandType) { return operandType == t; }))
return;
@ -1234,7 +1234,7 @@ transform::MatchOp::apply(transform::TransformRewriter &rewriter,
for (auto [attr, operandType] :
llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
auto typeattr = cast<mlir::TypeAttr>(attr);
Type type = typeattr.getValue().cast<::mlir::Type>();
Type type = cast<::mlir::Type>(typeattr.getValue());
if (type != operandType)
return;
@ -2665,7 +2665,7 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
if (scalableSizes[ofrIdx]) {
auto val = b.create<arith::ConstantIndexOp>(
getLoc(), attr.cast<IntegerAttr>().getInt());
getLoc(), cast<IntegerAttr>(attr).getInt());
Value vscale =
b.create<vector::VectorScaleOp>(getLoc(), b.getIndexType());
sizes.push_back(

View File

@ -60,7 +60,7 @@ static void createMemcpy(OpBuilder &b, Location loc, Value tensorSource,
const linalg::BufferizeToAllocationOptions &options) {
auto tensorType = dyn_cast<RankedTensorType>(tensorSource.getType());
assert(tensorType && "expected ranked tensor");
assert(memrefDest.getType().isa<MemRefType>() && "expected ranked memref");
assert(isa<MemRefType>(memrefDest.getType()) && "expected ranked memref");
switch (options.memcpyOp) {
case linalg::BufferizeToAllocationOptions::MemcpyOp::
@ -496,10 +496,10 @@ Value linalg::bufferizeToAllocation(
if (op == nestedOp)
return;
if (llvm::any_of(nestedOp->getOperands(),
[](Value v) { return v.getType().isa<TensorType>(); }))
[](Value v) { return isa<TensorType>(v.getType()); }))
llvm_unreachable("ops with nested tensor ops are not supported yet");
if (llvm::any_of(nestedOp->getResults(),
[](Value v) { return v.getType().isa<TensorType>(); }))
[](Value v) { return isa<TensorType>(v.getType()); }))
llvm_unreachable("ops with nested tensor ops are not supported yet");
});
}
@ -508,7 +508,7 @@ Value linalg::bufferizeToAllocation(
// Gather tensor results.
SmallVector<OpResult> tensorResults;
for (OpResult result : op->getResults()) {
if (!result.getType().isa<TensorType>())
if (!isa<TensorType>(result.getType()))
continue;
// Unranked tensors are not supported
if (!isa<RankedTensorType>(result.getType()))

View File

@ -49,7 +49,7 @@ LogicalResult linalg::linalgOpAnchoredEmptyTensorEliminationStep(
for (OpOperand *in : op.getDpsInputOperands()) {
// Skip non-tensor operands.
if (!in->get().getType().isa<RankedTensorType>())
if (!isa<RankedTensorType>(in->get().getType()))
continue;
// Find tensor.empty ops on the reverse SSA use-def chain. Only follow

View File

@ -405,7 +405,7 @@ static FailureOr<ForallTilingResult> tileToForallOpImpl(
for (OpOperand &outOperand : destinationStyleOp.getDpsInitsMutable()) {
// Swap tensor inits with the corresponding block argument of the
// scf.forall op. Memref inits remain as is.
if (outOperand.get().getType().isa<TensorType>()) {
if (isa<TensorType>(outOperand.get().getType())) {
auto *it = llvm::find(dest, outOperand.get());
assert(it != dest.end() && "could not find destination tensor");
unsigned destNum = std::distance(dest.begin(), it);

View File

@ -557,7 +557,7 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
Value dest = tensor::PackOp::createDestinationTensor(
rewriter, loc, operand, innerPackSizes, innerPos,
/*outerDimsPerm=*/{});
ShapedType operandType = operand.getType().cast<ShapedType>();
ShapedType operandType = cast<ShapedType>(operand.getType());
bool areConstantTiles =
llvm::all_of(innerPackSizes, [](OpFoldResult tile) {
return getConstantIntValue(tile).has_value();
@ -565,7 +565,7 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
if (areConstantTiles && operandType.hasStaticShape() &&
!tensor::PackOp::requirePaddingValue(
operandType.getShape(), innerPos,
dest.getType().cast<ShapedType>().getShape(), {},
cast<ShapedType>(dest.getType()).getShape(), {},
innerPackSizes)) {
packOps.push_back(rewriter.create<tensor::PackOp>(
loc, operand, dest, innerPos, innerPackSizes));

View File

@ -3410,8 +3410,8 @@ struct Conv1DGenerator
// * shape_cast(broadcast(filter))
// * broadcast(shuffle(filter))
// Opt for the option without shape_cast to simplify the codegen.
auto rhsSize = rhs.getType().cast<VectorType>().getShape()[0];
auto resSize = res.getType().cast<VectorType>().getShape()[1];
auto rhsSize = cast<VectorType>(rhs.getType()).getShape()[0];
auto resSize = cast<VectorType>(res.getType()).getShape()[1];
SmallVector<int64_t, 16> indicies;
for (int i = 0; i < resSize / rhsSize; ++i) {

View File

@ -173,8 +173,8 @@ transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
}
// Assemble results.
results.set(getGlobal().cast<OpResult>(), globalOps);
results.set(getGetGlobal().cast<OpResult>(), getGlobalOps);
results.set(cast<OpResult>(getGlobal()), globalOps);
results.set(cast<OpResult>(getGetGlobal()), getGlobalOps);
return DiagnosedSilenceableFailure::success();
}

View File

@ -254,7 +254,7 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
LogicalResult
matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto convertedType = adaptor.getMemref().getType().cast<MemRefType>();
auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
auto convertedElementType = convertedType.getElementType();
auto oldElementType = op.getMemRefType().getElementType();
int srcBits = oldElementType.getIntOrFloatBitWidth();
@ -351,7 +351,7 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
LogicalResult
matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto convertedType = adaptor.getMemref().getType().cast<MemRefType>();
auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
int srcBits = op.getMemRefType().getElementTypeBitWidth();
int dstBits = convertedType.getElementTypeBitWidth();
auto dstIntegerType = rewriter.getIntegerType(dstBits);

View File

@ -68,7 +68,7 @@ struct ExpandReallocOpPattern : public OpRewritePattern<memref::ReallocOp> {
// Get the size of the original buffer.
int64_t inputSize =
op.getSource().getType().cast<BaseMemRefType>().getDimSize(0);
cast<BaseMemRefType>(op.getSource().getType()).getDimSize(0);
OpFoldResult currSize = rewriter.getIndexAttr(inputSize);
if (ShapedType::isDynamic(inputSize)) {
Value dimZero = getValueOrCreateConstantIndexOp(rewriter, loc,
@ -79,7 +79,7 @@ struct ExpandReallocOpPattern : public OpRewritePattern<memref::ReallocOp> {
// Get the requested size that the new buffer should have.
int64_t outputSize =
op.getResult().getType().cast<BaseMemRefType>().getDimSize(0);
cast<BaseMemRefType>(op.getResult().getType()).getDimSize(0);
OpFoldResult targetSize = ShapedType::isDynamic(outputSize)
? OpFoldResult{op.getDynamicResultSize()}
: rewriter.getIndexAttr(outputSize);
@ -127,7 +127,7 @@ struct ExpandReallocOpPattern : public OpRewritePattern<memref::ReallocOp> {
// is already bigger than the requested size, the cast represents a
// subview operation.
Value casted = builder.create<memref::ReinterpretCastOp>(
loc, op.getResult().getType().cast<MemRefType>(), op.getSource(),
loc, cast<MemRefType>(op.getResult().getType()), op.getSource(),
rewriter.getIndexAttr(0), ArrayRef<OpFoldResult>{targetSize},
ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)});
builder.create<scf::YieldOp>(loc, casted);

View File

@ -169,7 +169,7 @@ ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
}
Type mesh::shardType(Type type, MeshOp mesh, MeshShardingAttr sharding) {
RankedTensorType rankedTensorType = type.dyn_cast<RankedTensorType>();
RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
if (rankedTensorType) {
return shardShapedType(rankedTensorType, mesh, sharding);
}
@ -281,7 +281,8 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
}
bool MeshShardingAttr::operator==(Attribute rhs) const {
MeshShardingAttr rhsAsMeshShardingAttr = rhs.dyn_cast<MeshShardingAttr>();
MeshShardingAttr rhsAsMeshShardingAttr =
mlir::dyn_cast<MeshShardingAttr>(rhs);
return rhsAsMeshShardingAttr && *this == rhsAsMeshShardingAttr;
}
@ -484,15 +485,15 @@ static LogicalResult verifyDimensionCompatibility(Location loc,
static LogicalResult verifyGatherOperandAndResultShape(
Value operand, Value result, int64_t gatherAxis,
ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
auto resultRank = result.getType().template cast<ShapedType>().getRank();
auto resultRank = cast<ShapedType>(result.getType()).getRank();
if (gatherAxis < 0 || gatherAxis >= resultRank) {
return emitError(result.getLoc())
<< "Gather axis " << gatherAxis << " is out of bounds [0, "
<< resultRank << ").";
}
ShapedType operandType = operand.getType().cast<ShapedType>();
ShapedType resultType = result.getType().cast<ShapedType>();
ShapedType operandType = cast<ShapedType>(operand.getType());
ShapedType resultType = cast<ShapedType>(result.getType());
auto deviceGroupSize =
DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
@ -511,8 +512,8 @@ static LogicalResult verifyGatherOperandAndResultShape(
static LogicalResult verifyAllToAllOperandAndResultShape(
Value operand, Value result, int64_t splitAxis, int64_t concatAxis,
ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
ShapedType operandType = operand.getType().cast<ShapedType>();
ShapedType resultType = result.getType().cast<ShapedType>();
ShapedType operandType = cast<ShapedType>(operand.getType());
ShapedType resultType = cast<ShapedType>(result.getType());
for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
if (failed(verifyDimensionCompatibility(
@ -556,8 +557,8 @@ static LogicalResult verifyAllToAllOperandAndResultShape(
static LogicalResult verifyScatterOrSliceOperandAndResultShape(
Value operand, Value result, int64_t tensorAxis,
ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
ShapedType operandType = operand.getType().cast<ShapedType>();
ShapedType resultType = result.getType().cast<ShapedType>();
ShapedType operandType = cast<ShapedType>(operand.getType());
ShapedType resultType = cast<ShapedType>(result.getType());
for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
if (axis != tensorAxis) {
if (failed(verifyDimensionCompatibility(

View File

@ -97,7 +97,7 @@ checkOperandAffineExpr(AffineExpr expr, unsigned numDims) {
FailureOr<std::pair<bool, MeshShardingAttr>>
mesh::getMeshShardingAttr(OpResult result) {
Value val = result.cast<Value>();
Value val = cast<Value>(result);
bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) {
auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
if (!shardOp)
@ -178,7 +178,7 @@ LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
return failure();
for (OpResult result : op->getResults()) {
auto resultType = result.getType().dyn_cast<RankedTensorType>();
auto resultType = dyn_cast<RankedTensorType>(result.getType());
if (!resultType)
return failure();
AffineMap map = maps[numOperands + result.getResultNumber()];
@ -404,7 +404,7 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
if (succeeded(maybeSharding) && !maybeSharding->first)
return success();
auto resultType = result.getType().cast<RankedTensorType>();
auto resultType = cast<RankedTensorType>(result.getType());
SmallVector<SmallVector<MeshAxis>> splitAxes(resultType.getRank());
SmallVector<MeshAxis> partialAxes;
@ -457,7 +457,7 @@ static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
if (succeeded(maybeShardingAttr) && maybeShardingAttr->first)
return success();
Value operand = opOperand.get();
auto operandType = operand.getType().cast<RankedTensorType>();
auto operandType = cast<RankedTensorType>(operand.getType());
SmallVector<SmallVector<MeshAxis>> splitAxes(operandType.getRank());
unsigned numDims = map.getNumDims();
for (auto it : llvm::enumerate(map.getResults())) {
@ -526,7 +526,7 @@ LogicalResult mesh::detail::defaultAddShardingAnnotations(
static bool
isValueCompatibleWithFullReplicationSharding(Value value,
MeshShardingAttr sharding) {
if (value.getType().isa<RankedTensorType>()) {
if (isa<RankedTensorType>(value.getType())) {
return sharding && isFullReplication(sharding);
}

View File

@ -86,14 +86,13 @@ handlePartialAxesDuringResharding(OpBuilder &builder,
}
builder.setInsertionPointAfterValue(sourceShard);
TypedValue<ShapedType> resultValue =
TypedValue<ShapedType> resultValue = cast<TypedValue<ShapedType>>(
builder
.create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(),
sourceSharding.getMesh().getLeafReference(),
allReduceMeshAxes, sourceShard,
sourceSharding.getPartialType())
.getResult()
.cast<TypedValue<ShapedType>>();
.getResult());
llvm::SmallVector<MeshAxis> remainingPartialAxes;
llvm::copy_if(sourceShardingPartialAxesSet,
@ -135,13 +134,12 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
MeshShardingAttr sourceSharding,
TypedValue<ShapedType> sourceShard, MeshOp mesh,
int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
TypedValue<ShapedType> targetShard =
TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
builder
.create<AllSliceOp>(sourceShard, mesh,
ArrayRef<MeshAxis>(splitMeshAxis),
splitTensorAxis)
.getResult()
.cast<TypedValue<ShapedType>>();
.getResult());
MeshShardingAttr targetSharding = targetShardingInSplitLastAxis(
builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
return {targetShard, targetSharding};
@ -278,10 +276,8 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
APInt(64, splitTensorAxis));
ShapedType targetShape =
shardShapedType(sourceUnshardedShape, mesh, targetSharding);
TypedValue<ShapedType> targetShard =
builder.create<tensor::CastOp>(targetShape, allGatherResult)
.getResult()
.cast<TypedValue<ShapedType>>();
TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
builder.create<tensor::CastOp>(targetShape, allGatherResult).getResult());
return {targetShard, targetSharding};
}
@ -413,10 +409,8 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
ShapedType targetShape =
shardShapedType(sourceUnshardedShape, mesh, targetSharding);
TypedValue<ShapedType> targetShard =
builder.create<tensor::CastOp>(targetShape, allToAllResult)
.getResult()
.cast<TypedValue<ShapedType>>();
TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
builder.create<tensor::CastOp>(targetShape, allToAllResult).getResult());
return {targetShard, targetSharding};
}
@ -505,7 +499,7 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
return reshard(
implicitLocOpBuilder, mesh, source.getShard(), target.getShard(),
source.getSrc().cast<TypedValue<ShapedType>>(), sourceShardValue);
cast<TypedValue<ShapedType>>(source.getSrc()), sourceShardValue);
}
TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
@ -533,23 +527,22 @@ SmallVector<Type>
shardedBlockArgumentTypes(Block &block,
SymbolTableCollection &symbolTableCollection) {
SmallVector<Type> res;
llvm::transform(block.getArguments(), std::back_inserter(res),
[&symbolTableCollection](BlockArgument arg) {
auto rankedTensorArg =
arg.dyn_cast<TypedValue<RankedTensorType>>();
if (!rankedTensorArg) {
return arg.getType();
}
llvm::transform(
block.getArguments(), std::back_inserter(res),
[&symbolTableCollection](BlockArgument arg) {
auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
if (!rankedTensorArg) {
return arg.getType();
}
assert(rankedTensorArg.hasOneUse());
Operation *useOp = *rankedTensorArg.getUsers().begin();
ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
assert(shardOp);
MeshOp mesh = getMesh(shardOp, symbolTableCollection);
return shardShapedType(rankedTensorArg.getType(), mesh,
shardOp.getShardAttr())
.cast<Type>();
});
assert(rankedTensorArg.hasOneUse());
Operation *useOp = *rankedTensorArg.getUsers().begin();
ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
assert(shardOp);
MeshOp mesh = getMesh(shardOp, symbolTableCollection);
return cast<Type>(shardShapedType(rankedTensorArg.getType(), mesh,
shardOp.getShardAttr()));
});
return res;
}
@ -587,7 +580,7 @@ static SmallVector<MeshShardingAttr> getOperandShardings(Operation &op) {
res.reserve(op.getNumOperands());
llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
TypedValue<RankedTensorType> rankedTensor =
operand.dyn_cast<TypedValue<RankedTensorType>>();
dyn_cast<TypedValue<RankedTensorType>>(operand);
if (!rankedTensor) {
return MeshShardingAttr();
}
@ -608,7 +601,7 @@ static SmallVector<MeshShardingAttr> getResultShardings(Operation &op) {
llvm::transform(op.getResults(), std::back_inserter(res),
[](OpResult result) {
TypedValue<RankedTensorType> rankedTensor =
result.dyn_cast<TypedValue<RankedTensorType>>();
dyn_cast<TypedValue<RankedTensorType>>(result);
if (!rankedTensor) {
return MeshShardingAttr();
}
@ -636,9 +629,8 @@ spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
} else {
// Insert resharding.
assert(!srcShardOp.getAnnotateForUsers() && shardOp.getAnnotateForUsers());
TypedValue<ShapedType> srcSpmdValue =
spmdizationMap.lookup(srcShardOp.getOperand())
.cast<TypedValue<ShapedType>>();
TypedValue<ShapedType> srcSpmdValue = cast<TypedValue<ShapedType>>(
spmdizationMap.lookup(srcShardOp.getOperand()));
targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
symbolTableCollection);
}

View File

@ -133,7 +133,7 @@ struct AllSliceOpLowering
// insert tensor.extract_slice
RankedTensorType operandType =
op.getOperand().getType().cast<RankedTensorType>();
cast<RankedTensorType>(op.getOperand().getType());
SmallVector<OpFoldResult> sizes;
for (int64_t i = 0; i < operandType.getRank(); ++i) {
if (i == sliceAxis) {
@ -202,10 +202,9 @@ createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
ImplicitLocOpBuilder &builder) {
Operation::result_range meshShape =
builder.create<mesh::MeshShapeOp>(mesh, axes).getResults();
return arith::createProduct(builder, builder.getLoc(),
llvm::to_vector_of<Value>(meshShape),
builder.getIndexType())
.cast<TypedValue<IndexType>>();
return cast<TypedValue<IndexType>>(arith::createProduct(
builder, builder.getLoc(), llvm::to_vector_of<Value>(meshShape),
builder.getIndexType()));
}
TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,

View File

@ -651,7 +651,7 @@ private:
template <typename ApplyFn, typename ReduceFn>
static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn,
ReduceFn reduceFn) {
VectorType vectorType = vector.getType().cast<VectorType>();
VectorType vectorType = cast<VectorType>(vector.getType());
auto vectorShape = vectorType.getShape();
auto strides = computeStrides(vectorShape);
for (int64_t idx = 0, e = vectorShape[0] * strides[0]; idx < e; ++idx) {
@ -779,11 +779,11 @@ FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) {
Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get();
Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get();
Value resMemRef = linalgOp.getDpsInitOperand(0)->get();
assert(lhsMemRef.getType().cast<MemRefType>().getRank() == 2 &&
assert(cast<MemRefType>(lhsMemRef.getType()).getRank() == 2 &&
"expected lhs to be a 2D memref");
assert(rhsMemRef.getType().cast<MemRefType>().getRank() == 2 &&
assert(cast<MemRefType>(rhsMemRef.getType()).getRank() == 2 &&
"expected rhs to be a 2D memref");
assert(resMemRef.getType().cast<MemRefType>().getRank() == 2 &&
assert(cast<MemRefType>(resMemRef.getType()).getRank() == 2 &&
"expected res to be a 2D memref");
int64_t m = cast<MemRefType>(lhsMemRef.getType()).getShape()[0];

View File

@ -1318,7 +1318,7 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
for (auto privateVarInfo : llvm::zip_equal(privateVars, privatizers)) {
Type varType = std::get<0>(privateVarInfo).getType();
SymbolRefAttr privatizerSym =
std::get<1>(privateVarInfo).template cast<SymbolRefAttr>();
cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
PrivateClauseOp privatizerOp =
SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op,
privatizerSym);

View File

@ -145,9 +145,9 @@ void VarInfo::setNum(Var::Num n) {
/// mismatches.
LLVM_ATTRIBUTE_UNUSED static llvm::SMLoc
minSMLoc(AsmParser &parser, llvm::SMLoc sm1, llvm::SMLoc sm2) {
const auto loc1 = parser.getEncodedSourceLoc(sm1).dyn_cast<FileLineColLoc>();
const auto loc1 = dyn_cast<FileLineColLoc>(parser.getEncodedSourceLoc(sm1));
assert(loc1 && "Could not get `FileLineColLoc` for first `SMLoc`");
const auto loc2 = parser.getEncodedSourceLoc(sm2).dyn_cast<FileLineColLoc>();
const auto loc2 = dyn_cast<FileLineColLoc>(parser.getEncodedSourceLoc(sm2));
assert(loc2 && "Could not get `FileLineColLoc` for second `SMLoc`");
if (loc1.getFilename() != loc2.getFilename())
return SMLoc();

View File

@ -2078,7 +2078,7 @@ struct SparseTensorAsmDialectInterface : public OpAsmDialectInterface {
using OpAsmDialectInterface::OpAsmDialectInterface;
AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
if (attr.isa<SparseTensorEncodingAttr>()) {
if (isa<SparseTensorEncodingAttr>(attr)) {
os << "sparse";
return AliasResult::OverridableAlias;
}

View File

@ -29,7 +29,7 @@ LogicalResult sparse_tensor::detail::stageWithSortImpl(
Location loc = op.getLoc();
Type finalTp = op->getOpResult(0).getType();
SparseTensorType dstStt(finalTp.cast<RankedTensorType>());
SparseTensorType dstStt(cast<RankedTensorType>(finalTp));
Type srcCOOTp = dstStt.getCOOType(/*ordered=*/false);
// Clones the original operation but changing the output to an unordered COO.

View File

@ -25,7 +25,7 @@ DiagnosedSilenceableFailure transform::MatchSparseInOut::matchOperation(
return emitSilenceableFailure(current->getLoc(),
"operation has no sparse input or output");
}
results.set(getResult().cast<OpResult>(), state.getPayloadOps(getTarget()));
results.set(cast<OpResult>(getResult()), state.getPayloadOps(getTarget()));
return DiagnosedSilenceableFailure::success();
}

View File

@ -42,7 +42,7 @@ static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
if (kind == SparseTensorFieldKind::PosMemRef ||
kind == SparseTensorFieldKind::CrdMemRef ||
kind == SparseTensorFieldKind::ValMemRef) {
auto rtp = t.cast<ShapedType>();
auto rtp = cast<ShapedType>(t);
if (!directOut) {
rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
if (extraTypes)
@ -97,7 +97,7 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
mem = builder.create<sparse_tensor::ToValuesOp>(loc, inputs[0]);
toVals.push_back(mem);
} else {
ShapedType rtp = t.cast<ShapedType>();
ShapedType rtp = cast<ShapedType>(t);
rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
inputs.push_back(extraVals[extra++]);
retTypes.push_back(rtp);

View File

@ -502,7 +502,7 @@ private:
for (const AffineExpr l : order.getResults()) {
unsigned loopId = llvm::cast<AffineDimExpr>(l).getPosition();
auto itTp =
linalgOp.getIteratorTypes()[loopId].cast<linalg::IteratorTypeAttr>();
cast<linalg::IteratorTypeAttr>(linalgOp.getIteratorTypes()[loopId]);
if (linalg::isReductionIterator(itTp.getValue()))
break; // terminate at first reduction
nest++;

View File

@ -476,8 +476,8 @@ private:
if (!sel)
return std::nullopt;
auto tVal = sel.getTrueValue().dyn_cast<BlockArgument>();
auto fVal = sel.getFalseValue().dyn_cast<BlockArgument>();
auto tVal = dyn_cast<BlockArgument>(sel.getTrueValue());
auto fVal = dyn_cast<BlockArgument>(sel.getFalseValue());
// TODO: For simplicity, we only handle cases where both true/false value
// are directly loaded the input tensor. We can probably admit more cases
// in theory.
@ -487,7 +487,7 @@ private:
// Helper lambda to determine whether the value is loaded from a dense input
// or is a loop invariant.
auto isValFromDenseInputOrInvariant = [&op](Value v) -> bool {
if (auto bArg = v.dyn_cast<BlockArgument>();
if (auto bArg = dyn_cast<BlockArgument>(v);
bArg && !isSparseTensor(op.getDpsInputOperand(bArg.getArgNumber())))
return true;
// If the value is defined outside the loop, it is a loop invariant.

View File

@ -165,7 +165,7 @@ Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value,
Value sparse_tensor::genScalarToTensor(OpBuilder &builder, Location loc,
Value elem, Type dstTp) {
if (auto rtp = dstTp.dyn_cast<RankedTensorType>()) {
if (auto rtp = dyn_cast<RankedTensorType>(dstTp)) {
// Scalars can only be converted to 0-ranked tensors.
assert(rtp.getRank() == 0);
elem = sparse_tensor::genCast(builder, loc, elem, rtp.getElementType());

View File

@ -157,8 +157,7 @@ IterationGraphSorter::IterationGraphSorter(
// The number of results of the map should match the rank of the tensor.
assert(llvm::all_of(llvm::zip(loop2InsLvl, ins), [](auto mvPair) {
auto [m, v] = mvPair;
return m.getNumResults() ==
v.getType().template cast<ShapedType>().getRank();
return m.getNumResults() == cast<ShapedType>(v.getType()).getRank();
}));
itGraph.resize(getNumLoops(), std::vector<bool>(getNumLoops(), false));

View File

@ -820,7 +820,7 @@ struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
if (!destOp)
return failure();
auto resultIndex = source.cast<OpResult>().getResultNumber();
auto resultIndex = cast<OpResult>(source).getResultNumber();
auto *initOperand = destOp.getDpsInitOperand(resultIndex);
rewriter.modifyOpInPlace(
@ -3475,7 +3475,7 @@ SplatOp::reifyResultShapes(OpBuilder &builder,
OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
auto constOperand = adaptor.getInput();
if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
return {};
// Do not fold if the splat is not statically shaped
@ -4307,7 +4307,7 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
/// unpack(destinationStyleOp(x)) -> unpack(x)
if (auto dstStyleOp =
unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
auto destValue = unPackOp.getDest().cast<OpResult>();
auto destValue = cast<OpResult>(unPackOp.getDest());
Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
rewriter.modifyOpInPlace(unPackOp,
[&]() { unPackOp.setDpsInitOperand(0, newDest); });

View File

@ -32,7 +32,7 @@ namespace {
struct MatMulOpSharding
: public ShardingInterface::ExternalModel<MatMulOpSharding, MatMulOp> {
SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
auto tensorType = op->getResult(0).getType().dyn_cast<RankedTensorType>();
auto tensorType = dyn_cast<RankedTensorType>(op->getResult(0).getType());
if (!tensorType)
return {};
@ -48,7 +48,7 @@ struct MatMulOpSharding
}
SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
auto tensorType = op->getResult(0).getType().dyn_cast<RankedTensorType>();
auto tensorType = dyn_cast<RankedTensorType>(op->getResult(0).getType());
if (!tensorType)
return {};
MLIRContext *ctx = op->getContext();

View File

@ -285,7 +285,7 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
return failure();
}
if (inputElementType.isa<FloatType>()) {
if (isa<FloatType>(inputElementType)) {
// Unlike integer types, floating point types can represent infinity.
auto minClamp = op.getMinFp();
auto maxClamp = op.getMaxFp();

View File

@ -168,7 +168,7 @@ ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
return parser.emitError(parser.getCurrentLocation())
<< "expected attribute";
}
if (auto typedAttr = attr.dyn_cast<TypedAttr>()) {
if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
typeAttr = TypeAttr::get(typedAttr.getType());
}
return success();
@ -186,7 +186,7 @@ ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
Attribute attr) {
bool needsSpace = false;
auto typedAttr = attr.dyn_cast_or_null<TypedAttr>();
auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
if (!typedAttr || typedAttr.getType() != type.getValue()) {
p << ": ";
p.printAttribute(type);

View File

@ -371,7 +371,7 @@ struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
auto reductionAxis = op.getAxis();
const auto denseElementsAttr = constOp.getValue();
const auto shapedOldElementsValues =
denseElementsAttr.getType().cast<ShapedType>();
cast<ShapedType>(denseElementsAttr.getType());
if (!llvm::isa<IntegerType>(shapedOldElementsValues.getElementType()))
return rewriter.notifyMatchFailure(

View File

@ -357,7 +357,7 @@ private:
bool levelCheckTransposeConv2d(Operation *op) {
if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
if (ShapedType filterType =
transpose.getFilter().getType().dyn_cast<ShapedType>()) {
dyn_cast<ShapedType>(transpose.getFilter().getType())) {
auto shape = filterType.getShape();
assert(shape.size() == 4);
// level check kernel sizes for kH and KW

View File

@ -21,16 +21,15 @@ DiagnosedSilenceableFailure
transform::DebugEmitRemarkAtOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
if (getAt().getType().isa<TransformHandleTypeInterface>()) {
if (isa<TransformHandleTypeInterface>(getAt().getType())) {
auto payload = state.getPayloadOps(getAt());
for (Operation *op : payload)
op->emitRemark() << getMessage();
return DiagnosedSilenceableFailure::success();
}
assert(
getAt().getType().isa<transform::TransformValueHandleTypeInterface>() &&
"unhandled kind of transform type");
assert(isa<transform::TransformValueHandleTypeInterface>(getAt().getType()) &&
"unhandled kind of transform type");
auto describeValue = [](Diagnostic &os, Value value) {
os << "value handle points to ";

View File

@ -1615,7 +1615,7 @@ transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
}
params.push_back(TypeAttr::get(type));
}
results.setParams(getResult().cast<OpResult>(), params);
results.setParams(cast<OpResult>(getResult()), params);
return DiagnosedSilenceableFailure::success();
}
@ -2217,14 +2217,14 @@ transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
llvm_unreachable("unknown kind of transform dialect type");
return 0;
});
results.setParams(getNum().cast<OpResult>(),
results.setParams(cast<OpResult>(getNum()),
rewriter.getI64IntegerAttr(numAssociations));
return DiagnosedSilenceableFailure::success();
}
LogicalResult transform::NumAssociationsOp::verify() {
// Verify that the result type accepts an i64 attribute as payload.
auto resultType = getNum().getType().cast<TransformParamTypeInterface>();
auto resultType = cast<TransformParamTypeInterface>(getNum().getType());
return resultType
.checkPayload(getLoc(), {Builder(getContext()).getI64IntegerAttr(0)})
.checkAndReport();

View File

@ -44,7 +44,7 @@ DiagnosedSilenceableFailure
transform::AffineMapParamType::checkPayload(Location loc,
ArrayRef<Attribute> payload) const {
for (Attribute attr : payload) {
if (!attr.isa<AffineMapAttr>()) {
if (!mlir::isa<AffineMapAttr>(attr)) {
return emitSilenceableError(loc)
<< "expected affine map attribute, got " << attr;
}
@ -144,7 +144,7 @@ DiagnosedSilenceableFailure
transform::TypeParamType::checkPayload(Location loc,
ArrayRef<Attribute> payload) const {
for (Attribute attr : payload) {
if (!attr.isa<TypeAttr>()) {
if (!mlir::isa<TypeAttr>(attr)) {
return emitSilenceableError(loc)
<< "expected type attribute, got " << attr;
}

View File

@ -6169,7 +6169,7 @@ void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
auto constOperand = adaptor.getInput();
if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
return {};
// SplatElementsAttr::get treats single value for second arg as being a splat.

View File

@ -57,7 +57,7 @@ static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec,
Value broadcasted = extendVectorRank(builder, loc, vec, addedRank);
SmallVector<int64_t> permutation;
for (int64_t i = addedRank,
e = broadcasted.getType().cast<VectorType>().getRank();
e = cast<VectorType>(broadcasted.getType()).getRank();
i < e; ++i)
permutation.push_back(i);
for (int64_t i = 0; i < addedRank; ++i)

View File

@ -403,7 +403,7 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
// Such transposes do not materially effect the underlying vector and can
// be omitted. EG: perm [1, 0, 2] applied to vector<1x1x8xi32>
bool transposeNonOuterUnitDims = false;
auto operandShape = operands[it.index()].getType().cast<ShapedType>();
auto operandShape = cast<ShapedType>(operands[it.index()].getType());
for (auto [index, dim] :
llvm::enumerate(ArrayRef<int64_t>(perm).drop_back(1))) {
if (dim != static_cast<int64_t>(index) &&

View File

@ -63,7 +63,7 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
// new mask index) only happens on the last dimension of the vectors.
Operation *newMask = nullptr;
SmallVector<int64_t> shape(
maskOp->getResultTypes()[0].cast<VectorType>().getShape());
cast<VectorType>(maskOp->getResultTypes()[0]).getShape());
shape.back() = numElements;
auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
if (createMaskOp) {

View File

@ -171,7 +171,7 @@ static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
/// is first inserted, followed by a `memref.cast`.
static Value castToCompatibleMemRefType(OpBuilder &b, Value memref,
MemRefType compatibleMemRefType) {
MemRefType sourceType = memref.getType().cast<MemRefType>();
MemRefType sourceType = cast<MemRefType>(memref.getType());
Value res = memref;
if (sourceType.getMemorySpace() != compatibleMemRefType.getMemorySpace()) {
sourceType = MemRefType::get(

View File

@ -127,7 +127,7 @@ LogicalResult CreateNdDescOp::verify() {
// check source type matches the rank if it is a memref.
// It also should have the same ElementType as TensorDesc.
auto memrefTy = getSourceType().dyn_cast<MemRefType>();
auto memrefTy = dyn_cast<MemRefType>(getSourceType());
if (memrefTy) {
invalidRank |= (memrefTy.getRank() != rank);
invalidElemTy |= memrefTy.getElementType() != getElementType();

View File

@ -711,7 +711,7 @@ AffineMap mlir::foldAttributesIntoMap(Builder &b, AffineMap map,
for (int64_t i = 0; i < map.getNumDims(); ++i) {
if (auto attr = operands[i].dyn_cast<Attribute>()) {
dimReplacements.push_back(
b.getAffineConstantExpr(attr.cast<IntegerAttr>().getInt()));
b.getAffineConstantExpr(cast<IntegerAttr>(attr).getInt()));
} else {
dimReplacements.push_back(b.getAffineDimExpr(numDims++));
remainingValues.push_back(operands[i].get<Value>());
@ -721,7 +721,7 @@ AffineMap mlir::foldAttributesIntoMap(Builder &b, AffineMap map,
for (int64_t i = 0; i < map.getNumSymbols(); ++i) {
if (auto attr = operands[i + map.getNumDims()].dyn_cast<Attribute>()) {
symReplacements.push_back(
b.getAffineConstantExpr(attr.cast<IntegerAttr>().getInt()));
b.getAffineConstantExpr(cast<IntegerAttr>(attr).getInt()));
} else {
symReplacements.push_back(b.getAffineSymbolExpr(numSymbols++));
remainingValues.push_back(operands[i + map.getNumDims()].get<Value>());

View File

@ -1154,7 +1154,7 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultRank(Operation *op) {
// delegate function that returns rank of shaped type with known rank
auto getRank = [](const Type type) {
return type.cast<ShapedType>().getRank();
return cast<ShapedType>(type).getRank();
};
auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin())

View File

@ -2489,7 +2489,7 @@ static void genMapInfos(llvm::IRBuilderBase &builder,
auto addDevInfos = [&, fail](auto devOperands, auto devOpType) -> void {
for (const auto &devOp : devOperands) {
// TODO: Only LLVMPointerTypes are handled.
if (!devOp.getType().template isa<LLVM::LLVMPointerType>())
if (!isa<LLVM::LLVMPointerType>(devOp.getType()))
return fail();
llvm::Value *mapOpValue = moduleTranslation.lookupValue(devOp);
@ -3083,10 +3083,9 @@ convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute,
std::vector<llvm::GlobalVariable *> generatedRefs;
std::vector<llvm::Triple> targetTriple;
auto targetTripleAttr =
op->getParentOfType<mlir::ModuleOp>()
->getAttr(LLVM::LLVMDialect::getTargetTripleAttrName())
.dyn_cast_or_null<mlir::StringAttr>();
auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
op->getParentOfType<mlir::ModuleOp>()->getAttr(
LLVM::LLVMDialect::getTargetTripleAttrName()));
if (targetTripleAttr)
targetTriple.emplace_back(targetTripleAttr.data());
@ -3328,7 +3327,7 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
attribute.getName())
.Case("omp.is_target_device",
[&](Attribute attr) {
if (auto deviceAttr = attr.dyn_cast<BoolAttr>()) {
if (auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
llvm::OpenMPIRBuilderConfig &config =
moduleTranslation.getOpenMPBuilder()->Config;
config.setIsTargetDevice(deviceAttr.getValue());
@ -3338,7 +3337,7 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
})
.Case("omp.is_gpu",
[&](Attribute attr) {
if (auto gpuAttr = attr.dyn_cast<BoolAttr>()) {
if (auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
llvm::OpenMPIRBuilderConfig &config =
moduleTranslation.getOpenMPBuilder()->Config;
config.setIsGPU(gpuAttr.getValue());
@ -3348,7 +3347,7 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
})
.Case("omp.host_ir_filepath",
[&](Attribute attr) {
if (auto filepathAttr = attr.dyn_cast<StringAttr>()) {
if (auto filepathAttr = dyn_cast<StringAttr>(attr)) {
llvm::OpenMPIRBuilder *ompBuilder =
moduleTranslation.getOpenMPBuilder();
ompBuilder->loadOffloadInfoMetadata(filepathAttr.getValue());
@ -3358,13 +3357,13 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
})
.Case("omp.flags",
[&](Attribute attr) {
if (auto rtlAttr = attr.dyn_cast<omp::FlagsAttr>())
if (auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
return convertFlagsAttr(op, rtlAttr, moduleTranslation);
return failure();
})
.Case("omp.version",
[&](Attribute attr) {
if (auto versionAttr = attr.dyn_cast<omp::VersionAttr>()) {
if (auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
llvm::OpenMPIRBuilder *ompBuilder =
moduleTranslation.getOpenMPBuilder();
ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp",
@ -3376,15 +3375,14 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
.Case("omp.declare_target",
[&](Attribute attr) {
if (auto declareTargetAttr =
attr.dyn_cast<omp::DeclareTargetAttr>())
dyn_cast<omp::DeclareTargetAttr>(attr))
return convertDeclareTargetAttr(op, declareTargetAttr,
moduleTranslation);
return failure();
})
.Case("omp.requires",
[&](Attribute attr) {
if (auto requiresAttr =
attr.dyn_cast<omp::ClauseRequiresAttr>()) {
if (auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
using Requires = omp::ClauseRequires;
Requires flags = requiresAttr.getValue();
llvm::OpenMPIRBuilderConfig &config =

View File

@ -29,8 +29,8 @@ using mlir::LLVM::detail::createIntrinsicCall;
/// option around.
static llvm::Type *getXlenType(Attribute opcodeAttr,
LLVM::ModuleTranslation &moduleTranslation) {
auto intAttr = opcodeAttr.cast<IntegerAttr>();
unsigned xlenWidth = intAttr.getType().cast<IntegerType>().getWidth();
auto intAttr = cast<IntegerAttr>(opcodeAttr);
unsigned xlenWidth = cast<IntegerType>(intAttr.getType()).getWidth();
return llvm::Type::getIntNTy(moduleTranslation.getLLVMContext(), xlenWidth);
}

View File

@ -25,7 +25,7 @@ namespace {
/// according to LLVM's encoding:
/// https://lists.llvm.org/pipermail/llvm-dev/2020-October/145850.html
static std::pair<unsigned, VectorType> legalizeVectorType(const Type &type) {
VectorType vt = type.cast<VectorType>();
VectorType vt = cast<VectorType>(type);
// To simplify test pass, avoid multi-dimensional vectors.
if (!vt || vt.getRank() != 1)
return {0, nullptr};
@ -39,7 +39,7 @@ static std::pair<unsigned, VectorType> legalizeVectorType(const Type &type) {
sew = 32;
else if (eltTy.isF64())
sew = 64;
else if (auto intTy = eltTy.dyn_cast<IntegerType>())
else if (auto intTy = dyn_cast<IntegerType>(eltTy))
sew = intTy.getWidth();
else
return {0, nullptr};

View File

@ -67,12 +67,11 @@ struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
ShapedType sourceShardShape =
shardShapedType(op.getResult().getType(), mesh, op.getShard());
TypedValue<ShapedType> sourceShard =
TypedValue<ShapedType> sourceShard = cast<TypedValue<ShapedType>>(
builder
.create<UnrealizedConversionCastOp>(sourceShardShape,
op.getOperand())
->getResult(0)
.cast<TypedValue<ShapedType>>();
->getResult(0));
TypedValue<ShapedType> targetShard =
reshard(builder, mesh, op, targetShardOp, sourceShard);
Value newTargetUnsharded =

View File

@ -61,7 +61,7 @@ LogicalResult TestDialectLLVMIRTranslationInterface::amendOperation(
}
bool createSymbol = false;
if (auto boolAttr = attr.dyn_cast<BoolAttr>())
if (auto boolAttr = dyn_cast<BoolAttr>(attr))
createSymbol = boolAttr.getValue();
if (createSymbol) {

View File

@ -44,7 +44,7 @@ void TestAffineWalk::runOnOperation() {
// Test whether the walk is being correctly interrupted.
m.walk([](Operation *op) {
for (NamedAttribute attr : op->getAttrs()) {
auto mapAttr = attr.getValue().dyn_cast<AffineMapAttr>();
auto mapAttr = dyn_cast<AffineMapAttr>(attr.getValue());
if (!mapAttr)
return;
checkMod(mapAttr.getAffineMap(), op->getLoc());

View File

@ -51,7 +51,7 @@ struct TestElementsAttrInterface
InFlightDiagnostic diag = op->emitError()
<< "Test iterating `" << type << "`: ";
if (!attr.getElementType().isa<mlir::IntegerType>()) {
if (!isa<mlir::IntegerType>(attr.getElementType())) {
diag << "expected element type to be an integer type";
return;
}

View File

@ -61,7 +61,7 @@ static LogicalResult customTypeRangeResultConstraint(PatternRewriter &rewriter,
PDLResultList &results,
ArrayRef<PDLValue> args) {
auto *op = args[0].cast<Operation *>();
int numTypes = args[1].cast<Attribute>().cast<IntegerAttr>().getInt();
int numTypes = cast<IntegerAttr>(args[1].cast<Attribute>()).getInt();
if (op->getName().getStringRef() == "test.success_op") {
SmallVector<Type> types;