diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp index 37cfc9f2c23e..d9ec93224477 100644 --- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp +++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp @@ -100,11 +100,13 @@ struct SCFToControlFlowPass // | <%init visible by dominance> | // +--------------------------------+ // -struct ForLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct ForLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(ForOp forOp, - PatternRewriter &rewriter) const override; + LogicalResult + matchAndRewrite(ForOp forOp, + typename OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; }; // Create a CFG subgraph for the scf.if operation (including its "then" and @@ -193,25 +195,31 @@ struct ForLowering : public OpRewritePattern { // | | // +--------------------------------+ // -struct IfLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct IfLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(IfOp ifOp, - PatternRewriter &rewriter) const override; + LogicalResult + matchAndRewrite(IfOp ifOp, + typename OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; }; -struct ExecuteRegionLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct ExecuteRegionLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(ExecuteRegionOp op, - PatternRewriter &rewriter) const override; + LogicalResult matchAndRewrite( + ExecuteRegionOp op, + typename OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; }; -struct ParallelLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct ParallelLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp, - PatternRewriter &rewriter) const override; + LogicalResult matchAndRewrite( + mlir::scf::ParallelOp parallelOp, + typename OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; }; /// Create a CFG subgraph for this loop construct. The regions of the loop need @@ -273,41 +281,49 @@ struct ParallelLowering : public OpRewritePattern { /// the results of the WhileOp are defined in the 'before' region, which is /// required to have a single existing block, and are therefore accessible in /// the continuation block due to dominance. -struct WhileLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct WhileLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(WhileOp whileOp, - PatternRewriter &rewriter) const override; + LogicalResult + matchAndRewrite(WhileOp whileOp, + typename OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; }; /// Optimized version of the above for the case of the "after" region merely /// forwarding its arguments back to the "before" region (i.e., a "do-while" /// loop). This avoid inlining the "after" region completely and branches back /// to the "before" entry instead. -struct DoWhileLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct DoWhileLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(WhileOp whileOp, - PatternRewriter &rewriter) const override; + LogicalResult + matchAndRewrite(WhileOp whileOp, + typename OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; }; /// Lower an `scf.index_switch` operation to a `cf.switch` operation. -struct IndexSwitchLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct IndexSwitchLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(IndexSwitchOp op, - PatternRewriter &rewriter) const override; + LogicalResult matchAndRewrite( + IndexSwitchOp op, + typename OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; }; /// Lower an `scf.forall` operation to an `scf.parallel` op, assuming that it /// has no shared outputs. Ops with shared outputs should be bufferized first. /// Specialized lowerings for `scf.forall` (e.g., for GPUs) exist in other /// dialects/passes. -struct ForallLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct ForallLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(mlir::scf::ForallOp forallOp, - PatternRewriter &rewriter) const override; + LogicalResult matchAndRewrite( + mlir::scf::ForallOp forallOp, + typename OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; }; } // namespace @@ -325,8 +341,9 @@ static void propagateLoopAttrs(Operation *scfOp, Operation *brOp) { brOp->setDiscardableAttrs(llvmAttrs); } -LogicalResult ForLowering::matchAndRewrite(ForOp forOp, - PatternRewriter &rewriter) const { +LogicalResult ForLowering::matchAndRewrite( + ForOp forOp, typename OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { Location loc = forOp.getLoc(); // Start by splitting the block containing the 'scf.for' into two parts. @@ -397,8 +414,9 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp, return success(); } -LogicalResult IfLowering::matchAndRewrite(IfOp ifOp, - PatternRewriter &rewriter) const { +LogicalResult IfLowering::matchAndRewrite( + IfOp ifOp, typename OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { auto loc = ifOp.getLoc(); // Start by splitting the block containing the 'scf.if' into two parts. @@ -453,9 +471,10 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp, return success(); } -LogicalResult -ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op, - PatternRewriter &rewriter) const { +LogicalResult ExecuteRegionLowering::matchAndRewrite( + ExecuteRegionOp op, + typename OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { auto loc = op.getLoc(); auto *condBlock = rewriter.getInsertionBlock(); @@ -487,9 +506,10 @@ ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op, return success(); } -LogicalResult -ParallelLowering::matchAndRewrite(ParallelOp parallelOp, - PatternRewriter &rewriter) const { +LogicalResult ParallelLowering::matchAndRewrite( + ParallelOp parallelOp, + typename OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { Location loc = parallelOp.getLoc(); auto reductionOp = dyn_cast(parallelOp.getBody()->getTerminator()); if (!reductionOp) { @@ -563,8 +583,9 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp, return success(); } -LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp, - PatternRewriter &rewriter) const { +LogicalResult WhileLowering::matchAndRewrite( + WhileOp whileOp, typename OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { OpBuilder::InsertionGuard guard(rewriter); Location loc = whileOp.getLoc(); @@ -606,9 +627,9 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp, return success(); } -LogicalResult -DoWhileLowering::matchAndRewrite(WhileOp whileOp, - PatternRewriter &rewriter) const { +LogicalResult DoWhileLowering::matchAndRewrite( + WhileOp whileOp, typename OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { Block &afterBlock = *whileOp.getAfterBody(); if (!llvm::hasSingleElement(afterBlock)) return rewriter.notifyMatchFailure(whileOp, @@ -652,9 +673,10 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp, return success(); } -LogicalResult -IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op, - PatternRewriter &rewriter) const { +LogicalResult IndexSwitchLowering::matchAndRewrite( + IndexSwitchOp op, + typename OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { // Split the block at the op. Block *condBlock = rewriter.getInsertionBlock(); Block *continueBlock = rewriter.splitBlock(condBlock, Block::iterator(op)); @@ -714,8 +736,10 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op, return success(); } -LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp, - PatternRewriter &rewriter) const { +LogicalResult ForallLowering::matchAndRewrite( + ForallOp forallOp, + typename OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { return scf::forallToParallelLoop(rewriter, forallOp); }