Switch member calls to isa/dyn_cast/cast/... to free function calls. (#89356)
This change cleans up call sites. Next step is to mark the member functions deprecated. See https://mlir.llvm.org/deprecation and https://discourse.llvm.org/t/preferred-casting-style-going-forward.
This commit is contained in:
parent
ce2f6423f0
commit
a5757c5b65
@ -142,7 +142,7 @@ mlir::transform::HasOperandSatisfyingOp::apply(
|
||||
transform::detail::prepareValueMappings(
|
||||
yieldedMappings, getBody().front().getTerminator()->getOperands(),
|
||||
state);
|
||||
results.setParams(getPosition().cast<OpResult>(),
|
||||
results.setParams(cast<OpResult>(getPosition()),
|
||||
{rewriter.getI32IntegerAttr(operand.getOperandNumber())});
|
||||
for (auto &&[result, mapping] : llvm::zip(getResults(), yieldedMappings))
|
||||
results.setMappedValues(result, mapping);
|
||||
|
||||
@ -87,7 +87,7 @@ private:
|
||||
void
|
||||
populateIteratorTypes(Type t,
|
||||
SmallVector<utils::IteratorType> &iterTypes) const {
|
||||
RankedTensorType rankedTensorType = t.dyn_cast<RankedTensorType>();
|
||||
RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(t);
|
||||
if (!rankedTensorType) {
|
||||
return;
|
||||
}
|
||||
@ -106,7 +106,7 @@ struct ElementwiseShardingInterface
|
||||
ElementwiseShardingInterface<ElemwiseOp>, ElemwiseOp> {
|
||||
SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
|
||||
Value val = op->getOperand(0);
|
||||
auto type = val.getType().dyn_cast<RankedTensorType>();
|
||||
auto type = dyn_cast<RankedTensorType>(val.getType());
|
||||
if (!type)
|
||||
return {};
|
||||
SmallVector<utils::IteratorType> types(type.getRank(),
|
||||
@ -117,7 +117,7 @@ struct ElementwiseShardingInterface
|
||||
SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
|
||||
MLIRContext *ctx = op->getContext();
|
||||
Value val = op->getOperand(0);
|
||||
auto type = val.getType().dyn_cast<RankedTensorType>();
|
||||
auto type = dyn_cast<RankedTensorType>(val.getType());
|
||||
if (!type)
|
||||
return {};
|
||||
int64_t rank = type.getRank();
|
||||
|
||||
@ -60,11 +60,11 @@ public:
|
||||
if (llvm::isa<FloatType>(resElemType))
|
||||
return impl::verifySameOperandsAndResultElementType(op);
|
||||
|
||||
if (auto resIntType = resElemType.dyn_cast<IntegerType>()) {
|
||||
if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
|
||||
IntegerType lhsIntType =
|
||||
getElementTypeOrSelf(op->getOperand(0)).cast<IntegerType>();
|
||||
cast<IntegerType>(getElementTypeOrSelf(op->getOperand(0)));
|
||||
IntegerType rhsIntType =
|
||||
getElementTypeOrSelf(op->getOperand(1)).cast<IntegerType>();
|
||||
cast<IntegerType>(getElementTypeOrSelf(op->getOperand(1)));
|
||||
if (lhsIntType != rhsIntType)
|
||||
return op->emitOpError(
|
||||
"requires the same element type for all operands");
|
||||
|
||||
@ -154,7 +154,7 @@ public:
|
||||
/// Support llvm style casting.
|
||||
static bool classof(Attribute attr) {
|
||||
auto fusedLoc = llvm::dyn_cast<FusedLoc>(attr);
|
||||
return fusedLoc && fusedLoc.getMetadata().isa_and_nonnull<MetadataT>();
|
||||
return fusedLoc && mlir::isa_and_nonnull<MetadataT>(fusedLoc.getMetadata());
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -135,7 +135,7 @@ MlirAttribute mlirLLVMDIExpressionAttrGet(MlirContext ctx, intptr_t nOperations,
|
||||
unwrap(ctx),
|
||||
llvm::map_to_vector(
|
||||
unwrapList(nOperations, operations, attrStorage),
|
||||
[](Attribute a) { return a.cast<DIExpressionElemAttr>(); })));
|
||||
[](Attribute a) { return cast<DIExpressionElemAttr>(a); })));
|
||||
}
|
||||
|
||||
MlirAttribute mlirLLVMDINullTypeAttrGet(MlirContext ctx) {
|
||||
@ -165,7 +165,7 @@ MlirAttribute mlirLLVMDICompositeTypeAttrGet(
|
||||
cast<DIScopeAttr>(unwrap(scope)), cast<DITypeAttr>(unwrap(baseType)),
|
||||
DIFlags(flags), sizeInBits, alignInBits,
|
||||
llvm::map_to_vector(unwrapList(nElements, elements, elementsStorage),
|
||||
[](Attribute a) { return a.cast<DINodeAttr>(); })));
|
||||
[](Attribute a) { return cast<DINodeAttr>(a); })));
|
||||
}
|
||||
|
||||
MlirAttribute
|
||||
@ -259,7 +259,7 @@ MlirAttribute mlirLLVMDISubroutineTypeAttrGet(MlirContext ctx,
|
||||
return wrap(DISubroutineTypeAttr::get(
|
||||
unwrap(ctx), callingConvention,
|
||||
llvm::map_to_vector(unwrapList(nTypes, types, attrStorage),
|
||||
[](Attribute a) { return a.cast<DITypeAttr>(); })));
|
||||
[](Attribute a) { return cast<DITypeAttr>(a); })));
|
||||
}
|
||||
|
||||
MlirAttribute mlirLLVMDISubprogramAttrGet(
|
||||
|
||||
@ -311,11 +311,11 @@ MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
|
||||
}
|
||||
|
||||
bool mlirVectorTypeIsScalable(MlirType type) {
|
||||
return unwrap(type).cast<VectorType>().isScalable();
|
||||
return cast<VectorType>(unwrap(type)).isScalable();
|
||||
}
|
||||
|
||||
bool mlirVectorTypeIsDimScalable(MlirType type, intptr_t dim) {
|
||||
return unwrap(type).cast<VectorType>().getScalableDims()[dim];
|
||||
return cast<VectorType>(unwrap(type)).getScalableDims()[dim];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@ -371,7 +371,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
|
||||
bool isUnsigned, Value llvmInput,
|
||||
SmallVector<Value, 4> &operands) {
|
||||
Type inputType = llvmInput.getType();
|
||||
auto vectorType = inputType.dyn_cast<VectorType>();
|
||||
auto vectorType = dyn_cast<VectorType>(inputType);
|
||||
Type elemType = vectorType.getElementType();
|
||||
|
||||
if (elemType.isBF16())
|
||||
@ -414,7 +414,7 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
|
||||
Value output, int32_t subwordOffset,
|
||||
bool clamp, SmallVector<Value, 4> &operands) {
|
||||
Type inputType = output.getType();
|
||||
auto vectorType = inputType.dyn_cast<VectorType>();
|
||||
auto vectorType = dyn_cast<VectorType>(inputType);
|
||||
Type elemType = vectorType.getElementType();
|
||||
if (elemType.isBF16())
|
||||
output = rewriter.create<LLVM::BitcastOp>(
|
||||
@ -569,9 +569,8 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
|
||||
/// on the architecture you are compiling for.
|
||||
static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
|
||||
Chipset chipset) {
|
||||
|
||||
auto sourceVectorType = wmma.getSourceA().getType().dyn_cast<VectorType>();
|
||||
auto destVectorType = wmma.getDestC().getType().dyn_cast<VectorType>();
|
||||
auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
|
||||
auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
|
||||
auto elemSourceType = sourceVectorType.getElementType();
|
||||
auto elemDestType = destVectorType.getElementType();
|
||||
|
||||
@ -727,7 +726,7 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
|
||||
Type f32 = getTypeConverter()->convertType(op.getResult().getType());
|
||||
|
||||
Value source = adaptor.getSource();
|
||||
auto sourceVecType = op.getSource().getType().dyn_cast<VectorType>();
|
||||
auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
|
||||
Type sourceElemType = getElementTypeOrSelf(op.getSource());
|
||||
// Extend to a v4i8
|
||||
if (!sourceVecType || sourceVecType.getNumElements() < 4) {
|
||||
|
||||
@ -65,7 +65,7 @@ static Value castF32To(Type elementType, Value f32, Location loc,
|
||||
|
||||
LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
|
||||
Type inType = op.getIn().getType();
|
||||
if (auto inVecType = inType.dyn_cast<VectorType>()) {
|
||||
if (auto inVecType = dyn_cast<VectorType>(inType)) {
|
||||
if (inVecType.isScalable())
|
||||
return failure();
|
||||
if (inVecType.getShape().size() > 1)
|
||||
@ -81,13 +81,13 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
|
||||
Location loc = op.getLoc();
|
||||
Value in = op.getIn();
|
||||
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
|
||||
if (!in.getType().isa<VectorType>()) {
|
||||
if (!isa<VectorType>(in.getType())) {
|
||||
Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
|
||||
loc, rewriter.getF32Type(), in, 0);
|
||||
Value result = castF32To(outElemType, asFloat, loc, rewriter);
|
||||
return rewriter.replaceOp(op, result);
|
||||
}
|
||||
VectorType inType = in.getType().cast<VectorType>();
|
||||
VectorType inType = cast<VectorType>(in.getType());
|
||||
int64_t numElements = inType.getNumElements();
|
||||
Value zero = rewriter.create<arith::ConstantOp>(
|
||||
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
|
||||
@ -179,7 +179,7 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
|
||||
if (op.getRoundingmodeAttr())
|
||||
return failure();
|
||||
Type outType = op.getOut().getType();
|
||||
if (auto outVecType = outType.dyn_cast<VectorType>()) {
|
||||
if (auto outVecType = dyn_cast<VectorType>(outType)) {
|
||||
if (outVecType.isScalable())
|
||||
return failure();
|
||||
if (outVecType.getShape().size() > 1)
|
||||
@ -202,7 +202,7 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
|
||||
if (saturateFP8)
|
||||
in = clampInput(rewriter, loc, outElemType, in);
|
||||
VectorType truncResType = VectorType::get(4, outElemType);
|
||||
if (!in.getType().isa<VectorType>()) {
|
||||
if (!isa<VectorType>(in.getType())) {
|
||||
Value asFloat = castToF32(in, loc, rewriter);
|
||||
Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
|
||||
loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
|
||||
@ -210,7 +210,7 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
|
||||
Value result = rewriter.create<vector::ExtractOp>(loc, asF8s, 0);
|
||||
return rewriter.replaceOp(op, result);
|
||||
}
|
||||
VectorType outType = op.getOut().getType().cast<VectorType>();
|
||||
VectorType outType = cast<VectorType>(op.getOut().getType());
|
||||
int64_t numElements = outType.getNumElements();
|
||||
Value zero = rewriter.create<arith::ConstantOp>(
|
||||
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
|
||||
|
||||
@ -214,7 +214,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
|
||||
llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
|
||||
auto remapping = signatureConversion.getInputMapping(idx);
|
||||
NamedAttrList argAttr =
|
||||
argAttrs ? argAttrs[idx].cast<DictionaryAttr>() : NamedAttrList();
|
||||
argAttrs ? cast<DictionaryAttr>(argAttrs[idx]) : NamedAttrList();
|
||||
auto copyAttribute = [&](StringRef attrName) {
|
||||
Attribute attr = argAttr.erase(attrName);
|
||||
if (!attr)
|
||||
@ -234,9 +234,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
|
||||
return;
|
||||
}
|
||||
for (size_t i = 0, e = remapping->size; i < e; ++i) {
|
||||
if (llvmFuncOp.getArgument(remapping->inputNo + i)
|
||||
.getType()
|
||||
.isa<LLVM::LLVMPointerType>()) {
|
||||
if (isa<LLVM::LLVMPointerType>(
|
||||
llvmFuncOp.getArgument(remapping->inputNo + i).getType())) {
|
||||
llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr);
|
||||
}
|
||||
}
|
||||
|
||||
@ -668,7 +668,7 @@ static int32_t getCuSparseLtDataTypeFrom(Type type) {
|
||||
static int32_t getCuSparseDataTypeFrom(Type type) {
|
||||
if (llvm::isa<ComplexType>(type)) {
|
||||
// get the element type
|
||||
auto elementType = type.cast<ComplexType>().getElementType();
|
||||
auto elementType = cast<ComplexType>(type).getElementType();
|
||||
if (elementType.isBF16())
|
||||
return 15; // CUDA_C_16BF
|
||||
if (elementType.isF16())
|
||||
|
||||
@ -1579,7 +1579,7 @@ struct NVGPUWarpgroupMmaStoreOpLowering
|
||||
if (offset)
|
||||
ti = makeAdd(ti, makeConst(offset));
|
||||
|
||||
auto structType = matrixD.getType().cast<LLVM::LLVMStructType>();
|
||||
auto structType = cast<LLVM::LLVMStructType>(matrixD.getType());
|
||||
|
||||
// Number of 32-bit registers owns per thread
|
||||
constexpr unsigned numAdjacentRegisters = 2;
|
||||
@ -1606,9 +1606,9 @@ struct NVGPUWarpgroupMmaStoreOpLowering
|
||||
int offset = 0;
|
||||
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
||||
Value matriDValue = adaptor.getMatrixD();
|
||||
auto stype = matriDValue.getType().cast<LLVM::LLVMStructType>();
|
||||
auto stype = cast<LLVM::LLVMStructType>(matriDValue.getType());
|
||||
for (auto [idx, matrixD] : llvm::enumerate(stype.getBody())) {
|
||||
auto structType = matrixD.cast<LLVM::LLVMStructType>();
|
||||
auto structType = cast<LLVM::LLVMStructType>(matrixD);
|
||||
Value innerStructValue = b.create<LLVM::ExtractValueOp>(matriDValue, idx);
|
||||
storeFragmentedMatrix(b, innerStructValue, op.getDstMemref(), offset);
|
||||
offset += structType.getBody().size();
|
||||
@ -1626,13 +1626,9 @@ struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
|
||||
matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
||||
LLVM::LLVMStructType packStructType =
|
||||
getTypeConverter()
|
||||
->convertType(op.getMatrixC().getType())
|
||||
.cast<LLVM::LLVMStructType>();
|
||||
Type elemType = packStructType.getBody()
|
||||
.front()
|
||||
.cast<LLVM::LLVMStructType>()
|
||||
LLVM::LLVMStructType packStructType = cast<LLVM::LLVMStructType>(
|
||||
getTypeConverter()->convertType(op.getMatrixC().getType()));
|
||||
Type elemType = cast<LLVM::LLVMStructType>(packStructType.getBody().front())
|
||||
.getBody()
|
||||
.front();
|
||||
Value zero = b.create<LLVM::ConstantOp>(elemType, b.getZeroAttr(elemType));
|
||||
@ -1640,7 +1636,7 @@ struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
|
||||
SmallVector<Value> innerStructs;
|
||||
// Unpack the structs and set all values to zero
|
||||
for (auto [idx, s] : llvm::enumerate(packStructType.getBody())) {
|
||||
auto structType = s.cast<LLVM::LLVMStructType>();
|
||||
auto structType = cast<LLVM::LLVMStructType>(s);
|
||||
Value structValue = b.create<LLVM::ExtractValueOp>(packStruct, idx);
|
||||
for (unsigned i = 0; i < structType.getBody().size(); ++i) {
|
||||
structValue = b.create<LLVM::InsertValueOp>(
|
||||
|
||||
@ -618,7 +618,7 @@ static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
|
||||
static SmallVector<Value> expandInputRanks(PatternRewriter &rewriter,
|
||||
Location loc, Operation *operation) {
|
||||
auto rank =
|
||||
operation->getResultTypes().front().cast<RankedTensorType>().getRank();
|
||||
cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
|
||||
return llvm::map_to_vector(operation->getOperands(), [&](Value operand) {
|
||||
return expandRank(rewriter, loc, operand, rank);
|
||||
});
|
||||
@ -680,7 +680,7 @@ computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool,
|
||||
// dimension, that is the target size. An occurrence of an additional static
|
||||
// dimension greater than 1 with a different value is undefined behavior.
|
||||
for (auto operand : operands) {
|
||||
auto size = operand.getType().cast<RankedTensorType>().getDimSize(dim);
|
||||
auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
|
||||
if (!ShapedType::isDynamic(size) && size > 1)
|
||||
return {rewriter.getIndexAttr(size), operand};
|
||||
}
|
||||
@ -688,7 +688,7 @@ computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool,
|
||||
// Filter operands with dynamic dimension
|
||||
auto operandsWithDynamicDim =
|
||||
llvm::to_vector(llvm::make_filter_range(operands, [&](Value operand) {
|
||||
return operand.getType().cast<RankedTensorType>().isDynamicDim(dim);
|
||||
return cast<RankedTensorType>(operand.getType()).isDynamicDim(dim);
|
||||
}));
|
||||
|
||||
// If no operand has a dynamic dimension, it means all sizes were 1
|
||||
@ -718,7 +718,7 @@ static std::pair<SmallVector<OpFoldResult>, SmallVector<Value>>
|
||||
computeTargetShape(PatternRewriter &rewriter, Location loc,
|
||||
IndexPool &indexPool, ValueRange operands) {
|
||||
assert(!operands.empty());
|
||||
auto rank = operands.front().getType().cast<RankedTensorType>().getRank();
|
||||
auto rank = cast<RankedTensorType>(operands.front().getType()).getRank();
|
||||
SmallVector<OpFoldResult> targetShape;
|
||||
SmallVector<Value> masterOperands;
|
||||
for (auto dim : llvm::seq<int64_t>(0, rank)) {
|
||||
@ -735,7 +735,7 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc,
|
||||
int64_t dim, OpFoldResult targetSize,
|
||||
Value masterOperand) {
|
||||
// Nothing to do if this is a static dimension
|
||||
auto rankedTensorType = operand.getType().cast<RankedTensorType>();
|
||||
auto rankedTensorType = cast<RankedTensorType>(operand.getType());
|
||||
if (!rankedTensorType.isDynamicDim(dim))
|
||||
return operand;
|
||||
|
||||
@ -817,7 +817,7 @@ static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
|
||||
IndexPool &indexPool, Value operand,
|
||||
ArrayRef<OpFoldResult> targetShape,
|
||||
ArrayRef<Value> masterOperands) {
|
||||
int64_t rank = operand.getType().cast<RankedTensorType>().getRank();
|
||||
int64_t rank = cast<RankedTensorType>(operand.getType()).getRank();
|
||||
assert((int64_t)targetShape.size() == rank);
|
||||
assert((int64_t)masterOperands.size() == rank);
|
||||
for (auto index : llvm::seq<int64_t>(0, rank))
|
||||
@ -848,8 +848,7 @@ emitElementwiseComputation(PatternRewriter &rewriter, Location loc,
|
||||
Operation *operation, ValueRange operands,
|
||||
ArrayRef<OpFoldResult> targetShape) {
|
||||
// Generate output tensor
|
||||
auto resultType =
|
||||
operation->getResultTypes().front().cast<RankedTensorType>();
|
||||
auto resultType = cast<RankedTensorType>(operation->getResultTypes().front());
|
||||
Value outputTensor = rewriter.create<tensor::EmptyOp>(
|
||||
loc, targetShape, resultType.getElementType());
|
||||
|
||||
@ -2274,8 +2273,7 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
|
||||
llvm::SmallVector<int64_t, 3> staticSizes;
|
||||
dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
|
||||
|
||||
auto elementType =
|
||||
input.getType().cast<RankedTensorType>().getElementType();
|
||||
auto elementType = cast<RankedTensorType>(input.getType()).getElementType();
|
||||
return RankedTensorType::get(staticSizes, elementType);
|
||||
}
|
||||
|
||||
@ -2327,7 +2325,7 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
|
||||
auto loc = rfft2d.getLoc();
|
||||
auto input = rfft2d.getInput();
|
||||
auto elementType =
|
||||
input.getType().cast<ShapedType>().getElementType().cast<FloatType>();
|
||||
cast<FloatType>(cast<ShapedType>(input.getType()).getElementType());
|
||||
|
||||
// Compute the output type and set of dynamic sizes
|
||||
llvm::SmallVector<Value> dynamicSizes;
|
||||
|
||||
@ -1204,10 +1204,10 @@ convertElementwiseOp(RewriterBase &rewriter, Operation *op,
|
||||
return rewriter.notifyMatchFailure(op, "no mapping");
|
||||
matrixOperands.push_back(it->second);
|
||||
}
|
||||
auto resultType = matrixOperands[0].getType().cast<gpu::MMAMatrixType>();
|
||||
auto resultType = cast<gpu::MMAMatrixType>(matrixOperands[0].getType());
|
||||
if (opType == gpu::MMAElementwiseOp::EXTF) {
|
||||
// The floating point extension case has a different result type.
|
||||
auto vectorType = op->getResultTypes()[0].cast<VectorType>();
|
||||
auto vectorType = cast<VectorType>(op->getResultTypes()[0]);
|
||||
resultType = gpu::MMAMatrixType::get(resultType.getShape(),
|
||||
vectorType.getElementType(),
|
||||
resultType.getOperand());
|
||||
|
||||
@ -631,8 +631,7 @@ static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter,
|
||||
Type vectorType) {
|
||||
const auto &floatSemantics = cast<FloatType>(llvmType).getFloatSemantics();
|
||||
auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics);
|
||||
auto denseValue =
|
||||
DenseElementsAttr::get(vectorType.cast<ShapedType>(), value);
|
||||
auto denseValue = DenseElementsAttr::get(cast<ShapedType>(vectorType), value);
|
||||
return rewriter.create<LLVM::ConstantOp>(loc, vectorType, denseValue);
|
||||
}
|
||||
|
||||
|
||||
@ -227,8 +227,8 @@ LogicalResult WMMAOp::verify() {
|
||||
Type sourceAType = getSourceA().getType();
|
||||
Type destType = getDestC().getType();
|
||||
|
||||
VectorType sourceVectorAType = sourceAType.dyn_cast<VectorType>();
|
||||
VectorType destVectorType = destType.dyn_cast<VectorType>();
|
||||
VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
|
||||
VectorType destVectorType = dyn_cast<VectorType>(destType);
|
||||
|
||||
Type sourceAElemType = sourceVectorAType.getElementType();
|
||||
Type destElemType = destVectorType.getElementType();
|
||||
|
||||
@ -26,7 +26,7 @@ struct ConstantOpInterface
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
const BufferizationOptions &options) const {
|
||||
auto constantOp = cast<arith::ConstantOp>(op);
|
||||
auto type = constantOp.getType().dyn_cast<RankedTensorType>();
|
||||
auto type = dyn_cast<RankedTensorType>(constantOp.getType());
|
||||
|
||||
// Only ranked tensors are supported.
|
||||
if (!type)
|
||||
|
||||
@ -106,7 +106,7 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions(
|
||||
targetType](Type type) -> std::optional<Type> {
|
||||
if (llvm::is_contained(sourceTypes, type))
|
||||
return targetType;
|
||||
if (auto shaped = type.dyn_cast<ShapedType>())
|
||||
if (auto shaped = dyn_cast<ShapedType>(type))
|
||||
if (llvm::is_contained(sourceTypes, shaped.getElementType()))
|
||||
return shaped.clone(targetType);
|
||||
// All other types legal
|
||||
|
||||
@ -99,7 +99,7 @@ public:
|
||||
Value extsiLhs;
|
||||
Value extsiRhs;
|
||||
if (auto lhsExtInType =
|
||||
origLhsExtOp.getIn().getType().dyn_cast<mlir::VectorType>()) {
|
||||
dyn_cast<mlir::VectorType>(origLhsExtOp.getIn().getType())) {
|
||||
if (lhsExtInType.getElementTypeBitWidth() <= 8) {
|
||||
Type targetLhsExtTy =
|
||||
matchContainerType(rewriter.getI8Type(), lhsExtInType);
|
||||
@ -108,7 +108,7 @@ public:
|
||||
}
|
||||
}
|
||||
if (auto rhsExtInType =
|
||||
origRhsExtOp.getIn().getType().dyn_cast<mlir::VectorType>()) {
|
||||
dyn_cast<mlir::VectorType>(origRhsExtOp.getIn().getType())) {
|
||||
if (rhsExtInType.getElementTypeBitWidth() <= 8) {
|
||||
Type targetRhsExtTy =
|
||||
matchContainerType(rewriter.getI8Type(), rhsExtInType);
|
||||
@ -161,9 +161,9 @@ public:
|
||||
extractOperand(op.getAcc(), accPermutationMap, accOffsets);
|
||||
|
||||
auto inputElementType =
|
||||
tiledLhs.getType().cast<ShapedType>().getElementType();
|
||||
cast<ShapedType>(tiledLhs.getType()).getElementType();
|
||||
auto accElementType =
|
||||
tiledAcc.getType().cast<ShapedType>().getElementType();
|
||||
cast<ShapedType>(tiledAcc.getType()).getElementType();
|
||||
auto inputExpandedType = VectorType::get({2, 8}, inputElementType);
|
||||
auto outputExpandedType = VectorType::get({2, 2}, accElementType);
|
||||
|
||||
@ -175,9 +175,9 @@ public:
|
||||
auto emptyOperand = rewriter.create<arith::ConstantOp>(
|
||||
loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType));
|
||||
SmallVector<int64_t> offsets(
|
||||
emptyOperand.getType().cast<ShapedType>().getRank(), 0);
|
||||
cast<ShapedType>(emptyOperand.getType()).getRank(), 0);
|
||||
SmallVector<int64_t> strides(
|
||||
tiledOperand.getType().cast<ShapedType>().getRank(), 1);
|
||||
cast<ShapedType>(tiledOperand.getType()).getRank(), 1);
|
||||
return rewriter.createOrFold<vector::InsertStridedSliceOp>(
|
||||
loc, tiledOperand, emptyOperand, offsets, strides);
|
||||
};
|
||||
@ -214,7 +214,7 @@ public:
|
||||
// Insert the tiled result back into the non tiled result of the
|
||||
// contract op.
|
||||
SmallVector<int64_t> strides(
|
||||
tiledRes.getType().cast<ShapedType>().getRank(), 1);
|
||||
cast<ShapedType>(tiledRes.getType()).getRank(), 1);
|
||||
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
|
||||
loc, tiledRes, result, accOffsets, strides);
|
||||
}
|
||||
|
||||
@ -39,7 +39,7 @@ static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) {
|
||||
return builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(value));
|
||||
}
|
||||
|
||||
static bool isMemref(Value v) { return v.getType().isa<BaseMemRefType>(); }
|
||||
static bool isMemref(Value v) { return isa<BaseMemRefType>(v.getType()); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Ownership
|
||||
@ -222,8 +222,8 @@ bool ValueComparator::operator()(const Value &lhs, const Value &rhs) const {
|
||||
return false;
|
||||
|
||||
// Block arguments are less than results.
|
||||
bool lhsIsBBArg = lhs.isa<BlockArgument>();
|
||||
if (lhsIsBBArg != rhs.isa<BlockArgument>()) {
|
||||
bool lhsIsBBArg = isa<BlockArgument>(lhs);
|
||||
if (lhsIsBBArg != isa<BlockArgument>(rhs)) {
|
||||
return lhsIsBBArg;
|
||||
}
|
||||
|
||||
|
||||
@ -684,7 +684,7 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
|
||||
|
||||
// Op is not bufferizable.
|
||||
auto memSpace =
|
||||
options.defaultMemorySpaceFn(value.getType().cast<TensorType>());
|
||||
options.defaultMemorySpaceFn(cast<TensorType>(value.getType()));
|
||||
if (!memSpace.has_value())
|
||||
return op->emitError("could not infer memory space");
|
||||
|
||||
@ -939,7 +939,7 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
|
||||
// If we do not know the memory space and there is no default memory space,
|
||||
// report a failure.
|
||||
auto memSpace =
|
||||
options.defaultMemorySpaceFn(value.getType().cast<TensorType>());
|
||||
options.defaultMemorySpaceFn(cast<TensorType>(value.getType()));
|
||||
if (!memSpace.has_value())
|
||||
return op->emitError("could not infer memory space");
|
||||
|
||||
@ -987,7 +987,7 @@ bufferization::detail::unknownGetAliasingValues(OpOperand &opOperand) {
|
||||
for (Region ®ion : opOperand.getOwner()->getRegions())
|
||||
if (!region.getBlocks().empty())
|
||||
for (BlockArgument bbArg : region.getBlocks().front().getArguments())
|
||||
if (bbArg.getType().isa<TensorType>())
|
||||
if (isa<TensorType>(bbArg.getType()))
|
||||
r.addAlias({bbArg, BufferRelation::Unknown, /*isDefinite=*/false});
|
||||
return r;
|
||||
}
|
||||
|
||||
@ -46,7 +46,7 @@ static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) {
|
||||
return builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(value));
|
||||
}
|
||||
|
||||
static bool isMemref(Value v) { return v.getType().isa<BaseMemRefType>(); }
|
||||
static bool isMemref(Value v) { return isa<BaseMemRefType>(v.getType()); }
|
||||
|
||||
/// Return "true" if the given op is guaranteed to have neither "Allocate" nor
|
||||
/// "Free" side effects.
|
||||
|
||||
@ -378,7 +378,7 @@ OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
|
||||
if (!rhs)
|
||||
return {};
|
||||
|
||||
ArrayAttr arrayAttr = rhs.dyn_cast<ArrayAttr>();
|
||||
ArrayAttr arrayAttr = dyn_cast<ArrayAttr>(rhs);
|
||||
if (!arrayAttr || arrayAttr.size() != 2)
|
||||
return {};
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
using namespace mlir;
|
||||
using namespace mlir::bufferization;
|
||||
|
||||
static bool isMemref(Value v) { return v.getType().isa<BaseMemRefType>(); }
|
||||
static bool isMemref(Value v) { return isa<BaseMemRefType>(v.getType()); }
|
||||
|
||||
namespace {
|
||||
/// While CondBranchOp also implement the BranchOpInterface, we add a
|
||||
|
||||
@ -160,13 +160,13 @@ LogicalResult AddOp::verify() {
|
||||
Type lhsType = getLhs().getType();
|
||||
Type rhsType = getRhs().getType();
|
||||
|
||||
if (lhsType.isa<emitc::PointerType>() && rhsType.isa<emitc::PointerType>())
|
||||
if (isa<emitc::PointerType>(lhsType) && isa<emitc::PointerType>(rhsType))
|
||||
return emitOpError("requires that at most one operand is a pointer");
|
||||
|
||||
if ((lhsType.isa<emitc::PointerType>() &&
|
||||
!rhsType.isa<IntegerType, emitc::OpaqueType>()) ||
|
||||
(rhsType.isa<emitc::PointerType>() &&
|
||||
!lhsType.isa<IntegerType, emitc::OpaqueType>()))
|
||||
if ((isa<emitc::PointerType>(lhsType) &&
|
||||
!isa<IntegerType, emitc::OpaqueType>(rhsType)) ||
|
||||
(isa<emitc::PointerType>(rhsType) &&
|
||||
!isa<IntegerType, emitc::OpaqueType>(lhsType)))
|
||||
return emitOpError("requires that one operand is an integer or of opaque "
|
||||
"type if the other is a pointer");
|
||||
|
||||
@ -778,16 +778,16 @@ LogicalResult SubOp::verify() {
|
||||
Type rhsType = getRhs().getType();
|
||||
Type resultType = getResult().getType();
|
||||
|
||||
if (rhsType.isa<emitc::PointerType>() && !lhsType.isa<emitc::PointerType>())
|
||||
if (isa<emitc::PointerType>(rhsType) && !isa<emitc::PointerType>(lhsType))
|
||||
return emitOpError("rhs can only be a pointer if lhs is a pointer");
|
||||
|
||||
if (lhsType.isa<emitc::PointerType>() &&
|
||||
!rhsType.isa<IntegerType, emitc::OpaqueType, emitc::PointerType>())
|
||||
if (isa<emitc::PointerType>(lhsType) &&
|
||||
!isa<IntegerType, emitc::OpaqueType, emitc::PointerType>(rhsType))
|
||||
return emitOpError("requires that rhs is an integer, pointer or of opaque "
|
||||
"type if lhs is a pointer");
|
||||
|
||||
if (lhsType.isa<emitc::PointerType>() && rhsType.isa<emitc::PointerType>() &&
|
||||
!resultType.isa<IntegerType, emitc::OpaqueType>())
|
||||
if (isa<emitc::PointerType>(lhsType) && isa<emitc::PointerType>(rhsType) &&
|
||||
!isa<IntegerType, emitc::OpaqueType>(resultType))
|
||||
return emitOpError("requires that the result is an integer or of opaque "
|
||||
"type if lhs and rhs are pointers");
|
||||
return success();
|
||||
|
||||
@ -196,7 +196,7 @@ getSubgroupMmaNativeVectorSize(Operation *op, int64_t m, int64_t n, int64_t k) {
|
||||
auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users);
|
||||
if (!extract)
|
||||
return std::nullopt;
|
||||
auto vecType = extract.getResult().getType().cast<VectorType>();
|
||||
auto vecType = cast<VectorType>(extract.getResult().getType());
|
||||
if (sliceType && sliceType != vecType)
|
||||
return std::nullopt;
|
||||
sliceType = vecType;
|
||||
@ -204,7 +204,7 @@ getSubgroupMmaNativeVectorSize(Operation *op, int64_t m, int64_t n, int64_t k) {
|
||||
return llvm::to_vector(sliceType.getShape());
|
||||
}
|
||||
if ((OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)) {
|
||||
if (auto vecType = op->getResultTypes()[0].dyn_cast<VectorType>()) {
|
||||
if (auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0])) {
|
||||
// TODO: The condition for unrolling elementwise should be restricted
|
||||
// only to operations that need unrolling (connected to the contract).
|
||||
if (vecType.getRank() < 2)
|
||||
@ -219,7 +219,7 @@ getSubgroupMmaNativeVectorSize(Operation *op, int64_t m, int64_t n, int64_t k) {
|
||||
auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users);
|
||||
if (!extract)
|
||||
return std::nullopt;
|
||||
auto vecType = extract.getResult().getType().cast<VectorType>();
|
||||
auto vecType = cast<VectorType>(extract.getResult().getType());
|
||||
if (sliceType && sliceType != vecType)
|
||||
return std::nullopt;
|
||||
sliceType = vecType;
|
||||
|
||||
@ -354,7 +354,7 @@ static WalkResult loadOperation(
|
||||
|
||||
// Gather the variadicities of each result
|
||||
for (Attribute attr : resultsOp->getVariadicity())
|
||||
resultVariadicity.push_back(attr.cast<VariadicityAttr>().getValue());
|
||||
resultVariadicity.push_back(cast<VariadicityAttr>(attr).getValue());
|
||||
}
|
||||
|
||||
// Gather which constraint slots correspond to attributes constraints
|
||||
@ -367,7 +367,7 @@ static WalkResult loadOperation(
|
||||
for (const auto &[name, value] : llvm::zip(names, values)) {
|
||||
for (auto [i, constr] : enumerate(constrToValue)) {
|
||||
if (constr == value) {
|
||||
attributesContraints[name.cast<StringAttr>()] = i;
|
||||
attributesContraints[cast<StringAttr>(name)] = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@ -42,7 +42,7 @@ static char getRegisterType(Type type) {
|
||||
return 'f';
|
||||
if (type.isF64())
|
||||
return 'd';
|
||||
if (auto ptr = type.dyn_cast<LLVM::LLVMPointerType>()) {
|
||||
if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) {
|
||||
// Shared address spaces is addressed with 32-bit pointers.
|
||||
if (ptr.getAddressSpace() == kSharedMemorySpace) {
|
||||
return 'r';
|
||||
|
||||
@ -559,7 +559,7 @@ static void destructureIndices(Type currType, ArrayRef<GEPArg> indices,
|
||||
// we don't do anything here. The verifier will catch it and emit a proper
|
||||
// error. All other canonicalization is done in the fold method.
|
||||
bool requiresConst = !rawConstantIndices.empty() &&
|
||||
currType.isa_and_nonnull<LLVMStructType>();
|
||||
isa_and_nonnull<LLVMStructType>(currType);
|
||||
if (Value val = llvm::dyn_cast_if_present<Value>(iter)) {
|
||||
APInt intC;
|
||||
if (requiresConst && matchPattern(val, m_ConstantInt(&intC)) &&
|
||||
@ -2564,14 +2564,14 @@ LogicalResult LLVM::ConstantOp::verify() {
|
||||
}
|
||||
// See the comment for getLLVMConstant for more details about why 8-bit
|
||||
// floats can be represented by integers.
|
||||
if (getType().isa<IntegerType>() && !getType().isInteger(floatWidth)) {
|
||||
if (isa<IntegerType>(getType()) && !getType().isInteger(floatWidth)) {
|
||||
return emitOpError() << "expected integer type of width " << floatWidth;
|
||||
}
|
||||
}
|
||||
if (auto splatAttr = dyn_cast<SplatElementsAttr>(getValue())) {
|
||||
if (!getType().isa<VectorType>() && !getType().isa<LLVM::LLVMArrayType>() &&
|
||||
!getType().isa<LLVM::LLVMFixedVectorType>() &&
|
||||
!getType().isa<LLVM::LLVMScalableVectorType>())
|
||||
if (!isa<VectorType>(getType()) && !isa<LLVM::LLVMArrayType>(getType()) &&
|
||||
!isa<LLVM::LLVMFixedVectorType>(getType()) &&
|
||||
!isa<LLVM::LLVMScalableVectorType>(getType()))
|
||||
return emitOpError() << "expected vector or array type";
|
||||
}
|
||||
return success();
|
||||
|
||||
@ -319,7 +319,7 @@ LogicalResult LLVM::StoreOp::ensureOnlySafeAccesses(
|
||||
static Type getTypeAtIndex(const DestructurableMemorySlot &slot,
|
||||
Attribute index) {
|
||||
auto subelementIndexMap =
|
||||
slot.elemType.cast<DestructurableTypeInterface>().getSubelementIndexMap();
|
||||
cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap();
|
||||
if (!subelementIndexMap)
|
||||
return {};
|
||||
assert(!subelementIndexMap->empty());
|
||||
@ -913,8 +913,7 @@ bool LLVM::MemsetOp::canRewire(const DestructurableMemorySlot &slot,
|
||||
if (getIsVolatile())
|
||||
return false;
|
||||
|
||||
if (!slot.elemType.cast<DestructurableTypeInterface>()
|
||||
.getSubelementIndexMap())
|
||||
if (!cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap())
|
||||
return false;
|
||||
|
||||
if (!areAllIndicesI32(slot))
|
||||
@ -928,7 +927,7 @@ DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot,
|
||||
RewriterBase &rewriter,
|
||||
const DataLayout &dataLayout) {
|
||||
std::optional<DenseMap<Attribute, Type>> types =
|
||||
slot.elemType.cast<DestructurableTypeInterface>().getSubelementIndexMap();
|
||||
cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap();
|
||||
|
||||
IntegerAttr memsetLenAttr;
|
||||
bool successfulMatch =
|
||||
@ -1047,8 +1046,7 @@ static bool memcpyCanRewire(MemcpyLike op, const DestructurableMemorySlot &slot,
|
||||
if (op.getIsVolatile())
|
||||
return false;
|
||||
|
||||
if (!slot.elemType.cast<DestructurableTypeInterface>()
|
||||
.getSubelementIndexMap())
|
||||
if (!cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap())
|
||||
return false;
|
||||
|
||||
if (!areAllIndicesI32(slot))
|
||||
|
||||
@ -475,7 +475,7 @@ LogicalResult SplitStores::matchAndRewrite(StoreOp store,
|
||||
}
|
||||
}
|
||||
|
||||
auto destructurableType = typeHint.dyn_cast<DestructurableTypeInterface>();
|
||||
auto destructurableType = dyn_cast<DestructurableTypeInterface>(typeHint);
|
||||
if (!destructurableType)
|
||||
return failure();
|
||||
|
||||
|
||||
@ -202,9 +202,9 @@ DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation(
|
||||
body,
|
||||
[&](Operation *elem, Operation *red) {
|
||||
return elem->getName().getStringRef() ==
|
||||
(*contractionOps)[0].cast<StringAttr>().getValue() &&
|
||||
cast<StringAttr>((*contractionOps)[0]).getValue() &&
|
||||
red->getName().getStringRef() ==
|
||||
(*contractionOps)[1].cast<StringAttr>().getValue();
|
||||
cast<StringAttr>((*contractionOps)[1]).getValue();
|
||||
},
|
||||
os);
|
||||
if (result)
|
||||
@ -259,11 +259,11 @@ transform::MatchStructuredClassifyContractionDimsOp::matchOperation(
|
||||
return builder.getI64IntegerAttr(value);
|
||||
}));
|
||||
};
|
||||
results.setParams(getBatch().cast<OpResult>(),
|
||||
results.setParams(cast<OpResult>(getBatch()),
|
||||
makeI64Attrs(contractionDims->batch));
|
||||
results.setParams(getM().cast<OpResult>(), makeI64Attrs(contractionDims->m));
|
||||
results.setParams(getN().cast<OpResult>(), makeI64Attrs(contractionDims->n));
|
||||
results.setParams(getK().cast<OpResult>(), makeI64Attrs(contractionDims->k));
|
||||
results.setParams(cast<OpResult>(getM()), makeI64Attrs(contractionDims->m));
|
||||
results.setParams(cast<OpResult>(getN()), makeI64Attrs(contractionDims->n));
|
||||
results.setParams(cast<OpResult>(getK()), makeI64Attrs(contractionDims->k));
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
@ -288,17 +288,17 @@ transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation(
|
||||
return builder.getI64IntegerAttr(value);
|
||||
}));
|
||||
};
|
||||
results.setParams(getBatch().cast<OpResult>(),
|
||||
results.setParams(cast<OpResult>(getBatch()),
|
||||
makeI64Attrs(convolutionDims->batch));
|
||||
results.setParams(getOutputImage().cast<OpResult>(),
|
||||
results.setParams(cast<OpResult>(getOutputImage()),
|
||||
makeI64Attrs(convolutionDims->outputImage));
|
||||
results.setParams(getOutputChannel().cast<OpResult>(),
|
||||
results.setParams(cast<OpResult>(getOutputChannel()),
|
||||
makeI64Attrs(convolutionDims->outputChannel));
|
||||
results.setParams(getFilterLoop().cast<OpResult>(),
|
||||
results.setParams(cast<OpResult>(getFilterLoop()),
|
||||
makeI64Attrs(convolutionDims->filterLoop));
|
||||
results.setParams(getInputChannel().cast<OpResult>(),
|
||||
results.setParams(cast<OpResult>(getInputChannel()),
|
||||
makeI64Attrs(convolutionDims->inputChannel));
|
||||
results.setParams(getDepth().cast<OpResult>(),
|
||||
results.setParams(cast<OpResult>(getDepth()),
|
||||
makeI64Attrs(convolutionDims->depth));
|
||||
|
||||
auto makeI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
|
||||
@ -307,9 +307,9 @@ transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation(
|
||||
return builder.getI64IntegerAttr(value);
|
||||
}));
|
||||
};
|
||||
results.setParams(getStrides().cast<OpResult>(),
|
||||
results.setParams(cast<OpResult>(getStrides()),
|
||||
makeI64AttrsFromI64(convolutionDims->strides));
|
||||
results.setParams(getDilations().cast<OpResult>(),
|
||||
results.setParams(cast<OpResult>(getDilations()),
|
||||
makeI64AttrsFromI64(convolutionDims->dilations));
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
@ -1219,7 +1219,7 @@ transform::MatchOp::apply(transform::TransformRewriter &rewriter,
|
||||
// All the operands must must be equal to the specified type
|
||||
auto typeattr =
|
||||
dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
|
||||
Type t = typeattr.getValue().cast<::mlir::Type>();
|
||||
Type t = cast<::mlir::Type>(typeattr.getValue());
|
||||
if (!llvm::all_of(op->getOperandTypes(),
|
||||
[&](Type operandType) { return operandType == t; }))
|
||||
return;
|
||||
@ -1234,7 +1234,7 @@ transform::MatchOp::apply(transform::TransformRewriter &rewriter,
|
||||
for (auto [attr, operandType] :
|
||||
llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
|
||||
auto typeattr = cast<mlir::TypeAttr>(attr);
|
||||
Type type = typeattr.getValue().cast<::mlir::Type>();
|
||||
Type type = cast<::mlir::Type>(typeattr.getValue());
|
||||
|
||||
if (type != operandType)
|
||||
return;
|
||||
@ -2665,7 +2665,7 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
|
||||
if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
|
||||
if (scalableSizes[ofrIdx]) {
|
||||
auto val = b.create<arith::ConstantIndexOp>(
|
||||
getLoc(), attr.cast<IntegerAttr>().getInt());
|
||||
getLoc(), cast<IntegerAttr>(attr).getInt());
|
||||
Value vscale =
|
||||
b.create<vector::VectorScaleOp>(getLoc(), b.getIndexType());
|
||||
sizes.push_back(
|
||||
|
||||
@ -60,7 +60,7 @@ static void createMemcpy(OpBuilder &b, Location loc, Value tensorSource,
|
||||
const linalg::BufferizeToAllocationOptions &options) {
|
||||
auto tensorType = dyn_cast<RankedTensorType>(tensorSource.getType());
|
||||
assert(tensorType && "expected ranked tensor");
|
||||
assert(memrefDest.getType().isa<MemRefType>() && "expected ranked memref");
|
||||
assert(isa<MemRefType>(memrefDest.getType()) && "expected ranked memref");
|
||||
|
||||
switch (options.memcpyOp) {
|
||||
case linalg::BufferizeToAllocationOptions::MemcpyOp::
|
||||
@ -496,10 +496,10 @@ Value linalg::bufferizeToAllocation(
|
||||
if (op == nestedOp)
|
||||
return;
|
||||
if (llvm::any_of(nestedOp->getOperands(),
|
||||
[](Value v) { return v.getType().isa<TensorType>(); }))
|
||||
[](Value v) { return isa<TensorType>(v.getType()); }))
|
||||
llvm_unreachable("ops with nested tensor ops are not supported yet");
|
||||
if (llvm::any_of(nestedOp->getResults(),
|
||||
[](Value v) { return v.getType().isa<TensorType>(); }))
|
||||
[](Value v) { return isa<TensorType>(v.getType()); }))
|
||||
llvm_unreachable("ops with nested tensor ops are not supported yet");
|
||||
});
|
||||
}
|
||||
@ -508,7 +508,7 @@ Value linalg::bufferizeToAllocation(
|
||||
// Gather tensor results.
|
||||
SmallVector<OpResult> tensorResults;
|
||||
for (OpResult result : op->getResults()) {
|
||||
if (!result.getType().isa<TensorType>())
|
||||
if (!isa<TensorType>(result.getType()))
|
||||
continue;
|
||||
// Unranked tensors are not supported
|
||||
if (!isa<RankedTensorType>(result.getType()))
|
||||
|
||||
@ -49,7 +49,7 @@ LogicalResult linalg::linalgOpAnchoredEmptyTensorEliminationStep(
|
||||
|
||||
for (OpOperand *in : op.getDpsInputOperands()) {
|
||||
// Skip non-tensor operands.
|
||||
if (!in->get().getType().isa<RankedTensorType>())
|
||||
if (!isa<RankedTensorType>(in->get().getType()))
|
||||
continue;
|
||||
|
||||
// Find tensor.empty ops on the reverse SSA use-def chain. Only follow
|
||||
|
||||
@ -405,7 +405,7 @@ static FailureOr<ForallTilingResult> tileToForallOpImpl(
|
||||
for (OpOperand &outOperand : destinationStyleOp.getDpsInitsMutable()) {
|
||||
// Swap tensor inits with the corresponding block argument of the
|
||||
// scf.forall op. Memref inits remain as is.
|
||||
if (outOperand.get().getType().isa<TensorType>()) {
|
||||
if (isa<TensorType>(outOperand.get().getType())) {
|
||||
auto *it = llvm::find(dest, outOperand.get());
|
||||
assert(it != dest.end() && "could not find destination tensor");
|
||||
unsigned destNum = std::distance(dest.begin(), it);
|
||||
|
||||
@ -557,7 +557,7 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
|
||||
Value dest = tensor::PackOp::createDestinationTensor(
|
||||
rewriter, loc, operand, innerPackSizes, innerPos,
|
||||
/*outerDimsPerm=*/{});
|
||||
ShapedType operandType = operand.getType().cast<ShapedType>();
|
||||
ShapedType operandType = cast<ShapedType>(operand.getType());
|
||||
bool areConstantTiles =
|
||||
llvm::all_of(innerPackSizes, [](OpFoldResult tile) {
|
||||
return getConstantIntValue(tile).has_value();
|
||||
@ -565,7 +565,7 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
|
||||
if (areConstantTiles && operandType.hasStaticShape() &&
|
||||
!tensor::PackOp::requirePaddingValue(
|
||||
operandType.getShape(), innerPos,
|
||||
dest.getType().cast<ShapedType>().getShape(), {},
|
||||
cast<ShapedType>(dest.getType()).getShape(), {},
|
||||
innerPackSizes)) {
|
||||
packOps.push_back(rewriter.create<tensor::PackOp>(
|
||||
loc, operand, dest, innerPos, innerPackSizes));
|
||||
|
||||
@ -3410,8 +3410,8 @@ struct Conv1DGenerator
|
||||
// * shape_cast(broadcast(filter))
|
||||
// * broadcast(shuffle(filter))
|
||||
// Opt for the option without shape_cast to simplify the codegen.
|
||||
auto rhsSize = rhs.getType().cast<VectorType>().getShape()[0];
|
||||
auto resSize = res.getType().cast<VectorType>().getShape()[1];
|
||||
auto rhsSize = cast<VectorType>(rhs.getType()).getShape()[0];
|
||||
auto resSize = cast<VectorType>(res.getType()).getShape()[1];
|
||||
|
||||
SmallVector<int64_t, 16> indicies;
|
||||
for (int i = 0; i < resSize / rhsSize; ++i) {
|
||||
|
||||
@ -173,8 +173,8 @@ transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
|
||||
}
|
||||
|
||||
// Assemble results.
|
||||
results.set(getGlobal().cast<OpResult>(), globalOps);
|
||||
results.set(getGetGlobal().cast<OpResult>(), getGlobalOps);
|
||||
results.set(cast<OpResult>(getGlobal()), globalOps);
|
||||
results.set(cast<OpResult>(getGetGlobal()), getGlobalOps);
|
||||
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
@ -254,7 +254,7 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
|
||||
LogicalResult
|
||||
matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto convertedType = adaptor.getMemref().getType().cast<MemRefType>();
|
||||
auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
|
||||
auto convertedElementType = convertedType.getElementType();
|
||||
auto oldElementType = op.getMemRefType().getElementType();
|
||||
int srcBits = oldElementType.getIntOrFloatBitWidth();
|
||||
@ -351,7 +351,7 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
|
||||
LogicalResult
|
||||
matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto convertedType = adaptor.getMemref().getType().cast<MemRefType>();
|
||||
auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
|
||||
int srcBits = op.getMemRefType().getElementTypeBitWidth();
|
||||
int dstBits = convertedType.getElementTypeBitWidth();
|
||||
auto dstIntegerType = rewriter.getIntegerType(dstBits);
|
||||
|
||||
@ -68,7 +68,7 @@ struct ExpandReallocOpPattern : public OpRewritePattern<memref::ReallocOp> {
|
||||
|
||||
// Get the size of the original buffer.
|
||||
int64_t inputSize =
|
||||
op.getSource().getType().cast<BaseMemRefType>().getDimSize(0);
|
||||
cast<BaseMemRefType>(op.getSource().getType()).getDimSize(0);
|
||||
OpFoldResult currSize = rewriter.getIndexAttr(inputSize);
|
||||
if (ShapedType::isDynamic(inputSize)) {
|
||||
Value dimZero = getValueOrCreateConstantIndexOp(rewriter, loc,
|
||||
@ -79,7 +79,7 @@ struct ExpandReallocOpPattern : public OpRewritePattern<memref::ReallocOp> {
|
||||
|
||||
// Get the requested size that the new buffer should have.
|
||||
int64_t outputSize =
|
||||
op.getResult().getType().cast<BaseMemRefType>().getDimSize(0);
|
||||
cast<BaseMemRefType>(op.getResult().getType()).getDimSize(0);
|
||||
OpFoldResult targetSize = ShapedType::isDynamic(outputSize)
|
||||
? OpFoldResult{op.getDynamicResultSize()}
|
||||
: rewriter.getIndexAttr(outputSize);
|
||||
@ -127,7 +127,7 @@ struct ExpandReallocOpPattern : public OpRewritePattern<memref::ReallocOp> {
|
||||
// is already bigger than the requested size, the cast represents a
|
||||
// subview operation.
|
||||
Value casted = builder.create<memref::ReinterpretCastOp>(
|
||||
loc, op.getResult().getType().cast<MemRefType>(), op.getSource(),
|
||||
loc, cast<MemRefType>(op.getResult().getType()), op.getSource(),
|
||||
rewriter.getIndexAttr(0), ArrayRef<OpFoldResult>{targetSize},
|
||||
ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)});
|
||||
builder.create<scf::YieldOp>(loc, casted);
|
||||
|
||||
@ -169,7 +169,7 @@ ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
|
||||
}
|
||||
|
||||
Type mesh::shardType(Type type, MeshOp mesh, MeshShardingAttr sharding) {
|
||||
RankedTensorType rankedTensorType = type.dyn_cast<RankedTensorType>();
|
||||
RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type);
|
||||
if (rankedTensorType) {
|
||||
return shardShapedType(rankedTensorType, mesh, sharding);
|
||||
}
|
||||
@ -281,7 +281,8 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
|
||||
}
|
||||
|
||||
bool MeshShardingAttr::operator==(Attribute rhs) const {
|
||||
MeshShardingAttr rhsAsMeshShardingAttr = rhs.dyn_cast<MeshShardingAttr>();
|
||||
MeshShardingAttr rhsAsMeshShardingAttr =
|
||||
mlir::dyn_cast<MeshShardingAttr>(rhs);
|
||||
return rhsAsMeshShardingAttr && *this == rhsAsMeshShardingAttr;
|
||||
}
|
||||
|
||||
@ -484,15 +485,15 @@ static LogicalResult verifyDimensionCompatibility(Location loc,
|
||||
static LogicalResult verifyGatherOperandAndResultShape(
|
||||
Value operand, Value result, int64_t gatherAxis,
|
||||
ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
|
||||
auto resultRank = result.getType().template cast<ShapedType>().getRank();
|
||||
auto resultRank = cast<ShapedType>(result.getType()).getRank();
|
||||
if (gatherAxis < 0 || gatherAxis >= resultRank) {
|
||||
return emitError(result.getLoc())
|
||||
<< "Gather axis " << gatherAxis << " is out of bounds [0, "
|
||||
<< resultRank << ").";
|
||||
}
|
||||
|
||||
ShapedType operandType = operand.getType().cast<ShapedType>();
|
||||
ShapedType resultType = result.getType().cast<ShapedType>();
|
||||
ShapedType operandType = cast<ShapedType>(operand.getType());
|
||||
ShapedType resultType = cast<ShapedType>(result.getType());
|
||||
auto deviceGroupSize =
|
||||
DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
|
||||
for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
|
||||
@ -511,8 +512,8 @@ static LogicalResult verifyGatherOperandAndResultShape(
|
||||
static LogicalResult verifyAllToAllOperandAndResultShape(
|
||||
Value operand, Value result, int64_t splitAxis, int64_t concatAxis,
|
||||
ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
|
||||
ShapedType operandType = operand.getType().cast<ShapedType>();
|
||||
ShapedType resultType = result.getType().cast<ShapedType>();
|
||||
ShapedType operandType = cast<ShapedType>(operand.getType());
|
||||
ShapedType resultType = cast<ShapedType>(result.getType());
|
||||
for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
|
||||
if ((axis != splitAxis && axis != concatAxis) || splitAxis == concatAxis) {
|
||||
if (failed(verifyDimensionCompatibility(
|
||||
@ -556,8 +557,8 @@ static LogicalResult verifyAllToAllOperandAndResultShape(
|
||||
static LogicalResult verifyScatterOrSliceOperandAndResultShape(
|
||||
Value operand, Value result, int64_t tensorAxis,
|
||||
ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
|
||||
ShapedType operandType = operand.getType().cast<ShapedType>();
|
||||
ShapedType resultType = result.getType().cast<ShapedType>();
|
||||
ShapedType operandType = cast<ShapedType>(operand.getType());
|
||||
ShapedType resultType = cast<ShapedType>(result.getType());
|
||||
for (int64_t axis = 0; axis < operandType.getRank(); ++axis) {
|
||||
if (axis != tensorAxis) {
|
||||
if (failed(verifyDimensionCompatibility(
|
||||
|
||||
@ -97,7 +97,7 @@ checkOperandAffineExpr(AffineExpr expr, unsigned numDims) {
|
||||
|
||||
FailureOr<std::pair<bool, MeshShardingAttr>>
|
||||
mesh::getMeshShardingAttr(OpResult result) {
|
||||
Value val = result.cast<Value>();
|
||||
Value val = cast<Value>(result);
|
||||
bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) {
|
||||
auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
|
||||
if (!shardOp)
|
||||
@ -178,7 +178,7 @@ LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
|
||||
return failure();
|
||||
|
||||
for (OpResult result : op->getResults()) {
|
||||
auto resultType = result.getType().dyn_cast<RankedTensorType>();
|
||||
auto resultType = dyn_cast<RankedTensorType>(result.getType());
|
||||
if (!resultType)
|
||||
return failure();
|
||||
AffineMap map = maps[numOperands + result.getResultNumber()];
|
||||
@ -404,7 +404,7 @@ static LogicalResult addShardOp(OpBuilder &b, OpResult result,
|
||||
if (succeeded(maybeSharding) && !maybeSharding->first)
|
||||
return success();
|
||||
|
||||
auto resultType = result.getType().cast<RankedTensorType>();
|
||||
auto resultType = cast<RankedTensorType>(result.getType());
|
||||
SmallVector<SmallVector<MeshAxis>> splitAxes(resultType.getRank());
|
||||
SmallVector<MeshAxis> partialAxes;
|
||||
|
||||
@ -457,7 +457,7 @@ static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
|
||||
if (succeeded(maybeShardingAttr) && maybeShardingAttr->first)
|
||||
return success();
|
||||
Value operand = opOperand.get();
|
||||
auto operandType = operand.getType().cast<RankedTensorType>();
|
||||
auto operandType = cast<RankedTensorType>(operand.getType());
|
||||
SmallVector<SmallVector<MeshAxis>> splitAxes(operandType.getRank());
|
||||
unsigned numDims = map.getNumDims();
|
||||
for (auto it : llvm::enumerate(map.getResults())) {
|
||||
@ -526,7 +526,7 @@ LogicalResult mesh::detail::defaultAddShardingAnnotations(
|
||||
static bool
|
||||
isValueCompatibleWithFullReplicationSharding(Value value,
|
||||
MeshShardingAttr sharding) {
|
||||
if (value.getType().isa<RankedTensorType>()) {
|
||||
if (isa<RankedTensorType>(value.getType())) {
|
||||
return sharding && isFullReplication(sharding);
|
||||
}
|
||||
|
||||
|
||||
@ -86,14 +86,13 @@ handlePartialAxesDuringResharding(OpBuilder &builder,
|
||||
}
|
||||
|
||||
builder.setInsertionPointAfterValue(sourceShard);
|
||||
TypedValue<ShapedType> resultValue =
|
||||
TypedValue<ShapedType> resultValue = cast<TypedValue<ShapedType>>(
|
||||
builder
|
||||
.create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(),
|
||||
sourceSharding.getMesh().getLeafReference(),
|
||||
allReduceMeshAxes, sourceShard,
|
||||
sourceSharding.getPartialType())
|
||||
.getResult()
|
||||
.cast<TypedValue<ShapedType>>();
|
||||
.getResult());
|
||||
|
||||
llvm::SmallVector<MeshAxis> remainingPartialAxes;
|
||||
llvm::copy_if(sourceShardingPartialAxesSet,
|
||||
@ -135,13 +134,12 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
|
||||
MeshShardingAttr sourceSharding,
|
||||
TypedValue<ShapedType> sourceShard, MeshOp mesh,
|
||||
int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
|
||||
TypedValue<ShapedType> targetShard =
|
||||
TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
|
||||
builder
|
||||
.create<AllSliceOp>(sourceShard, mesh,
|
||||
ArrayRef<MeshAxis>(splitMeshAxis),
|
||||
splitTensorAxis)
|
||||
.getResult()
|
||||
.cast<TypedValue<ShapedType>>();
|
||||
.getResult());
|
||||
MeshShardingAttr targetSharding = targetShardingInSplitLastAxis(
|
||||
builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
|
||||
return {targetShard, targetSharding};
|
||||
@ -278,10 +276,8 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
|
||||
APInt(64, splitTensorAxis));
|
||||
ShapedType targetShape =
|
||||
shardShapedType(sourceUnshardedShape, mesh, targetSharding);
|
||||
TypedValue<ShapedType> targetShard =
|
||||
builder.create<tensor::CastOp>(targetShape, allGatherResult)
|
||||
.getResult()
|
||||
.cast<TypedValue<ShapedType>>();
|
||||
TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
|
||||
builder.create<tensor::CastOp>(targetShape, allGatherResult).getResult());
|
||||
return {targetShard, targetSharding};
|
||||
}
|
||||
|
||||
@ -413,10 +409,8 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
|
||||
APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
|
||||
ShapedType targetShape =
|
||||
shardShapedType(sourceUnshardedShape, mesh, targetSharding);
|
||||
TypedValue<ShapedType> targetShard =
|
||||
builder.create<tensor::CastOp>(targetShape, allToAllResult)
|
||||
.getResult()
|
||||
.cast<TypedValue<ShapedType>>();
|
||||
TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
|
||||
builder.create<tensor::CastOp>(targetShape, allToAllResult).getResult());
|
||||
return {targetShard, targetSharding};
|
||||
}
|
||||
|
||||
@ -505,7 +499,7 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
|
||||
ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
|
||||
return reshard(
|
||||
implicitLocOpBuilder, mesh, source.getShard(), target.getShard(),
|
||||
source.getSrc().cast<TypedValue<ShapedType>>(), sourceShardValue);
|
||||
cast<TypedValue<ShapedType>>(source.getSrc()), sourceShardValue);
|
||||
}
|
||||
|
||||
TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
|
||||
@ -533,23 +527,22 @@ SmallVector<Type>
|
||||
shardedBlockArgumentTypes(Block &block,
|
||||
SymbolTableCollection &symbolTableCollection) {
|
||||
SmallVector<Type> res;
|
||||
llvm::transform(block.getArguments(), std::back_inserter(res),
|
||||
[&symbolTableCollection](BlockArgument arg) {
|
||||
auto rankedTensorArg =
|
||||
arg.dyn_cast<TypedValue<RankedTensorType>>();
|
||||
if (!rankedTensorArg) {
|
||||
return arg.getType();
|
||||
}
|
||||
llvm::transform(
|
||||
block.getArguments(), std::back_inserter(res),
|
||||
[&symbolTableCollection](BlockArgument arg) {
|
||||
auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
|
||||
if (!rankedTensorArg) {
|
||||
return arg.getType();
|
||||
}
|
||||
|
||||
assert(rankedTensorArg.hasOneUse());
|
||||
Operation *useOp = *rankedTensorArg.getUsers().begin();
|
||||
ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
|
||||
assert(shardOp);
|
||||
MeshOp mesh = getMesh(shardOp, symbolTableCollection);
|
||||
return shardShapedType(rankedTensorArg.getType(), mesh,
|
||||
shardOp.getShardAttr())
|
||||
.cast<Type>();
|
||||
});
|
||||
assert(rankedTensorArg.hasOneUse());
|
||||
Operation *useOp = *rankedTensorArg.getUsers().begin();
|
||||
ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
|
||||
assert(shardOp);
|
||||
MeshOp mesh = getMesh(shardOp, symbolTableCollection);
|
||||
return cast<Type>(shardShapedType(rankedTensorArg.getType(), mesh,
|
||||
shardOp.getShardAttr()));
|
||||
});
|
||||
return res;
|
||||
}
|
||||
|
||||
@ -587,7 +580,7 @@ static SmallVector<MeshShardingAttr> getOperandShardings(Operation &op) {
|
||||
res.reserve(op.getNumOperands());
|
||||
llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) {
|
||||
TypedValue<RankedTensorType> rankedTensor =
|
||||
operand.dyn_cast<TypedValue<RankedTensorType>>();
|
||||
dyn_cast<TypedValue<RankedTensorType>>(operand);
|
||||
if (!rankedTensor) {
|
||||
return MeshShardingAttr();
|
||||
}
|
||||
@ -608,7 +601,7 @@ static SmallVector<MeshShardingAttr> getResultShardings(Operation &op) {
|
||||
llvm::transform(op.getResults(), std::back_inserter(res),
|
||||
[](OpResult result) {
|
||||
TypedValue<RankedTensorType> rankedTensor =
|
||||
result.dyn_cast<TypedValue<RankedTensorType>>();
|
||||
dyn_cast<TypedValue<RankedTensorType>>(result);
|
||||
if (!rankedTensor) {
|
||||
return MeshShardingAttr();
|
||||
}
|
||||
@ -636,9 +629,8 @@ spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
|
||||
} else {
|
||||
// Insert resharding.
|
||||
assert(!srcShardOp.getAnnotateForUsers() && shardOp.getAnnotateForUsers());
|
||||
TypedValue<ShapedType> srcSpmdValue =
|
||||
spmdizationMap.lookup(srcShardOp.getOperand())
|
||||
.cast<TypedValue<ShapedType>>();
|
||||
TypedValue<ShapedType> srcSpmdValue = cast<TypedValue<ShapedType>>(
|
||||
spmdizationMap.lookup(srcShardOp.getOperand()));
|
||||
targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
|
||||
symbolTableCollection);
|
||||
}
|
||||
|
||||
@ -133,7 +133,7 @@ struct AllSliceOpLowering
|
||||
|
||||
// insert tensor.extract_slice
|
||||
RankedTensorType operandType =
|
||||
op.getOperand().getType().cast<RankedTensorType>();
|
||||
cast<RankedTensorType>(op.getOperand().getType());
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
for (int64_t i = 0; i < operandType.getRank(); ++i) {
|
||||
if (i == sliceAxis) {
|
||||
@ -202,10 +202,9 @@ createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
|
||||
ImplicitLocOpBuilder &builder) {
|
||||
Operation::result_range meshShape =
|
||||
builder.create<mesh::MeshShapeOp>(mesh, axes).getResults();
|
||||
return arith::createProduct(builder, builder.getLoc(),
|
||||
llvm::to_vector_of<Value>(meshShape),
|
||||
builder.getIndexType())
|
||||
.cast<TypedValue<IndexType>>();
|
||||
return cast<TypedValue<IndexType>>(arith::createProduct(
|
||||
builder, builder.getLoc(), llvm::to_vector_of<Value>(meshShape),
|
||||
builder.getIndexType()));
|
||||
}
|
||||
|
||||
TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
|
||||
|
||||
@ -651,7 +651,7 @@ private:
|
||||
template <typename ApplyFn, typename ReduceFn>
|
||||
static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn,
|
||||
ReduceFn reduceFn) {
|
||||
VectorType vectorType = vector.getType().cast<VectorType>();
|
||||
VectorType vectorType = cast<VectorType>(vector.getType());
|
||||
auto vectorShape = vectorType.getShape();
|
||||
auto strides = computeStrides(vectorShape);
|
||||
for (int64_t idx = 0, e = vectorShape[0] * strides[0]; idx < e; ++idx) {
|
||||
@ -779,11 +779,11 @@ FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) {
|
||||
Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get();
|
||||
Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get();
|
||||
Value resMemRef = linalgOp.getDpsInitOperand(0)->get();
|
||||
assert(lhsMemRef.getType().cast<MemRefType>().getRank() == 2 &&
|
||||
assert(cast<MemRefType>(lhsMemRef.getType()).getRank() == 2 &&
|
||||
"expected lhs to be a 2D memref");
|
||||
assert(rhsMemRef.getType().cast<MemRefType>().getRank() == 2 &&
|
||||
assert(cast<MemRefType>(rhsMemRef.getType()).getRank() == 2 &&
|
||||
"expected rhs to be a 2D memref");
|
||||
assert(resMemRef.getType().cast<MemRefType>().getRank() == 2 &&
|
||||
assert(cast<MemRefType>(resMemRef.getType()).getRank() == 2 &&
|
||||
"expected res to be a 2D memref");
|
||||
|
||||
int64_t m = cast<MemRefType>(lhsMemRef.getType()).getShape()[0];
|
||||
|
||||
@ -1318,7 +1318,7 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
|
||||
for (auto privateVarInfo : llvm::zip_equal(privateVars, privatizers)) {
|
||||
Type varType = std::get<0>(privateVarInfo).getType();
|
||||
SymbolRefAttr privatizerSym =
|
||||
std::get<1>(privateVarInfo).template cast<SymbolRefAttr>();
|
||||
cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
|
||||
PrivateClauseOp privatizerOp =
|
||||
SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op,
|
||||
privatizerSym);
|
||||
|
||||
@ -145,9 +145,9 @@ void VarInfo::setNum(Var::Num n) {
|
||||
/// mismatches.
|
||||
LLVM_ATTRIBUTE_UNUSED static llvm::SMLoc
|
||||
minSMLoc(AsmParser &parser, llvm::SMLoc sm1, llvm::SMLoc sm2) {
|
||||
const auto loc1 = parser.getEncodedSourceLoc(sm1).dyn_cast<FileLineColLoc>();
|
||||
const auto loc1 = dyn_cast<FileLineColLoc>(parser.getEncodedSourceLoc(sm1));
|
||||
assert(loc1 && "Could not get `FileLineColLoc` for first `SMLoc`");
|
||||
const auto loc2 = parser.getEncodedSourceLoc(sm2).dyn_cast<FileLineColLoc>();
|
||||
const auto loc2 = dyn_cast<FileLineColLoc>(parser.getEncodedSourceLoc(sm2));
|
||||
assert(loc2 && "Could not get `FileLineColLoc` for second `SMLoc`");
|
||||
if (loc1.getFilename() != loc2.getFilename())
|
||||
return SMLoc();
|
||||
|
||||
@ -2078,7 +2078,7 @@ struct SparseTensorAsmDialectInterface : public OpAsmDialectInterface {
|
||||
using OpAsmDialectInterface::OpAsmDialectInterface;
|
||||
|
||||
AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
|
||||
if (attr.isa<SparseTensorEncodingAttr>()) {
|
||||
if (isa<SparseTensorEncodingAttr>(attr)) {
|
||||
os << "sparse";
|
||||
return AliasResult::OverridableAlias;
|
||||
}
|
||||
|
||||
@ -29,7 +29,7 @@ LogicalResult sparse_tensor::detail::stageWithSortImpl(
|
||||
|
||||
Location loc = op.getLoc();
|
||||
Type finalTp = op->getOpResult(0).getType();
|
||||
SparseTensorType dstStt(finalTp.cast<RankedTensorType>());
|
||||
SparseTensorType dstStt(cast<RankedTensorType>(finalTp));
|
||||
Type srcCOOTp = dstStt.getCOOType(/*ordered=*/false);
|
||||
|
||||
// Clones the original operation but changing the output to an unordered COO.
|
||||
|
||||
@ -25,7 +25,7 @@ DiagnosedSilenceableFailure transform::MatchSparseInOut::matchOperation(
|
||||
return emitSilenceableFailure(current->getLoc(),
|
||||
"operation has no sparse input or output");
|
||||
}
|
||||
results.set(getResult().cast<OpResult>(), state.getPayloadOps(getTarget()));
|
||||
results.set(cast<OpResult>(getResult()), state.getPayloadOps(getTarget()));
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
|
||||
@ -42,7 +42,7 @@ static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
|
||||
if (kind == SparseTensorFieldKind::PosMemRef ||
|
||||
kind == SparseTensorFieldKind::CrdMemRef ||
|
||||
kind == SparseTensorFieldKind::ValMemRef) {
|
||||
auto rtp = t.cast<ShapedType>();
|
||||
auto rtp = cast<ShapedType>(t);
|
||||
if (!directOut) {
|
||||
rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
|
||||
if (extraTypes)
|
||||
@ -97,7 +97,7 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
|
||||
mem = builder.create<sparse_tensor::ToValuesOp>(loc, inputs[0]);
|
||||
toVals.push_back(mem);
|
||||
} else {
|
||||
ShapedType rtp = t.cast<ShapedType>();
|
||||
ShapedType rtp = cast<ShapedType>(t);
|
||||
rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
|
||||
inputs.push_back(extraVals[extra++]);
|
||||
retTypes.push_back(rtp);
|
||||
|
||||
@ -502,7 +502,7 @@ private:
|
||||
for (const AffineExpr l : order.getResults()) {
|
||||
unsigned loopId = llvm::cast<AffineDimExpr>(l).getPosition();
|
||||
auto itTp =
|
||||
linalgOp.getIteratorTypes()[loopId].cast<linalg::IteratorTypeAttr>();
|
||||
cast<linalg::IteratorTypeAttr>(linalgOp.getIteratorTypes()[loopId]);
|
||||
if (linalg::isReductionIterator(itTp.getValue()))
|
||||
break; // terminate at first reduction
|
||||
nest++;
|
||||
|
||||
@ -476,8 +476,8 @@ private:
|
||||
if (!sel)
|
||||
return std::nullopt;
|
||||
|
||||
auto tVal = sel.getTrueValue().dyn_cast<BlockArgument>();
|
||||
auto fVal = sel.getFalseValue().dyn_cast<BlockArgument>();
|
||||
auto tVal = dyn_cast<BlockArgument>(sel.getTrueValue());
|
||||
auto fVal = dyn_cast<BlockArgument>(sel.getFalseValue());
|
||||
// TODO: For simplicity, we only handle cases where both true/false value
|
||||
// are directly loaded the input tensor. We can probably admit more cases
|
||||
// in theory.
|
||||
@ -487,7 +487,7 @@ private:
|
||||
// Helper lambda to determine whether the value is loaded from a dense input
|
||||
// or is a loop invariant.
|
||||
auto isValFromDenseInputOrInvariant = [&op](Value v) -> bool {
|
||||
if (auto bArg = v.dyn_cast<BlockArgument>();
|
||||
if (auto bArg = dyn_cast<BlockArgument>(v);
|
||||
bArg && !isSparseTensor(op.getDpsInputOperand(bArg.getArgNumber())))
|
||||
return true;
|
||||
// If the value is defined outside the loop, it is a loop invariant.
|
||||
|
||||
@ -165,7 +165,7 @@ Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value,
|
||||
|
||||
Value sparse_tensor::genScalarToTensor(OpBuilder &builder, Location loc,
|
||||
Value elem, Type dstTp) {
|
||||
if (auto rtp = dstTp.dyn_cast<RankedTensorType>()) {
|
||||
if (auto rtp = dyn_cast<RankedTensorType>(dstTp)) {
|
||||
// Scalars can only be converted to 0-ranked tensors.
|
||||
assert(rtp.getRank() == 0);
|
||||
elem = sparse_tensor::genCast(builder, loc, elem, rtp.getElementType());
|
||||
|
||||
@ -157,8 +157,7 @@ IterationGraphSorter::IterationGraphSorter(
|
||||
// The number of results of the map should match the rank of the tensor.
|
||||
assert(llvm::all_of(llvm::zip(loop2InsLvl, ins), [](auto mvPair) {
|
||||
auto [m, v] = mvPair;
|
||||
return m.getNumResults() ==
|
||||
v.getType().template cast<ShapedType>().getRank();
|
||||
return m.getNumResults() == cast<ShapedType>(v.getType()).getRank();
|
||||
}));
|
||||
|
||||
itGraph.resize(getNumLoops(), std::vector<bool>(getNumLoops(), false));
|
||||
|
||||
@ -820,7 +820,7 @@ struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
|
||||
if (!destOp)
|
||||
return failure();
|
||||
|
||||
auto resultIndex = source.cast<OpResult>().getResultNumber();
|
||||
auto resultIndex = cast<OpResult>(source).getResultNumber();
|
||||
auto *initOperand = destOp.getDpsInitOperand(resultIndex);
|
||||
|
||||
rewriter.modifyOpInPlace(
|
||||
@ -3475,7 +3475,7 @@ SplatOp::reifyResultShapes(OpBuilder &builder,
|
||||
|
||||
OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
|
||||
auto constOperand = adaptor.getInput();
|
||||
if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
|
||||
if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
|
||||
return {};
|
||||
|
||||
// Do not fold if the splat is not statically shaped
|
||||
@ -4307,7 +4307,7 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
|
||||
/// unpack(destinationStyleOp(x)) -> unpack(x)
|
||||
if (auto dstStyleOp =
|
||||
unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
|
||||
auto destValue = unPackOp.getDest().cast<OpResult>();
|
||||
auto destValue = cast<OpResult>(unPackOp.getDest());
|
||||
Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
|
||||
rewriter.modifyOpInPlace(unPackOp,
|
||||
[&]() { unPackOp.setDpsInitOperand(0, newDest); });
|
||||
|
||||
@ -32,7 +32,7 @@ namespace {
|
||||
struct MatMulOpSharding
|
||||
: public ShardingInterface::ExternalModel<MatMulOpSharding, MatMulOp> {
|
||||
SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
|
||||
auto tensorType = op->getResult(0).getType().dyn_cast<RankedTensorType>();
|
||||
auto tensorType = dyn_cast<RankedTensorType>(op->getResult(0).getType());
|
||||
if (!tensorType)
|
||||
return {};
|
||||
|
||||
@ -48,7 +48,7 @@ struct MatMulOpSharding
|
||||
}
|
||||
|
||||
SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
|
||||
auto tensorType = op->getResult(0).getType().dyn_cast<RankedTensorType>();
|
||||
auto tensorType = dyn_cast<RankedTensorType>(op->getResult(0).getType());
|
||||
if (!tensorType)
|
||||
return {};
|
||||
MLIRContext *ctx = op->getContext();
|
||||
|
||||
@ -285,7 +285,7 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (inputElementType.isa<FloatType>()) {
|
||||
if (isa<FloatType>(inputElementType)) {
|
||||
// Unlike integer types, floating point types can represent infinity.
|
||||
auto minClamp = op.getMinFp();
|
||||
auto maxClamp = op.getMaxFp();
|
||||
|
||||
@ -168,7 +168,7 @@ ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
|
||||
return parser.emitError(parser.getCurrentLocation())
|
||||
<< "expected attribute";
|
||||
}
|
||||
if (auto typedAttr = attr.dyn_cast<TypedAttr>()) {
|
||||
if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
|
||||
typeAttr = TypeAttr::get(typedAttr.getType());
|
||||
}
|
||||
return success();
|
||||
@ -186,7 +186,7 @@ ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
|
||||
void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
|
||||
Attribute attr) {
|
||||
bool needsSpace = false;
|
||||
auto typedAttr = attr.dyn_cast_or_null<TypedAttr>();
|
||||
auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
|
||||
if (!typedAttr || typedAttr.getType() != type.getValue()) {
|
||||
p << ": ";
|
||||
p.printAttribute(type);
|
||||
|
||||
@ -371,7 +371,7 @@ struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
|
||||
auto reductionAxis = op.getAxis();
|
||||
const auto denseElementsAttr = constOp.getValue();
|
||||
const auto shapedOldElementsValues =
|
||||
denseElementsAttr.getType().cast<ShapedType>();
|
||||
cast<ShapedType>(denseElementsAttr.getType());
|
||||
|
||||
if (!llvm::isa<IntegerType>(shapedOldElementsValues.getElementType()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
||||
@ -357,7 +357,7 @@ private:
|
||||
bool levelCheckTransposeConv2d(Operation *op) {
|
||||
if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
|
||||
if (ShapedType filterType =
|
||||
transpose.getFilter().getType().dyn_cast<ShapedType>()) {
|
||||
dyn_cast<ShapedType>(transpose.getFilter().getType())) {
|
||||
auto shape = filterType.getShape();
|
||||
assert(shape.size() == 4);
|
||||
// level check kernel sizes for kH and KW
|
||||
|
||||
@ -21,16 +21,15 @@ DiagnosedSilenceableFailure
|
||||
transform::DebugEmitRemarkAtOp::apply(transform::TransformRewriter &rewriter,
|
||||
transform::TransformResults &results,
|
||||
transform::TransformState &state) {
|
||||
if (getAt().getType().isa<TransformHandleTypeInterface>()) {
|
||||
if (isa<TransformHandleTypeInterface>(getAt().getType())) {
|
||||
auto payload = state.getPayloadOps(getAt());
|
||||
for (Operation *op : payload)
|
||||
op->emitRemark() << getMessage();
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
assert(
|
||||
getAt().getType().isa<transform::TransformValueHandleTypeInterface>() &&
|
||||
"unhandled kind of transform type");
|
||||
assert(isa<transform::TransformValueHandleTypeInterface>(getAt().getType()) &&
|
||||
"unhandled kind of transform type");
|
||||
|
||||
auto describeValue = [](Diagnostic &os, Value value) {
|
||||
os << "value handle points to ";
|
||||
|
||||
@ -1615,7 +1615,7 @@ transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
|
||||
}
|
||||
params.push_back(TypeAttr::get(type));
|
||||
}
|
||||
results.setParams(getResult().cast<OpResult>(), params);
|
||||
results.setParams(cast<OpResult>(getResult()), params);
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
@ -2217,14 +2217,14 @@ transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
|
||||
llvm_unreachable("unknown kind of transform dialect type");
|
||||
return 0;
|
||||
});
|
||||
results.setParams(getNum().cast<OpResult>(),
|
||||
results.setParams(cast<OpResult>(getNum()),
|
||||
rewriter.getI64IntegerAttr(numAssociations));
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
LogicalResult transform::NumAssociationsOp::verify() {
|
||||
// Verify that the result type accepts an i64 attribute as payload.
|
||||
auto resultType = getNum().getType().cast<TransformParamTypeInterface>();
|
||||
auto resultType = cast<TransformParamTypeInterface>(getNum().getType());
|
||||
return resultType
|
||||
.checkPayload(getLoc(), {Builder(getContext()).getI64IntegerAttr(0)})
|
||||
.checkAndReport();
|
||||
|
||||
@ -44,7 +44,7 @@ DiagnosedSilenceableFailure
|
||||
transform::AffineMapParamType::checkPayload(Location loc,
|
||||
ArrayRef<Attribute> payload) const {
|
||||
for (Attribute attr : payload) {
|
||||
if (!attr.isa<AffineMapAttr>()) {
|
||||
if (!mlir::isa<AffineMapAttr>(attr)) {
|
||||
return emitSilenceableError(loc)
|
||||
<< "expected affine map attribute, got " << attr;
|
||||
}
|
||||
@ -144,7 +144,7 @@ DiagnosedSilenceableFailure
|
||||
transform::TypeParamType::checkPayload(Location loc,
|
||||
ArrayRef<Attribute> payload) const {
|
||||
for (Attribute attr : payload) {
|
||||
if (!attr.isa<TypeAttr>()) {
|
||||
if (!mlir::isa<TypeAttr>(attr)) {
|
||||
return emitSilenceableError(loc)
|
||||
<< "expected type attribute, got " << attr;
|
||||
}
|
||||
|
||||
@ -6169,7 +6169,7 @@ void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
|
||||
|
||||
OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
|
||||
auto constOperand = adaptor.getInput();
|
||||
if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
|
||||
if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
|
||||
return {};
|
||||
|
||||
// SplatElementsAttr::get treats single value for second arg as being a splat.
|
||||
|
||||
@ -57,7 +57,7 @@ static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec,
|
||||
Value broadcasted = extendVectorRank(builder, loc, vec, addedRank);
|
||||
SmallVector<int64_t> permutation;
|
||||
for (int64_t i = addedRank,
|
||||
e = broadcasted.getType().cast<VectorType>().getRank();
|
||||
e = cast<VectorType>(broadcasted.getType()).getRank();
|
||||
i < e; ++i)
|
||||
permutation.push_back(i);
|
||||
for (int64_t i = 0; i < addedRank; ++i)
|
||||
|
||||
@ -403,7 +403,7 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
|
||||
// Such transposes do not materially effect the underlying vector and can
|
||||
// be omitted. EG: perm [1, 0, 2] applied to vector<1x1x8xi32>
|
||||
bool transposeNonOuterUnitDims = false;
|
||||
auto operandShape = operands[it.index()].getType().cast<ShapedType>();
|
||||
auto operandShape = cast<ShapedType>(operands[it.index()].getType());
|
||||
for (auto [index, dim] :
|
||||
llvm::enumerate(ArrayRef<int64_t>(perm).drop_back(1))) {
|
||||
if (dim != static_cast<int64_t>(index) &&
|
||||
|
||||
@ -63,7 +63,7 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
|
||||
// new mask index) only happens on the last dimension of the vectors.
|
||||
Operation *newMask = nullptr;
|
||||
SmallVector<int64_t> shape(
|
||||
maskOp->getResultTypes()[0].cast<VectorType>().getShape());
|
||||
cast<VectorType>(maskOp->getResultTypes()[0]).getShape());
|
||||
shape.back() = numElements;
|
||||
auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
|
||||
if (createMaskOp) {
|
||||
|
||||
@ -171,7 +171,7 @@ static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
|
||||
/// is first inserted, followed by a `memref.cast`.
|
||||
static Value castToCompatibleMemRefType(OpBuilder &b, Value memref,
|
||||
MemRefType compatibleMemRefType) {
|
||||
MemRefType sourceType = memref.getType().cast<MemRefType>();
|
||||
MemRefType sourceType = cast<MemRefType>(memref.getType());
|
||||
Value res = memref;
|
||||
if (sourceType.getMemorySpace() != compatibleMemRefType.getMemorySpace()) {
|
||||
sourceType = MemRefType::get(
|
||||
|
||||
@ -127,7 +127,7 @@ LogicalResult CreateNdDescOp::verify() {
|
||||
|
||||
// check source type matches the rank if it is a memref.
|
||||
// It also should have the same ElementType as TensorDesc.
|
||||
auto memrefTy = getSourceType().dyn_cast<MemRefType>();
|
||||
auto memrefTy = dyn_cast<MemRefType>(getSourceType());
|
||||
if (memrefTy) {
|
||||
invalidRank |= (memrefTy.getRank() != rank);
|
||||
invalidElemTy |= memrefTy.getElementType() != getElementType();
|
||||
|
||||
@ -711,7 +711,7 @@ AffineMap mlir::foldAttributesIntoMap(Builder &b, AffineMap map,
|
||||
for (int64_t i = 0; i < map.getNumDims(); ++i) {
|
||||
if (auto attr = operands[i].dyn_cast<Attribute>()) {
|
||||
dimReplacements.push_back(
|
||||
b.getAffineConstantExpr(attr.cast<IntegerAttr>().getInt()));
|
||||
b.getAffineConstantExpr(cast<IntegerAttr>(attr).getInt()));
|
||||
} else {
|
||||
dimReplacements.push_back(b.getAffineDimExpr(numDims++));
|
||||
remainingValues.push_back(operands[i].get<Value>());
|
||||
@ -721,7 +721,7 @@ AffineMap mlir::foldAttributesIntoMap(Builder &b, AffineMap map,
|
||||
for (int64_t i = 0; i < map.getNumSymbols(); ++i) {
|
||||
if (auto attr = operands[i + map.getNumDims()].dyn_cast<Attribute>()) {
|
||||
symReplacements.push_back(
|
||||
b.getAffineConstantExpr(attr.cast<IntegerAttr>().getInt()));
|
||||
b.getAffineConstantExpr(cast<IntegerAttr>(attr).getInt()));
|
||||
} else {
|
||||
symReplacements.push_back(b.getAffineSymbolExpr(numSymbols++));
|
||||
remainingValues.push_back(operands[i + map.getNumDims()].get<Value>());
|
||||
|
||||
@ -1154,7 +1154,7 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultRank(Operation *op) {
|
||||
|
||||
// delegate function that returns rank of shaped type with known rank
|
||||
auto getRank = [](const Type type) {
|
||||
return type.cast<ShapedType>().getRank();
|
||||
return cast<ShapedType>(type).getRank();
|
||||
};
|
||||
|
||||
auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin())
|
||||
|
||||
@ -2489,7 +2489,7 @@ static void genMapInfos(llvm::IRBuilderBase &builder,
|
||||
auto addDevInfos = [&, fail](auto devOperands, auto devOpType) -> void {
|
||||
for (const auto &devOp : devOperands) {
|
||||
// TODO: Only LLVMPointerTypes are handled.
|
||||
if (!devOp.getType().template isa<LLVM::LLVMPointerType>())
|
||||
if (!isa<LLVM::LLVMPointerType>(devOp.getType()))
|
||||
return fail();
|
||||
|
||||
llvm::Value *mapOpValue = moduleTranslation.lookupValue(devOp);
|
||||
@ -3083,10 +3083,9 @@ convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute,
|
||||
std::vector<llvm::GlobalVariable *> generatedRefs;
|
||||
|
||||
std::vector<llvm::Triple> targetTriple;
|
||||
auto targetTripleAttr =
|
||||
op->getParentOfType<mlir::ModuleOp>()
|
||||
->getAttr(LLVM::LLVMDialect::getTargetTripleAttrName())
|
||||
.dyn_cast_or_null<mlir::StringAttr>();
|
||||
auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
|
||||
op->getParentOfType<mlir::ModuleOp>()->getAttr(
|
||||
LLVM::LLVMDialect::getTargetTripleAttrName()));
|
||||
if (targetTripleAttr)
|
||||
targetTriple.emplace_back(targetTripleAttr.data());
|
||||
|
||||
@ -3328,7 +3327,7 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
|
||||
attribute.getName())
|
||||
.Case("omp.is_target_device",
|
||||
[&](Attribute attr) {
|
||||
if (auto deviceAttr = attr.dyn_cast<BoolAttr>()) {
|
||||
if (auto deviceAttr = dyn_cast<BoolAttr>(attr)) {
|
||||
llvm::OpenMPIRBuilderConfig &config =
|
||||
moduleTranslation.getOpenMPBuilder()->Config;
|
||||
config.setIsTargetDevice(deviceAttr.getValue());
|
||||
@ -3338,7 +3337,7 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
|
||||
})
|
||||
.Case("omp.is_gpu",
|
||||
[&](Attribute attr) {
|
||||
if (auto gpuAttr = attr.dyn_cast<BoolAttr>()) {
|
||||
if (auto gpuAttr = dyn_cast<BoolAttr>(attr)) {
|
||||
llvm::OpenMPIRBuilderConfig &config =
|
||||
moduleTranslation.getOpenMPBuilder()->Config;
|
||||
config.setIsGPU(gpuAttr.getValue());
|
||||
@ -3348,7 +3347,7 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
|
||||
})
|
||||
.Case("omp.host_ir_filepath",
|
||||
[&](Attribute attr) {
|
||||
if (auto filepathAttr = attr.dyn_cast<StringAttr>()) {
|
||||
if (auto filepathAttr = dyn_cast<StringAttr>(attr)) {
|
||||
llvm::OpenMPIRBuilder *ompBuilder =
|
||||
moduleTranslation.getOpenMPBuilder();
|
||||
ompBuilder->loadOffloadInfoMetadata(filepathAttr.getValue());
|
||||
@ -3358,13 +3357,13 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
|
||||
})
|
||||
.Case("omp.flags",
|
||||
[&](Attribute attr) {
|
||||
if (auto rtlAttr = attr.dyn_cast<omp::FlagsAttr>())
|
||||
if (auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
|
||||
return convertFlagsAttr(op, rtlAttr, moduleTranslation);
|
||||
return failure();
|
||||
})
|
||||
.Case("omp.version",
|
||||
[&](Attribute attr) {
|
||||
if (auto versionAttr = attr.dyn_cast<omp::VersionAttr>()) {
|
||||
if (auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
|
||||
llvm::OpenMPIRBuilder *ompBuilder =
|
||||
moduleTranslation.getOpenMPBuilder();
|
||||
ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp",
|
||||
@ -3376,15 +3375,14 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
|
||||
.Case("omp.declare_target",
|
||||
[&](Attribute attr) {
|
||||
if (auto declareTargetAttr =
|
||||
attr.dyn_cast<omp::DeclareTargetAttr>())
|
||||
dyn_cast<omp::DeclareTargetAttr>(attr))
|
||||
return convertDeclareTargetAttr(op, declareTargetAttr,
|
||||
moduleTranslation);
|
||||
return failure();
|
||||
})
|
||||
.Case("omp.requires",
|
||||
[&](Attribute attr) {
|
||||
if (auto requiresAttr =
|
||||
attr.dyn_cast<omp::ClauseRequiresAttr>()) {
|
||||
if (auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
|
||||
using Requires = omp::ClauseRequires;
|
||||
Requires flags = requiresAttr.getValue();
|
||||
llvm::OpenMPIRBuilderConfig &config =
|
||||
|
||||
@ -29,8 +29,8 @@ using mlir::LLVM::detail::createIntrinsicCall;
|
||||
/// option around.
|
||||
static llvm::Type *getXlenType(Attribute opcodeAttr,
|
||||
LLVM::ModuleTranslation &moduleTranslation) {
|
||||
auto intAttr = opcodeAttr.cast<IntegerAttr>();
|
||||
unsigned xlenWidth = intAttr.getType().cast<IntegerType>().getWidth();
|
||||
auto intAttr = cast<IntegerAttr>(opcodeAttr);
|
||||
unsigned xlenWidth = cast<IntegerType>(intAttr.getType()).getWidth();
|
||||
return llvm::Type::getIntNTy(moduleTranslation.getLLVMContext(), xlenWidth);
|
||||
}
|
||||
|
||||
|
||||
@ -25,7 +25,7 @@ namespace {
|
||||
/// according to LLVM's encoding:
|
||||
/// https://lists.llvm.org/pipermail/llvm-dev/2020-October/145850.html
|
||||
static std::pair<unsigned, VectorType> legalizeVectorType(const Type &type) {
|
||||
VectorType vt = type.cast<VectorType>();
|
||||
VectorType vt = cast<VectorType>(type);
|
||||
// To simplify test pass, avoid multi-dimensional vectors.
|
||||
if (!vt || vt.getRank() != 1)
|
||||
return {0, nullptr};
|
||||
@ -39,7 +39,7 @@ static std::pair<unsigned, VectorType> legalizeVectorType(const Type &type) {
|
||||
sew = 32;
|
||||
else if (eltTy.isF64())
|
||||
sew = 64;
|
||||
else if (auto intTy = eltTy.dyn_cast<IntegerType>())
|
||||
else if (auto intTy = dyn_cast<IntegerType>(eltTy))
|
||||
sew = intTy.getWidth();
|
||||
else
|
||||
return {0, nullptr};
|
||||
|
||||
@ -67,12 +67,11 @@ struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
|
||||
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
|
||||
ShapedType sourceShardShape =
|
||||
shardShapedType(op.getResult().getType(), mesh, op.getShard());
|
||||
TypedValue<ShapedType> sourceShard =
|
||||
TypedValue<ShapedType> sourceShard = cast<TypedValue<ShapedType>>(
|
||||
builder
|
||||
.create<UnrealizedConversionCastOp>(sourceShardShape,
|
||||
op.getOperand())
|
||||
->getResult(0)
|
||||
.cast<TypedValue<ShapedType>>();
|
||||
->getResult(0));
|
||||
TypedValue<ShapedType> targetShard =
|
||||
reshard(builder, mesh, op, targetShardOp, sourceShard);
|
||||
Value newTargetUnsharded =
|
||||
|
||||
@ -61,7 +61,7 @@ LogicalResult TestDialectLLVMIRTranslationInterface::amendOperation(
|
||||
}
|
||||
|
||||
bool createSymbol = false;
|
||||
if (auto boolAttr = attr.dyn_cast<BoolAttr>())
|
||||
if (auto boolAttr = dyn_cast<BoolAttr>(attr))
|
||||
createSymbol = boolAttr.getValue();
|
||||
|
||||
if (createSymbol) {
|
||||
|
||||
@ -44,7 +44,7 @@ void TestAffineWalk::runOnOperation() {
|
||||
// Test whether the walk is being correctly interrupted.
|
||||
m.walk([](Operation *op) {
|
||||
for (NamedAttribute attr : op->getAttrs()) {
|
||||
auto mapAttr = attr.getValue().dyn_cast<AffineMapAttr>();
|
||||
auto mapAttr = dyn_cast<AffineMapAttr>(attr.getValue());
|
||||
if (!mapAttr)
|
||||
return;
|
||||
checkMod(mapAttr.getAffineMap(), op->getLoc());
|
||||
|
||||
@ -51,7 +51,7 @@ struct TestElementsAttrInterface
|
||||
InFlightDiagnostic diag = op->emitError()
|
||||
<< "Test iterating `" << type << "`: ";
|
||||
|
||||
if (!attr.getElementType().isa<mlir::IntegerType>()) {
|
||||
if (!isa<mlir::IntegerType>(attr.getElementType())) {
|
||||
diag << "expected element type to be an integer type";
|
||||
return;
|
||||
}
|
||||
|
||||
@ -61,7 +61,7 @@ static LogicalResult customTypeRangeResultConstraint(PatternRewriter &rewriter,
|
||||
PDLResultList &results,
|
||||
ArrayRef<PDLValue> args) {
|
||||
auto *op = args[0].cast<Operation *>();
|
||||
int numTypes = args[1].cast<Attribute>().cast<IntegerAttr>().getInt();
|
||||
int numTypes = cast<IntegerAttr>(args[1].cast<Attribute>()).getInt();
|
||||
|
||||
if (op->getName().getStringRef() == "test.success_op") {
|
||||
SmallVector<Type> types;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user