[MLIR][SCF] Actually use conversion interface in scf-to-cf conversion
This commit is contained in:
parent
790bee99de
commit
cdd31610fd
@ -100,11 +100,13 @@ struct SCFToControlFlowPass
|
||||
// | <%init visible by dominance> |
|
||||
// +--------------------------------+
|
||||
//
|
||||
struct ForLowering : public OpRewritePattern<ForOp> {
|
||||
using OpRewritePattern<ForOp>::OpRewritePattern;
|
||||
struct ForLowering : public OpConversionPattern<ForOp> {
|
||||
using OpConversionPattern<ForOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ForOp forOp,
|
||||
PatternRewriter &rewriter) const override;
|
||||
LogicalResult
|
||||
matchAndRewrite(ForOp forOp,
|
||||
typename OpConversionPattern<ForOp>::OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
// Create a CFG subgraph for the scf.if operation (including its "then" and
|
||||
@ -193,25 +195,31 @@ struct ForLowering : public OpRewritePattern<ForOp> {
|
||||
// | <code after the IfOp> |
|
||||
// +--------------------------------+
|
||||
//
|
||||
struct IfLowering : public OpRewritePattern<IfOp> {
|
||||
using OpRewritePattern<IfOp>::OpRewritePattern;
|
||||
struct IfLowering : public OpConversionPattern<IfOp> {
|
||||
using OpConversionPattern<IfOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(IfOp ifOp,
|
||||
PatternRewriter &rewriter) const override;
|
||||
LogicalResult
|
||||
matchAndRewrite(IfOp ifOp,
|
||||
typename OpConversionPattern<IfOp>::OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
struct ExecuteRegionLowering : public OpRewritePattern<ExecuteRegionOp> {
|
||||
using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
|
||||
struct ExecuteRegionLowering : public OpConversionPattern<ExecuteRegionOp> {
|
||||
using OpConversionPattern<ExecuteRegionOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ExecuteRegionOp op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
LogicalResult matchAndRewrite(
|
||||
ExecuteRegionOp op,
|
||||
typename OpConversionPattern<ExecuteRegionOp>::OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
|
||||
using OpRewritePattern<mlir::scf::ParallelOp>::OpRewritePattern;
|
||||
struct ParallelLowering : public OpConversionPattern<mlir::scf::ParallelOp> {
|
||||
using OpConversionPattern<mlir::scf::ParallelOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp,
|
||||
PatternRewriter &rewriter) const override;
|
||||
LogicalResult matchAndRewrite(
|
||||
mlir::scf::ParallelOp parallelOp,
|
||||
typename OpConversionPattern<mlir::scf::ParallelOp>::OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
/// Create a CFG subgraph for this loop construct. The regions of the loop need
|
||||
@ -273,41 +281,49 @@ struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
|
||||
/// the results of the WhileOp are defined in the 'before' region, which is
|
||||
/// required to have a single existing block, and are therefore accessible in
|
||||
/// the continuation block due to dominance.
|
||||
struct WhileLowering : public OpRewritePattern<WhileOp> {
|
||||
using OpRewritePattern<WhileOp>::OpRewritePattern;
|
||||
struct WhileLowering : public OpConversionPattern<WhileOp> {
|
||||
using OpConversionPattern<WhileOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(WhileOp whileOp,
|
||||
PatternRewriter &rewriter) const override;
|
||||
LogicalResult
|
||||
matchAndRewrite(WhileOp whileOp,
|
||||
typename OpConversionPattern<WhileOp>::OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
/// Optimized version of the above for the case of the "after" region merely
|
||||
/// forwarding its arguments back to the "before" region (i.e., a "do-while"
|
||||
/// loop). This avoid inlining the "after" region completely and branches back
|
||||
/// to the "before" entry instead.
|
||||
struct DoWhileLowering : public OpRewritePattern<WhileOp> {
|
||||
using OpRewritePattern<WhileOp>::OpRewritePattern;
|
||||
struct DoWhileLowering : public OpConversionPattern<WhileOp> {
|
||||
using OpConversionPattern<WhileOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(WhileOp whileOp,
|
||||
PatternRewriter &rewriter) const override;
|
||||
LogicalResult
|
||||
matchAndRewrite(WhileOp whileOp,
|
||||
typename OpConversionPattern<WhileOp>::OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
/// Lower an `scf.index_switch` operation to a `cf.switch` operation.
|
||||
struct IndexSwitchLowering : public OpRewritePattern<IndexSwitchOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
struct IndexSwitchLowering : public OpConversionPattern<IndexSwitchOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(IndexSwitchOp op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
LogicalResult matchAndRewrite(
|
||||
IndexSwitchOp op,
|
||||
typename OpConversionPattern<IndexSwitchOp>::OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
/// Lower an `scf.forall` operation to an `scf.parallel` op, assuming that it
|
||||
/// has no shared outputs. Ops with shared outputs should be bufferized first.
|
||||
/// Specialized lowerings for `scf.forall` (e.g., for GPUs) exist in other
|
||||
/// dialects/passes.
|
||||
struct ForallLowering : public OpRewritePattern<mlir::scf::ForallOp> {
|
||||
using OpRewritePattern<mlir::scf::ForallOp>::OpRewritePattern;
|
||||
struct ForallLowering : public OpConversionPattern<mlir::scf::ForallOp> {
|
||||
using OpConversionPattern<mlir::scf::ForallOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(mlir::scf::ForallOp forallOp,
|
||||
PatternRewriter &rewriter) const override;
|
||||
LogicalResult matchAndRewrite(
|
||||
mlir::scf::ForallOp forallOp,
|
||||
typename OpConversionPattern<mlir::scf::ForallOp>::OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
@ -325,8 +341,9 @@ static void propagateLoopAttrs(Operation *scfOp, Operation *brOp) {
|
||||
brOp->setDiscardableAttrs(llvmAttrs);
|
||||
}
|
||||
|
||||
LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
|
||||
PatternRewriter &rewriter) const {
|
||||
LogicalResult ForLowering::matchAndRewrite(
|
||||
ForOp forOp, typename OpConversionPattern<ForOp>::OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Location loc = forOp.getLoc();
|
||||
|
||||
// Start by splitting the block containing the 'scf.for' into two parts.
|
||||
@ -397,8 +414,9 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
|
||||
PatternRewriter &rewriter) const {
|
||||
LogicalResult IfLowering::matchAndRewrite(
|
||||
IfOp ifOp, typename OpConversionPattern<IfOp>::OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto loc = ifOp.getLoc();
|
||||
|
||||
// Start by splitting the block containing the 'scf.if' into two parts.
|
||||
@ -453,9 +471,10 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
|
||||
PatternRewriter &rewriter) const {
|
||||
LogicalResult ExecuteRegionLowering::matchAndRewrite(
|
||||
ExecuteRegionOp op,
|
||||
typename OpConversionPattern<ExecuteRegionOp>::OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto loc = op.getLoc();
|
||||
|
||||
auto *condBlock = rewriter.getInsertionBlock();
|
||||
@ -487,9 +506,10 @@ ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
|
||||
PatternRewriter &rewriter) const {
|
||||
LogicalResult ParallelLowering::matchAndRewrite(
|
||||
ParallelOp parallelOp,
|
||||
typename OpConversionPattern<ParallelOp>::OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Location loc = parallelOp.getLoc();
|
||||
auto reductionOp = dyn_cast<ReduceOp>(parallelOp.getBody()->getTerminator());
|
||||
if (!reductionOp) {
|
||||
@ -563,8 +583,9 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
|
||||
PatternRewriter &rewriter) const {
|
||||
LogicalResult WhileLowering::matchAndRewrite(
|
||||
WhileOp whileOp, typename OpConversionPattern<WhileOp>::OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
Location loc = whileOp.getLoc();
|
||||
|
||||
@ -606,9 +627,9 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
DoWhileLowering::matchAndRewrite(WhileOp whileOp,
|
||||
PatternRewriter &rewriter) const {
|
||||
LogicalResult DoWhileLowering::matchAndRewrite(
|
||||
WhileOp whileOp, typename OpConversionPattern<WhileOp>::OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Block &afterBlock = *whileOp.getAfterBody();
|
||||
if (!llvm::hasSingleElement(afterBlock))
|
||||
return rewriter.notifyMatchFailure(whileOp,
|
||||
@ -652,9 +673,10 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
|
||||
PatternRewriter &rewriter) const {
|
||||
LogicalResult IndexSwitchLowering::matchAndRewrite(
|
||||
IndexSwitchOp op,
|
||||
typename OpConversionPattern<IndexSwitchOp>::OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// Split the block at the op.
|
||||
Block *condBlock = rewriter.getInsertionBlock();
|
||||
Block *continueBlock = rewriter.splitBlock(condBlock, Block::iterator(op));
|
||||
@ -714,8 +736,10 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
|
||||
PatternRewriter &rewriter) const {
|
||||
LogicalResult ForallLowering::matchAndRewrite(
|
||||
ForallOp forallOp,
|
||||
typename OpConversionPattern<ForallOp>::OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
return scf::forallToParallelLoop(rewriter, forallOp);
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user