Enable pass instrumentation to signal failures. (#163126)

Enables adding instrumentation to pass manager that can track/flag
invariants. This would be useful for cases where one some tighter
requirements than the general dialects or for a phase of conversion that
elsewhere.

It would enable making verify also just a regular instrumentation I
believe, but also a non-goal as that is a first class concept and
baseline for the ops and passes.

Would have enabled some of the requirements of
https://discourse.llvm.org/t/pre-verification-logic-before-running-conversion-pass-in-mlir/88318/10
.
This commit is contained in:
Jacques Pienaar 2025-12-11 14:26:10 +02:00 committed by GitHub
parent 0f2f9e1c80
commit 6b7b0ab530
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 132 additions and 13 deletions

View File

@ -17,6 +17,7 @@
#include <optional>
namespace mlir {
class PassInstrumentation;
namespace detail {
class OpToOpPassAdaptor;
struct OpPassManagerImpl;
@ -341,6 +342,9 @@ private:
/// Allow access to 'passOptions'.
friend class PassInfo;
/// Allow access to 'signalPassFailure'.
friend class PassInstrumentation;
};
//===----------------------------------------------------------------------===//

View File

@ -80,6 +80,10 @@ public:
/// name of the analysis that was computed, its TypeID, as well as the
/// current operation being analyzed.
virtual void runAfterAnalysis(StringRef name, TypeID id, Operation *op) {}
/// Helper method to enable analysis to signal pass failure. Used, for
/// example, when pre- or post-conditions fail.
void signalPassFailure(Pass *pass);
};
/// This class holds a collection of PassInstrumentation objects, and invokes

View File

@ -599,17 +599,21 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
if (pi)
pi->runBeforePass(pass, op);
bool passFailed = false;
op->getContext()->executeAction<PassExecutionAction>(
[&]() {
// Invoke the virtual runOnOperation method.
if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass))
adaptor->runOnOperation(verifyPasses);
else
pass->runOnOperation();
passFailed = pass->passState->irAndPassFailed.getInt();
},
{op}, *pass);
// Pass instrumentation can use pass failure to flag unmet invariants
// (preconditions) of the pass. Skip running pass if in failure state.
bool passFailed = pass->passState->irAndPassFailed.getInt();
if (!passFailed) {
op->getContext()->executeAction<PassExecutionAction>(
[&]() {
// Invoke the virtual runOnOperation method.
if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass))
adaptor->runOnOperation(verifyPasses);
else
pass->runOnOperation();
passFailed = pass->passState->irAndPassFailed.getInt();
},
{op}, *pass);
}
// Invalidate any non preserved analyses.
am.invalidate(pass->passState->preservedAnalyses);
@ -640,10 +644,12 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
// Instrument after the pass has run.
if (pi) {
if (passFailed)
if (passFailed) {
pi->runAfterPassFailed(pass, op);
else
} else {
pi->runAfterPass(pass, op);
passFailed = passFailed || pass->passState->irAndPassFailed.getInt();
}
}
// Return if the pass signaled a failure.
@ -1198,6 +1204,10 @@ void PassInstrumentation::runBeforePipeline(
void PassInstrumentation::runAfterPipeline(
std::optional<OperationName> name, const PipelineParentInfo &parentInfo) {}
void PassInstrumentation::signalPassFailure(Pass *pass) {
pass->signalPassFailure();
}
//===----------------------------------------------------------------------===//
// PassInstrumentor
//===----------------------------------------------------------------------===//

View File

@ -14,6 +14,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassInstrumentation.h"
#include "gtest/gtest.h"
#include <memory>
@ -117,6 +118,106 @@ struct AddSecondAttrFunctionPass
}
};
/// PassInstrumentation to count pass callbacks and signal pass failures.
struct TestPassInstrumentation : public PassInstrumentation {
int beforePassCallbackCount = 0;
int afterPassCallbackCount = 0;
int afterPassFailedCallbackCount = 0;
bool failBeforePass = false;
bool failAfterPass = false;
void runBeforePass(Pass *pass, Operation *op) override {
if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>())
return;
++beforePassCallbackCount;
if (failBeforePass)
signalPassFailure(pass);
}
void runAfterPass(Pass *pass, Operation *op) override {
if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>())
return;
++afterPassCallbackCount;
if (failAfterPass)
signalPassFailure(pass);
}
void runAfterPassFailed(Pass *pass, Operation *op) override {
if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>())
return;
++afterPassFailedCallbackCount;
}
};
TEST(PassManagerTest, PassInstrumentation) {
MLIRContext context;
context.loadDialect<func::FuncDialect>();
Builder b(&context);
// Create a module with 1 function.
OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
auto func = func::FuncOp::create(b.getUnknownLoc(), "test_func",
b.getFunctionType({}, {}));
func.setPrivate();
module->push_back(func);
struct InstrumentationCounts {
int beforePass;
int afterPass;
int afterPassFailed;
};
auto runInstrumentation =
[&](bool failBefore,
bool failAfter) -> std::pair<LogicalResult, InstrumentationCounts> {
// Instantiate and run our pass.
auto pm = PassManager::on<ModuleOp>(&context);
auto instrumentation = std::make_unique<TestPassInstrumentation>();
auto *instrumentationPtr = instrumentation.get();
instrumentation->failBeforePass = failBefore;
instrumentation->failAfterPass = failAfter;
pm.addInstrumentation(std::move(instrumentation));
pm.addNestedPass<func::FuncOp>(std::make_unique<AddAttrFunctionPass>());
LogicalResult result = pm.run(module.get());
InstrumentationCounts counts = {
instrumentationPtr->beforePassCallbackCount,
instrumentationPtr->afterPassCallbackCount,
instrumentationPtr->afterPassFailedCallbackCount};
return {result, counts};
};
for (bool failBefore : {false, true}) {
for (bool failAfter : {false, true}) {
auto [result, counts] = runInstrumentation(failBefore, failAfter);
InstrumentationCounts expected;
if (failBefore) {
EXPECT_TRUE(failed(result))
<< "failBefore=" << failBefore << ", failAfter=" << failAfter;
expected = {/*beforePass=*/1, /*afterPass=*/0, /*afterPassFailed=*/1};
} else if (failAfter) {
EXPECT_TRUE(failed(result))
<< "failBefore=" << failBefore << ", failAfter=" << failAfter;
expected = {/*beforePass=*/1, /*afterPass=*/1, /*afterPassFailed=*/0};
} else {
EXPECT_TRUE(succeeded(result))
<< "failBefore=" << failBefore << ", failAfter=" << failAfter;
expected = {/*beforePass=*/1, /*afterPass=*/1, /*afterPassFailed=*/0};
}
EXPECT_EQ(counts.beforePass, expected.beforePass)
<< "failBefore=" << failBefore << ", failAfter=" << failAfter;
EXPECT_EQ(counts.afterPass, expected.afterPass)
<< "failBefore=" << failBefore << ", failAfter=" << failAfter;
EXPECT_EQ(counts.afterPassFailed, expected.afterPassFailed)
<< "failBefore=" << failBefore << ", failAfter=" << failAfter;
}
}
}
TEST(PassManagerTest, ExecutionAction) {
MLIRContext context;
context.loadDialect<func::FuncDialect>();