Enable pass instrumentation to signal failures. (#163126)
Enables adding instrumentation to pass manager that can track/flag invariants. This would be useful for cases where one some tighter requirements than the general dialects or for a phase of conversion that elsewhere. It would enable making verify also just a regular instrumentation I believe, but also a non-goal as that is a first class concept and baseline for the ops and passes. Would have enabled some of the requirements of https://discourse.llvm.org/t/pre-verification-logic-before-running-conversion-pass-in-mlir/88318/10 .
This commit is contained in:
parent
0f2f9e1c80
commit
6b7b0ab530
@ -17,6 +17,7 @@
|
||||
#include <optional>
|
||||
|
||||
namespace mlir {
|
||||
class PassInstrumentation;
|
||||
namespace detail {
|
||||
class OpToOpPassAdaptor;
|
||||
struct OpPassManagerImpl;
|
||||
@ -341,6 +342,9 @@ private:
|
||||
|
||||
/// Allow access to 'passOptions'.
|
||||
friend class PassInfo;
|
||||
|
||||
/// Allow access to 'signalPassFailure'.
|
||||
friend class PassInstrumentation;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@ -80,6 +80,10 @@ public:
|
||||
/// name of the analysis that was computed, its TypeID, as well as the
|
||||
/// current operation being analyzed.
|
||||
virtual void runAfterAnalysis(StringRef name, TypeID id, Operation *op) {}
|
||||
|
||||
/// Helper method to enable analysis to signal pass failure. Used, for
|
||||
/// example, when pre- or post-conditions fail.
|
||||
void signalPassFailure(Pass *pass);
|
||||
};
|
||||
|
||||
/// This class holds a collection of PassInstrumentation objects, and invokes
|
||||
|
||||
@ -599,17 +599,21 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
|
||||
if (pi)
|
||||
pi->runBeforePass(pass, op);
|
||||
|
||||
bool passFailed = false;
|
||||
op->getContext()->executeAction<PassExecutionAction>(
|
||||
[&]() {
|
||||
// Invoke the virtual runOnOperation method.
|
||||
if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass))
|
||||
adaptor->runOnOperation(verifyPasses);
|
||||
else
|
||||
pass->runOnOperation();
|
||||
passFailed = pass->passState->irAndPassFailed.getInt();
|
||||
},
|
||||
{op}, *pass);
|
||||
// Pass instrumentation can use pass failure to flag unmet invariants
|
||||
// (preconditions) of the pass. Skip running pass if in failure state.
|
||||
bool passFailed = pass->passState->irAndPassFailed.getInt();
|
||||
if (!passFailed) {
|
||||
op->getContext()->executeAction<PassExecutionAction>(
|
||||
[&]() {
|
||||
// Invoke the virtual runOnOperation method.
|
||||
if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass))
|
||||
adaptor->runOnOperation(verifyPasses);
|
||||
else
|
||||
pass->runOnOperation();
|
||||
passFailed = pass->passState->irAndPassFailed.getInt();
|
||||
},
|
||||
{op}, *pass);
|
||||
}
|
||||
|
||||
// Invalidate any non preserved analyses.
|
||||
am.invalidate(pass->passState->preservedAnalyses);
|
||||
@ -640,10 +644,12 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
|
||||
|
||||
// Instrument after the pass has run.
|
||||
if (pi) {
|
||||
if (passFailed)
|
||||
if (passFailed) {
|
||||
pi->runAfterPassFailed(pass, op);
|
||||
else
|
||||
} else {
|
||||
pi->runAfterPass(pass, op);
|
||||
passFailed = passFailed || pass->passState->irAndPassFailed.getInt();
|
||||
}
|
||||
}
|
||||
|
||||
// Return if the pass signaled a failure.
|
||||
@ -1198,6 +1204,10 @@ void PassInstrumentation::runBeforePipeline(
|
||||
void PassInstrumentation::runAfterPipeline(
|
||||
std::optional<OperationName> name, const PipelineParentInfo &parentInfo) {}
|
||||
|
||||
void PassInstrumentation::signalPassFailure(Pass *pass) {
|
||||
pass->signalPassFailure();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PassInstrumentor
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassInstrumentation.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include <memory>
|
||||
@ -117,6 +118,106 @@ struct AddSecondAttrFunctionPass
|
||||
}
|
||||
};
|
||||
|
||||
/// PassInstrumentation to count pass callbacks and signal pass failures.
|
||||
struct TestPassInstrumentation : public PassInstrumentation {
|
||||
int beforePassCallbackCount = 0;
|
||||
int afterPassCallbackCount = 0;
|
||||
int afterPassFailedCallbackCount = 0;
|
||||
|
||||
bool failBeforePass = false;
|
||||
bool failAfterPass = false;
|
||||
|
||||
void runBeforePass(Pass *pass, Operation *op) override {
|
||||
if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>())
|
||||
return;
|
||||
|
||||
++beforePassCallbackCount;
|
||||
if (failBeforePass)
|
||||
signalPassFailure(pass);
|
||||
}
|
||||
void runAfterPass(Pass *pass, Operation *op) override {
|
||||
if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>())
|
||||
return;
|
||||
|
||||
++afterPassCallbackCount;
|
||||
if (failAfterPass)
|
||||
signalPassFailure(pass);
|
||||
}
|
||||
void runAfterPassFailed(Pass *pass, Operation *op) override {
|
||||
if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>())
|
||||
return;
|
||||
|
||||
++afterPassFailedCallbackCount;
|
||||
}
|
||||
};
|
||||
|
||||
TEST(PassManagerTest, PassInstrumentation) {
|
||||
MLIRContext context;
|
||||
context.loadDialect<func::FuncDialect>();
|
||||
Builder b(&context);
|
||||
|
||||
// Create a module with 1 function.
|
||||
OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
|
||||
auto func = func::FuncOp::create(b.getUnknownLoc(), "test_func",
|
||||
b.getFunctionType({}, {}));
|
||||
func.setPrivate();
|
||||
module->push_back(func);
|
||||
|
||||
struct InstrumentationCounts {
|
||||
int beforePass;
|
||||
int afterPass;
|
||||
int afterPassFailed;
|
||||
};
|
||||
|
||||
auto runInstrumentation =
|
||||
[&](bool failBefore,
|
||||
bool failAfter) -> std::pair<LogicalResult, InstrumentationCounts> {
|
||||
// Instantiate and run our pass.
|
||||
auto pm = PassManager::on<ModuleOp>(&context);
|
||||
auto instrumentation = std::make_unique<TestPassInstrumentation>();
|
||||
auto *instrumentationPtr = instrumentation.get();
|
||||
instrumentation->failBeforePass = failBefore;
|
||||
instrumentation->failAfterPass = failAfter;
|
||||
pm.addInstrumentation(std::move(instrumentation));
|
||||
pm.addNestedPass<func::FuncOp>(std::make_unique<AddAttrFunctionPass>());
|
||||
LogicalResult result = pm.run(module.get());
|
||||
|
||||
InstrumentationCounts counts = {
|
||||
instrumentationPtr->beforePassCallbackCount,
|
||||
instrumentationPtr->afterPassCallbackCount,
|
||||
instrumentationPtr->afterPassFailedCallbackCount};
|
||||
return {result, counts};
|
||||
};
|
||||
|
||||
for (bool failBefore : {false, true}) {
|
||||
for (bool failAfter : {false, true}) {
|
||||
auto [result, counts] = runInstrumentation(failBefore, failAfter);
|
||||
|
||||
InstrumentationCounts expected;
|
||||
if (failBefore) {
|
||||
EXPECT_TRUE(failed(result))
|
||||
<< "failBefore=" << failBefore << ", failAfter=" << failAfter;
|
||||
expected = {/*beforePass=*/1, /*afterPass=*/0, /*afterPassFailed=*/1};
|
||||
} else if (failAfter) {
|
||||
EXPECT_TRUE(failed(result))
|
||||
<< "failBefore=" << failBefore << ", failAfter=" << failAfter;
|
||||
expected = {/*beforePass=*/1, /*afterPass=*/1, /*afterPassFailed=*/0};
|
||||
} else {
|
||||
EXPECT_TRUE(succeeded(result))
|
||||
<< "failBefore=" << failBefore << ", failAfter=" << failAfter;
|
||||
expected = {/*beforePass=*/1, /*afterPass=*/1, /*afterPassFailed=*/0};
|
||||
}
|
||||
|
||||
EXPECT_EQ(counts.beforePass, expected.beforePass)
|
||||
<< "failBefore=" << failBefore << ", failAfter=" << failAfter;
|
||||
EXPECT_EQ(counts.afterPass, expected.afterPass)
|
||||
<< "failBefore=" << failBefore << ", failAfter=" << failAfter;
|
||||
EXPECT_EQ(counts.afterPassFailed, expected.afterPassFailed)
|
||||
<< "failBefore=" << failBefore << ", failAfter=" << failAfter;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(PassManagerTest, ExecutionAction) {
|
||||
MLIRContext context;
|
||||
context.loadDialect<func::FuncDialect>();
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user