From 02ab6f358cf26f00342acc7c1c306dd61cd4f10e Mon Sep 17 00:00:00 2001 From: yanming Date: Wed, 13 Aug 2025 14:56:02 +0800 Subject: [PATCH] [flang][fir][NFC] unify flang's code style with the rest. --- flang/lib/Optimizer/Transforms/FIRToSCF.cpp | 80 +++++++++++---------- 1 file changed, 41 insertions(+), 39 deletions(-) diff --git a/flang/lib/Optimizer/Transforms/FIRToSCF.cpp b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp index 1902757e83bf..79ed85fa6060 100644 --- a/flang/lib/Optimizer/Transforms/FIRToSCF.cpp +++ b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp @@ -16,29 +16,27 @@ namespace fir { #include "flang/Optimizer/Transforms/Passes.h.inc" } // namespace fir -using namespace fir; -using namespace mlir; - namespace { class FIRToSCFPass : public fir::impl::FIRToSCFPassBase { public: void runOnOperation() override; }; -struct DoLoopConversion : public OpRewritePattern { +struct DoLoopConversion : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(fir::DoLoopOp doLoopOp, - PatternRewriter &rewriter) const override { - auto loc = doLoopOp.getLoc(); + mlir::LogicalResult + matchAndRewrite(fir::DoLoopOp doLoopOp, + mlir::PatternRewriter &rewriter) const override { + mlir::Location loc = doLoopOp.getLoc(); bool hasFinalValue = doLoopOp.getFinalValue().has_value(); // Get loop values from the DoLoopOp - auto low = doLoopOp.getLowerBound(); - auto high = doLoopOp.getUpperBound(); + mlir::Value low = doLoopOp.getLowerBound(); + mlir::Value high = doLoopOp.getUpperBound(); assert(low && high && "must be a Value"); - auto step = doLoopOp.getStep(); - llvm::SmallVector iterArgs; + mlir::Value step = doLoopOp.getStep(); + llvm::SmallVector iterArgs; if (hasFinalValue) iterArgs.push_back(low); iterArgs.append(doLoopOp.getIterOperands().begin(), @@ -49,31 +47,33 @@ struct DoLoopConversion : public OpRewritePattern { // must be a positive value. // For easier conversion, we calculate the trip count and use a canonical // induction variable. - auto diff = arith::SubIOp::create(rewriter, loc, high, low); - auto distance = arith::AddIOp::create(rewriter, loc, diff, step); - auto tripCount = arith::DivSIOp::create(rewriter, loc, distance, step); - auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0); - auto one = arith::ConstantIndexOp::create(rewriter, loc, 1); + auto diff = mlir::arith::SubIOp::create(rewriter, loc, high, low); + auto distance = mlir::arith::AddIOp::create(rewriter, loc, diff, step); + auto tripCount = + mlir::arith::DivSIOp::create(rewriter, loc, distance, step); + auto zero = mlir::arith::ConstantIndexOp::create(rewriter, loc, 0); + auto one = mlir::arith::ConstantIndexOp::create(rewriter, loc, 1); auto scfForOp = - scf::ForOp::create(rewriter, loc, zero, tripCount, one, iterArgs); + mlir::scf::ForOp::create(rewriter, loc, zero, tripCount, one, iterArgs); auto &loopOps = doLoopOp.getBody()->getOperations(); - auto resultOp = cast(doLoopOp.getBody()->getTerminator()); + auto resultOp = + mlir::cast(doLoopOp.getBody()->getTerminator()); auto results = resultOp.getOperands(); - Block *loweredBody = scfForOp.getBody(); + mlir::Block *loweredBody = scfForOp.getBody(); loweredBody->getOperations().splice(loweredBody->begin(), loopOps, loopOps.begin(), std::prev(loopOps.end())); rewriter.setInsertionPointToStart(loweredBody); - Value iv = - arith::MulIOp::create(rewriter, loc, scfForOp.getInductionVar(), step); - iv = arith::AddIOp::create(rewriter, loc, low, iv); + mlir::Value iv = mlir::arith::MulIOp::create( + rewriter, loc, scfForOp.getInductionVar(), step); + iv = mlir::arith::AddIOp::create(rewriter, loc, low, iv); if (!results.empty()) { rewriter.setInsertionPointToEnd(loweredBody); - scf::YieldOp::create(rewriter, resultOp->getLoc(), results); + mlir::scf::YieldOp::create(rewriter, resultOp->getLoc(), results); } doLoopOp.getInductionVar().replaceAllUsesWith(iv); rewriter.replaceAllUsesWith(doLoopOp.getRegionIterArgs(), @@ -84,34 +84,36 @@ struct DoLoopConversion : public OpRewritePattern { // Copy all the attributes from the old to new op. scfForOp->setAttrs(doLoopOp->getAttrs()); rewriter.replaceOp(doLoopOp, scfForOp); - return success(); + return mlir::success(); } }; -void copyBlockAndTransformResult(PatternRewriter &rewriter, Block &srcBlock, - Block &dstBlock) { - Operation *srcTerminator = srcBlock.getTerminator(); - auto resultOp = cast(srcTerminator); +void copyBlockAndTransformResult(mlir::PatternRewriter &rewriter, + mlir::Block &srcBlock, mlir::Block &dstBlock) { + mlir::Operation *srcTerminator = srcBlock.getTerminator(); + auto resultOp = mlir::cast(srcTerminator); dstBlock.getOperations().splice(dstBlock.begin(), srcBlock.getOperations(), srcBlock.begin(), std::prev(srcBlock.end())); if (!resultOp->getOperands().empty()) { rewriter.setInsertionPointToEnd(&dstBlock); - scf::YieldOp::create(rewriter, resultOp->getLoc(), resultOp->getOperands()); + mlir::scf::YieldOp::create(rewriter, resultOp->getLoc(), + resultOp->getOperands()); } rewriter.eraseOp(srcTerminator); } -struct IfConversion : public OpRewritePattern { +struct IfConversion : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(fir::IfOp ifOp, - PatternRewriter &rewriter) const override { + mlir::LogicalResult + matchAndRewrite(fir::IfOp ifOp, + mlir::PatternRewriter &rewriter) const override { bool hasElse = !ifOp.getElseRegion().empty(); auto scfIfOp = - scf::IfOp::create(rewriter, ifOp.getLoc(), ifOp.getResultTypes(), - ifOp.getCondition(), hasElse); + mlir::scf::IfOp::create(rewriter, ifOp.getLoc(), ifOp.getResultTypes(), + ifOp.getCondition(), hasElse); copyBlockAndTransformResult(rewriter, ifOp.getThenRegion().front(), scfIfOp.getThenRegion().front()); @@ -123,22 +125,22 @@ struct IfConversion : public OpRewritePattern { scfIfOp->setAttrs(ifOp->getAttrs()); rewriter.replaceOp(ifOp, scfIfOp); - return success(); + return mlir::success(); } }; } // namespace void FIRToSCFPass::runOnOperation() { - RewritePatternSet patterns(&getContext()); + mlir::RewritePatternSet patterns(&getContext()); patterns.add(patterns.getContext()); - ConversionTarget target(getContext()); + mlir::ConversionTarget target(getContext()); target.addIllegalOp(); - target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); + target.markUnknownOpDynamicallyLegal([](mlir::Operation *) { return true; }); if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } -std::unique_ptr fir::createFIRToSCFPass() { +std::unique_ptr fir::createFIRToSCFPass() { return std::make_unique(); }