[mlir][NFC] Use getDefiningOp<OpTy>()
instead of dyn_cast<OpTy>(getDefiningOp())
(#150428)
This PR uses `val.getDefiningOp<OpTy>()` to replace `dyn_cast<OpTy>(val.getDefiningOp())` , `dyn_cast_or_null<OpTy>(val.getDefiningOp())` and `dyn_cast_if_present<OpTy>(val.getDefiningOp())`.
This commit is contained in:
parent
b16ef20626
commit
f047b735e9
@ -49,7 +49,7 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
|
||||
assert(isa<pdl::AttributeType>(val.getType()) && "expected attribute type");
|
||||
predList.emplace_back(pos, builder.getIsNotNull());
|
||||
|
||||
if (auto attr = dyn_cast<pdl::AttributeOp>(val.getDefiningOp())) {
|
||||
if (auto attr = val.getDefiningOp<pdl::AttributeOp>()) {
|
||||
// If the attribute has a type or value, add a constraint.
|
||||
if (Value type = attr.getValueType())
|
||||
getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
|
||||
|
@ -1322,7 +1322,7 @@ static bool isNeutralElementConst(arith::AtomicRMWKind reductionKind,
|
||||
return false;
|
||||
Attribute valueAttr = getIdentityValueAttr(reductionKind, scalarTy,
|
||||
state.builder, value.getLoc());
|
||||
if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(value.getDefiningOp()))
|
||||
if (auto constOp = value.getDefiningOp<arith::ConstantOp>())
|
||||
return constOp.getValue() == valueAttr;
|
||||
return false;
|
||||
}
|
||||
|
@ -2498,7 +2498,7 @@ OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
|
||||
matchPattern(adaptor.getFalseValue(), m_Zero()))
|
||||
return condition;
|
||||
|
||||
if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
|
||||
if (auto cmp = condition.getDefiningOp<arith::CmpIOp>()) {
|
||||
auto pred = cmp.getPredicate();
|
||||
if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
|
||||
auto cmpLhs = cmp.getLhs();
|
||||
|
@ -49,7 +49,7 @@ std::optional<Value> getExtOperand(Value v) {
|
||||
|
||||
// If the operand is not defined by an explicit extend operation of the
|
||||
// accepted operation type allow for an implicit sign-extension.
|
||||
auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp());
|
||||
auto extOp = v.getDefiningOp<Op>();
|
||||
if (!extOp) {
|
||||
if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
|
||||
auto eltTy = cast<VectorType>(v.getType()).getElementType();
|
||||
|
@ -50,7 +50,7 @@ std::optional<Value> getExtOperand(Value v) {
|
||||
|
||||
// If the operand is not defined by an explicit extend operation of the
|
||||
// accepted operation type allow for an implicit sign-extension.
|
||||
auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp());
|
||||
auto extOp = v.getDefiningOp<Op>();
|
||||
if (!extOp) {
|
||||
if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
|
||||
auto vTy = cast<VectorType>(v.getType());
|
||||
|
@ -62,9 +62,7 @@ struct FoldExpressionOp : public OpRewritePattern<ExpressionOp> {
|
||||
continue;
|
||||
|
||||
for (Value operand : op.getOperands()) {
|
||||
auto usedExpression =
|
||||
dyn_cast_if_present<ExpressionOp>(operand.getDefiningOp());
|
||||
|
||||
auto usedExpression = operand.getDefiningOp<ExpressionOp>();
|
||||
if (!usedExpression)
|
||||
continue;
|
||||
|
||||
|
@ -2707,7 +2707,7 @@ LogicalResult IFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||
while (alias) {
|
||||
Block &initBlock = alias.getInitializerBlock();
|
||||
auto returnOp = cast<ReturnOp>(initBlock.getTerminator());
|
||||
auto addrOp = dyn_cast<AddressOfOp>(returnOp.getArg().getDefiningOp());
|
||||
auto addrOp = returnOp.getArg().getDefiningOp<AddressOfOp>();
|
||||
// FIXME: This is a best effort solution. The AliasOp body might be more
|
||||
// complex and in that case we bail out with success. To completely match
|
||||
// the LLVM IR logic it would be necessary to implement proper alias and
|
||||
|
@ -1852,7 +1852,7 @@ transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter,
|
||||
assert(!packOp && "packOp must be null on entry when unPackOp is not null");
|
||||
OpOperand *packUse = linalgOp.getDpsInitOperand(
|
||||
cast<OpResult>(unPackOp.getSource()).getResultNumber());
|
||||
packOp = dyn_cast_or_null<linalg::PackOp>(packUse->get().getDefiningOp());
|
||||
packOp = packUse->get().getDefiningOp<linalg::PackOp>();
|
||||
if (!packOp || !packOp.getResult().hasOneUse())
|
||||
return emitSilenceableError() << "could not find matching pack op";
|
||||
}
|
||||
|
@ -757,8 +757,7 @@ static bool tracesBackToExpectedValue(tensor::ExtractSliceOp extractSliceOp,
|
||||
Value source = extractSliceOp.getSource();
|
||||
LLVM_DEBUG(DBGS() << "--with starting source: " << source << "\n");
|
||||
while (source && source != expectedSource) {
|
||||
auto destOp =
|
||||
dyn_cast_or_null<DestinationStyleOpInterface>(source.getDefiningOp());
|
||||
auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
|
||||
if (!destOp)
|
||||
break;
|
||||
LLVM_DEBUG(DBGS() << "--step dest op: " << destOp << "\n");
|
||||
|
@ -165,8 +165,7 @@ static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
|
||||
Value source = transferRead.getBase();
|
||||
|
||||
// Skip view-like Ops and retrive the actual soruce Operation
|
||||
while (auto srcOp =
|
||||
dyn_cast_or_null<ViewLikeOpInterface>(source.getDefiningOp()))
|
||||
while (auto srcOp = source.getDefiningOp<ViewLikeOpInterface>())
|
||||
source = srcOp.getViewSource();
|
||||
|
||||
llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
|
||||
|
@ -755,7 +755,7 @@ bool MeshSharding::operator!=(const MeshSharding &rhs) const {
|
||||
MeshSharding::MeshSharding(::mlir::FlatSymbolRefAttr mesh_) : mesh(mesh_) {}
|
||||
|
||||
MeshSharding::MeshSharding(Value rhs) {
|
||||
auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.getDefiningOp());
|
||||
auto shardingOp = rhs.getDefiningOp<ShardingOp>();
|
||||
assert(shardingOp && "expected sharding op");
|
||||
auto splitAxes = shardingOp.getSplitAxes().getAxes();
|
||||
// If splitAxes are empty, use "empty" constructor.
|
||||
|
@ -167,7 +167,7 @@ ReshardingRquirementKind getReshardingRquirementKind(
|
||||
|
||||
for (auto [operand, sharding] :
|
||||
llvm::zip_equal(op->getOperands(), operandShardings)) {
|
||||
ShardOp shardOp = llvm::dyn_cast_or_null<ShardOp>(operand.getDefiningOp());
|
||||
ShardOp shardOp = operand.getDefiningOp<ShardOp>();
|
||||
if (!shardOp) {
|
||||
continue;
|
||||
}
|
||||
@ -376,8 +376,7 @@ struct ShardingPropagation
|
||||
LLVM_DEBUG(
|
||||
DBGS() << "print all the ops' iterator types and indexing maps in the "
|
||||
"block.\n";
|
||||
for (Operation &op
|
||||
: block.getOperations()) {
|
||||
for (Operation &op : block.getOperations()) {
|
||||
if (auto shardingOp = llvm::dyn_cast<ShardingInterface>(&op))
|
||||
shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
|
||||
});
|
||||
|
@ -660,8 +660,7 @@ spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
|
||||
|
||||
// Check if 2 shard ops are chained. If not there is no need for resharding
|
||||
// as the source and target shared the same sharding.
|
||||
ShardOp srcShardOp =
|
||||
dyn_cast_or_null<ShardOp>(shardOp.getSrc().getDefiningOp());
|
||||
ShardOp srcShardOp = shardOp.getSrc().getDefiningOp<ShardOp>();
|
||||
if (!srcShardOp) {
|
||||
targetSpmdValue = spmdizationMap.lookup(shardOp.getSrc());
|
||||
} else {
|
||||
|
@ -1730,8 +1730,7 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) {
|
||||
if (!mapOp.getDefiningOp())
|
||||
return emitError(op->getLoc(), "missing map operation");
|
||||
|
||||
if (auto mapInfoOp =
|
||||
mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp())) {
|
||||
if (auto mapInfoOp = mapOp.getDefiningOp<mlir::omp::MapInfoOp>()) {
|
||||
uint64_t mapTypeBits = mapInfoOp.getMapType();
|
||||
|
||||
bool to = mapTypeToBitFlag(
|
||||
|
@ -53,7 +53,7 @@ OpFoldResult FromPtrOp::fold(FoldAdaptor adaptor) {
|
||||
Value ptrLike;
|
||||
FromPtrOp fromPtr = *this;
|
||||
while (fromPtr != nullptr) {
|
||||
auto toPtr = dyn_cast_or_null<ToPtrOp>(fromPtr.getPtr().getDefiningOp());
|
||||
auto toPtr = fromPtr.getPtr().getDefiningOp<ToPtrOp>();
|
||||
// Cannot fold if it's not a `to_ptr` op or the initial and final types are
|
||||
// different.
|
||||
if (!toPtr || toPtr.getPtr().getType() != fromPtr.getType())
|
||||
@ -64,13 +64,12 @@ OpFoldResult FromPtrOp::fold(FoldAdaptor adaptor) {
|
||||
ptrLike = toPtr.getPtr();
|
||||
} else if (md) {
|
||||
// Fold if the metadata can be verified to be equal.
|
||||
if (auto mdOp = dyn_cast_or_null<GetMetadataOp>(md.getDefiningOp());
|
||||
if (auto mdOp = md.getDefiningOp<GetMetadataOp>();
|
||||
mdOp && mdOp.getPtr() == toPtr.getPtr())
|
||||
ptrLike = toPtr.getPtr();
|
||||
}
|
||||
// Check for a sequence of casts.
|
||||
fromPtr = dyn_cast_or_null<FromPtrOp>(ptrLike ? ptrLike.getDefiningOp()
|
||||
: nullptr);
|
||||
fromPtr = ptrLike ? ptrLike.getDefiningOp<FromPtrOp>() : nullptr;
|
||||
}
|
||||
return ptrLike;
|
||||
}
|
||||
@ -112,13 +111,13 @@ OpFoldResult ToPtrOp::fold(FoldAdaptor adaptor) {
|
||||
Value ptr;
|
||||
ToPtrOp toPtr = *this;
|
||||
while (toPtr != nullptr) {
|
||||
auto fromPtr = dyn_cast_or_null<FromPtrOp>(toPtr.getPtr().getDefiningOp());
|
||||
auto fromPtr = toPtr.getPtr().getDefiningOp<FromPtrOp>();
|
||||
// Cannot fold if it's not a `from_ptr` op.
|
||||
if (!fromPtr)
|
||||
return ptr;
|
||||
ptr = fromPtr.getPtr();
|
||||
// Check for chains of casts.
|
||||
toPtr = dyn_cast_or_null<ToPtrOp>(ptr.getDefiningOp());
|
||||
toPtr = ptr.getDefiningOp<ToPtrOp>();
|
||||
}
|
||||
return ptr;
|
||||
}
|
||||
|
@ -100,11 +100,10 @@ mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes,
|
||||
op.getStep(), tileSizeConstants)) {
|
||||
// Collect the statically known loop bounds
|
||||
auto lowerBoundConstant =
|
||||
dyn_cast_or_null<arith::ConstantIndexOp>(lowerBound.getDefiningOp());
|
||||
lowerBound.getDefiningOp<arith::ConstantIndexOp>();
|
||||
auto upperBoundConstant =
|
||||
dyn_cast_or_null<arith::ConstantIndexOp>(upperBound.getDefiningOp());
|
||||
auto stepConstant =
|
||||
dyn_cast_or_null<arith::ConstantIndexOp>(step.getDefiningOp());
|
||||
upperBound.getDefiningOp<arith::ConstantIndexOp>();
|
||||
auto stepConstant = step.getDefiningOp<arith::ConstantIndexOp>();
|
||||
auto tileSize =
|
||||
cast<arith::ConstantIndexOp>(tileSizeConstant.getDefiningOp()).value();
|
||||
// If the loop bounds and the loop step are constant and if the number of
|
||||
|
@ -1317,7 +1317,7 @@ public:
|
||||
|
||||
Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1);
|
||||
Value newSize = arith::AddIOp::create(rewriter, loc, size, n);
|
||||
auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp());
|
||||
auto nValue = n.getDefiningOp<arith::ConstantIndexOp>();
|
||||
bool nIsOne = (nValue && nValue.value() == 1);
|
||||
|
||||
if (!op.getInbounds()) {
|
||||
|
@ -554,7 +554,7 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
|
||||
Value input = op.getInput();
|
||||
|
||||
// Check the input to the CLAMP op is itself a CLAMP.
|
||||
auto clampOp = dyn_cast_if_present<tosa::ClampOp>(input.getDefiningOp());
|
||||
auto clampOp = input.getDefiningOp<tosa::ClampOp>();
|
||||
if (!clampOp)
|
||||
return failure();
|
||||
|
||||
@ -1636,7 +1636,7 @@ OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
|
||||
for (Value operand : getOperands()) {
|
||||
concatOperands.emplace_back(operand);
|
||||
|
||||
auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp());
|
||||
auto producer = operand.getDefiningOp<ConcatOp>();
|
||||
if (!producer)
|
||||
continue;
|
||||
|
||||
|
@ -2591,8 +2591,7 @@ class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
|
||||
llvm::enumerate(fromElements.getElements())) {
|
||||
|
||||
// Check that the element is from a vector.extract operation.
|
||||
auto extractOp =
|
||||
dyn_cast_if_present<vector::ExtractOp>(element.getDefiningOp());
|
||||
auto extractOp = element.getDefiningOp<vector::ExtractOp>();
|
||||
if (!extractOp) {
|
||||
return rewriter.notifyMatchFailure(fromElements,
|
||||
"element not from vector.extract");
|
||||
|
@ -900,8 +900,7 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) {
|
||||
// inlined, and as such should be wrapped in parentheses in order to guarantee
|
||||
// its precedence and associativity.
|
||||
auto requiresParentheses = [&](Value value) {
|
||||
auto expressionOp =
|
||||
dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
|
||||
auto expressionOp = value.getDefiningOp<ExpressionOp>();
|
||||
if (!expressionOp)
|
||||
return false;
|
||||
return shouldBeInlined(expressionOp);
|
||||
@ -1542,7 +1541,7 @@ LogicalResult CppEmitter::emitOperand(Value value) {
|
||||
return success();
|
||||
}
|
||||
|
||||
auto expressionOp = dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
|
||||
auto expressionOp = value.getDefiningOp<ExpressionOp>();
|
||||
if (expressionOp && shouldBeInlined(expressionOp))
|
||||
return emitExpression(expressionOp);
|
||||
|
||||
|
@ -151,8 +151,7 @@ processDataOperands(llvm::IRBuilderBase &builder,
|
||||
// Copyin operands are handled as `to` call.
|
||||
llvm::SmallVector<mlir::Value> create, copyin;
|
||||
for (mlir::Value dataOp : op.getDataClauseOperands()) {
|
||||
if (auto createOp =
|
||||
mlir::dyn_cast_or_null<acc::CreateOp>(dataOp.getDefiningOp())) {
|
||||
if (auto createOp = dataOp.getDefiningOp<acc::CreateOp>()) {
|
||||
create.push_back(createOp.getVarPtr());
|
||||
} else if (auto copyinOp = mlir::dyn_cast_or_null<acc::CopyinOp>(
|
||||
dataOp.getDefiningOp())) {
|
||||
|
@ -3537,8 +3537,7 @@ getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp,
|
||||
}
|
||||
|
||||
static bool isDeclareTargetLink(mlir::Value value) {
|
||||
if (auto addressOfOp =
|
||||
llvm::dyn_cast_if_present<LLVM::AddressOfOp>(value.getDefiningOp())) {
|
||||
if (auto addressOfOp = value.getDefiningOp<LLVM::AddressOfOp>()) {
|
||||
auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
|
||||
Operation *gOp = modOp.lookupSymbol(addressOfOp.getGlobalName());
|
||||
if (auto declareTargetGlobal =
|
||||
@ -4498,8 +4497,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
|
||||
ifCond = moduleTranslation.lookupValue(ifVar);
|
||||
|
||||
if (auto devId = dataOp.getDevice())
|
||||
if (auto constOp =
|
||||
dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
|
||||
if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
|
||||
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
|
||||
deviceID = intAttr.getInt();
|
||||
|
||||
@ -4516,8 +4514,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
|
||||
ifCond = moduleTranslation.lookupValue(ifVar);
|
||||
|
||||
if (auto devId = enterDataOp.getDevice())
|
||||
if (auto constOp =
|
||||
dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
|
||||
if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
|
||||
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
|
||||
deviceID = intAttr.getInt();
|
||||
RTLFn =
|
||||
@ -4536,8 +4533,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
|
||||
ifCond = moduleTranslation.lookupValue(ifVar);
|
||||
|
||||
if (auto devId = exitDataOp.getDevice())
|
||||
if (auto constOp =
|
||||
dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
|
||||
if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
|
||||
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
|
||||
deviceID = intAttr.getInt();
|
||||
|
||||
@ -4556,8 +4552,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
|
||||
ifCond = moduleTranslation.lookupValue(ifVar);
|
||||
|
||||
if (auto devId = updateDataOp.getDevice())
|
||||
if (auto constOp =
|
||||
dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
|
||||
if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>())
|
||||
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
|
||||
deviceID = intAttr.getInt();
|
||||
|
||||
@ -5198,8 +5193,7 @@ static std::optional<int64_t> extractConstInteger(Value value) {
|
||||
if (!value)
|
||||
return std::nullopt;
|
||||
|
||||
if (auto constOp =
|
||||
dyn_cast_if_present<LLVM::ConstantOp>(value.getDefiningOp()))
|
||||
if (auto constOp = value.getDefiningOp<LLVM::ConstantOp>())
|
||||
if (auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
|
||||
return constAttr.getInt();
|
||||
|
||||
|
@ -139,8 +139,7 @@ public:
|
||||
|
||||
LogicalResult matchAndRewrite(TestCommutative2Op op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto operand =
|
||||
dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp());
|
||||
auto operand = op->getOperand(0).getDefiningOp<TestCommutative2Op>();
|
||||
if (!operand)
|
||||
return failure();
|
||||
Attribute constInput;
|
||||
|
Loading…
x
Reference in New Issue
Block a user