These passes generally don't rely on any special aspects of FuncOp, and moving allows for these passes to be used in many more situations. The passes that obviously weren't relying on invariants guaranteed by a "function" were updated to be generic pass, the rest were updated to be FunctionOpinterface InterfacePasses. The test updates are NFC switching from implicit nesting (-pass -pass2) form to the -pass-pipeline form (generic passes do not implicitly nest as op-specific passes do). Differential Revision: https://reviews.llvm.org/D121190
69 lines
2.4 KiB
C++
69 lines
2.4 KiB
C++
//===- ConvertShapeConstraints.cpp - Conversion of shape constraints ------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
|
|
|
|
#include "../PassDetail.h"
|
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
|
#include "mlir/Dialect/SCF/SCF.h"
|
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Pass/PassRegistry.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
using namespace mlir;
|
|
namespace {
|
|
#include "ShapeToStandard.cpp.inc"
|
|
} // namespace
|
|
|
|
namespace {
|
|
class ConvertCstrRequireOp : public OpRewritePattern<shape::CstrRequireOp> {
|
|
public:
|
|
using OpRewritePattern::OpRewritePattern;
|
|
LogicalResult matchAndRewrite(shape::CstrRequireOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
rewriter.create<cf::AssertOp>(op.getLoc(), op.getPred(), op.getMsgAttr());
|
|
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
|
|
return success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
void mlir::populateConvertShapeConstraintsConversionPatterns(
|
|
RewritePatternSet &patterns) {
|
|
patterns.add<CstrBroadcastableToRequire>(patterns.getContext());
|
|
patterns.add<CstrEqToRequire>(patterns.getContext());
|
|
patterns.add<ConvertCstrRequireOp>(patterns.getContext());
|
|
}
|
|
|
|
namespace {
|
|
// This pass eliminates shape constraints from the program, converting them to
|
|
// eager (side-effecting) error handling code. After eager error handling code
|
|
// is emitted, witnesses are satisfied, so they are replace with
|
|
// `shape.const_witness true`.
|
|
class ConvertShapeConstraints
|
|
: public ConvertShapeConstraintsBase<ConvertShapeConstraints> {
|
|
void runOnOperation() override {
|
|
auto func = getOperation();
|
|
auto *context = &getContext();
|
|
|
|
RewritePatternSet patterns(context);
|
|
populateConvertShapeConstraintsConversionPatterns(patterns);
|
|
|
|
if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns))))
|
|
return signalPassFailure();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
std::unique_ptr<Pass> mlir::createConvertShapeConstraintsPass() {
|
|
return std::make_unique<ConvertShapeConstraints>();
|
|
}
|