[flang][fir][NFC] unify flang's code style with the rest.
This commit is contained in:
parent
8f3254aa4a
commit
02ab6f358c
@ -16,29 +16,27 @@ namespace fir {
|
||||
#include "flang/Optimizer/Transforms/Passes.h.inc"
|
||||
} // namespace fir
|
||||
|
||||
using namespace fir;
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
class FIRToSCFPass : public fir::impl::FIRToSCFPassBase<FIRToSCFPass> {
|
||||
public:
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
struct DoLoopConversion : public OpRewritePattern<fir::DoLoopOp> {
|
||||
struct DoLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
|
||||
using OpRewritePattern<fir::DoLoopOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(fir::DoLoopOp doLoopOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto loc = doLoopOp.getLoc();
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(fir::DoLoopOp doLoopOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::Location loc = doLoopOp.getLoc();
|
||||
bool hasFinalValue = doLoopOp.getFinalValue().has_value();
|
||||
|
||||
// Get loop values from the DoLoopOp
|
||||
auto low = doLoopOp.getLowerBound();
|
||||
auto high = doLoopOp.getUpperBound();
|
||||
mlir::Value low = doLoopOp.getLowerBound();
|
||||
mlir::Value high = doLoopOp.getUpperBound();
|
||||
assert(low && high && "must be a Value");
|
||||
auto step = doLoopOp.getStep();
|
||||
llvm::SmallVector<Value> iterArgs;
|
||||
mlir::Value step = doLoopOp.getStep();
|
||||
llvm::SmallVector<mlir::Value> iterArgs;
|
||||
if (hasFinalValue)
|
||||
iterArgs.push_back(low);
|
||||
iterArgs.append(doLoopOp.getIterOperands().begin(),
|
||||
@ -49,31 +47,33 @@ struct DoLoopConversion : public OpRewritePattern<fir::DoLoopOp> {
|
||||
// must be a positive value.
|
||||
// For easier conversion, we calculate the trip count and use a canonical
|
||||
// induction variable.
|
||||
auto diff = arith::SubIOp::create(rewriter, loc, high, low);
|
||||
auto distance = arith::AddIOp::create(rewriter, loc, diff, step);
|
||||
auto tripCount = arith::DivSIOp::create(rewriter, loc, distance, step);
|
||||
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||
auto one = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
||||
auto diff = mlir::arith::SubIOp::create(rewriter, loc, high, low);
|
||||
auto distance = mlir::arith::AddIOp::create(rewriter, loc, diff, step);
|
||||
auto tripCount =
|
||||
mlir::arith::DivSIOp::create(rewriter, loc, distance, step);
|
||||
auto zero = mlir::arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||
auto one = mlir::arith::ConstantIndexOp::create(rewriter, loc, 1);
|
||||
auto scfForOp =
|
||||
scf::ForOp::create(rewriter, loc, zero, tripCount, one, iterArgs);
|
||||
mlir::scf::ForOp::create(rewriter, loc, zero, tripCount, one, iterArgs);
|
||||
|
||||
auto &loopOps = doLoopOp.getBody()->getOperations();
|
||||
auto resultOp = cast<fir::ResultOp>(doLoopOp.getBody()->getTerminator());
|
||||
auto resultOp =
|
||||
mlir::cast<fir::ResultOp>(doLoopOp.getBody()->getTerminator());
|
||||
auto results = resultOp.getOperands();
|
||||
Block *loweredBody = scfForOp.getBody();
|
||||
mlir::Block *loweredBody = scfForOp.getBody();
|
||||
|
||||
loweredBody->getOperations().splice(loweredBody->begin(), loopOps,
|
||||
loopOps.begin(),
|
||||
std::prev(loopOps.end()));
|
||||
|
||||
rewriter.setInsertionPointToStart(loweredBody);
|
||||
Value iv =
|
||||
arith::MulIOp::create(rewriter, loc, scfForOp.getInductionVar(), step);
|
||||
iv = arith::AddIOp::create(rewriter, loc, low, iv);
|
||||
mlir::Value iv = mlir::arith::MulIOp::create(
|
||||
rewriter, loc, scfForOp.getInductionVar(), step);
|
||||
iv = mlir::arith::AddIOp::create(rewriter, loc, low, iv);
|
||||
|
||||
if (!results.empty()) {
|
||||
rewriter.setInsertionPointToEnd(loweredBody);
|
||||
scf::YieldOp::create(rewriter, resultOp->getLoc(), results);
|
||||
mlir::scf::YieldOp::create(rewriter, resultOp->getLoc(), results);
|
||||
}
|
||||
doLoopOp.getInductionVar().replaceAllUsesWith(iv);
|
||||
rewriter.replaceAllUsesWith(doLoopOp.getRegionIterArgs(),
|
||||
@ -84,34 +84,36 @@ struct DoLoopConversion : public OpRewritePattern<fir::DoLoopOp> {
|
||||
// Copy all the attributes from the old to new op.
|
||||
scfForOp->setAttrs(doLoopOp->getAttrs());
|
||||
rewriter.replaceOp(doLoopOp, scfForOp);
|
||||
return success();
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
void copyBlockAndTransformResult(PatternRewriter &rewriter, Block &srcBlock,
|
||||
Block &dstBlock) {
|
||||
Operation *srcTerminator = srcBlock.getTerminator();
|
||||
auto resultOp = cast<fir::ResultOp>(srcTerminator);
|
||||
void copyBlockAndTransformResult(mlir::PatternRewriter &rewriter,
|
||||
mlir::Block &srcBlock, mlir::Block &dstBlock) {
|
||||
mlir::Operation *srcTerminator = srcBlock.getTerminator();
|
||||
auto resultOp = mlir::cast<fir::ResultOp>(srcTerminator);
|
||||
|
||||
dstBlock.getOperations().splice(dstBlock.begin(), srcBlock.getOperations(),
|
||||
srcBlock.begin(), std::prev(srcBlock.end()));
|
||||
|
||||
if (!resultOp->getOperands().empty()) {
|
||||
rewriter.setInsertionPointToEnd(&dstBlock);
|
||||
scf::YieldOp::create(rewriter, resultOp->getLoc(), resultOp->getOperands());
|
||||
mlir::scf::YieldOp::create(rewriter, resultOp->getLoc(),
|
||||
resultOp->getOperands());
|
||||
}
|
||||
|
||||
rewriter.eraseOp(srcTerminator);
|
||||
}
|
||||
|
||||
struct IfConversion : public OpRewritePattern<fir::IfOp> {
|
||||
struct IfConversion : public mlir::OpRewritePattern<fir::IfOp> {
|
||||
using OpRewritePattern<fir::IfOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(fir::IfOp ifOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(fir::IfOp ifOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
bool hasElse = !ifOp.getElseRegion().empty();
|
||||
auto scfIfOp =
|
||||
scf::IfOp::create(rewriter, ifOp.getLoc(), ifOp.getResultTypes(),
|
||||
ifOp.getCondition(), hasElse);
|
||||
mlir::scf::IfOp::create(rewriter, ifOp.getLoc(), ifOp.getResultTypes(),
|
||||
ifOp.getCondition(), hasElse);
|
||||
|
||||
copyBlockAndTransformResult(rewriter, ifOp.getThenRegion().front(),
|
||||
scfIfOp.getThenRegion().front());
|
||||
@ -123,22 +125,22 @@ struct IfConversion : public OpRewritePattern<fir::IfOp> {
|
||||
|
||||
scfIfOp->setAttrs(ifOp->getAttrs());
|
||||
rewriter.replaceOp(ifOp, scfIfOp);
|
||||
return success();
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void FIRToSCFPass::runOnOperation() {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
patterns.add<DoLoopConversion, IfConversion>(patterns.getContext());
|
||||
ConversionTarget target(getContext());
|
||||
mlir::ConversionTarget target(getContext());
|
||||
target.addIllegalOp<fir::DoLoopOp, fir::IfOp>();
|
||||
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
|
||||
target.markUnknownOpDynamicallyLegal([](mlir::Operation *) { return true; });
|
||||
if (failed(
|
||||
applyPartialConversion(getOperation(), target, std::move(patterns))))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> fir::createFIRToSCFPass() {
|
||||
std::unique_ptr<mlir::Pass> fir::createFIRToSCFPass() {
|
||||
return std::make_unique<FIRToSCFPass>();
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user