llvm-project/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
Miloš Poletanović d9b021d33a
[MLIR][Transform] Safely erase transform ops by collecting first (#172016)
Avoids runtime crashes caused by deleting operations inside a walk.
Operations are gathered during the walk and then erased in the correct
dependency order after the walk finishes.

Issue: [#130023](https://github.com/llvm/llvm-project/issues/130023)

---------

Co-authored-by: Milos Poletanovic <mpoletanovic@syrmia.com>
Co-authored-by: Milos Poletanovic <milos.poletanovic@htecgroup.com>
2025-12-22 11:51:22 +01:00

63 lines
2.0 KiB
C++

//===- TestTransformDialectInterpreter.cpp --------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines a test pass that interprets Transform dialect operations in
// the module.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
namespace {
template <typename Derived>
class OpPassWrapper : public PassWrapper<Derived, OperationPass<>> {};
struct TestTransformDialectEraseSchedulePass
: public PassWrapper<TestTransformDialectEraseSchedulePass,
OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestTransformDialectEraseSchedulePass)
StringRef getArgument() const final {
return "test-transform-dialect-erase-schedule";
}
StringRef getDescription() const final {
return "erase transform dialect schedule from the IR";
}
void runOnOperation() override {
SmallVector<Operation *> opsToDelete;
getOperation()->walk<WalkOrder::PreOrder>([&](Operation *nestedOp) {
if (isa<transform::TransformOpInterface>(nestedOp)) {
opsToDelete.push_back(nestedOp);
return WalkResult::skip();
}
return WalkResult::advance();
});
for (Operation *op : llvm::reverse(opsToDelete)) {
// erase the operation
op->erase();
}
}
};
} // namespace
namespace mlir {
namespace test {
/// Registers the test pass for erasing transform dialect ops.
void registerTestTransformDialectEraseSchedulePass() {
PassRegistration<TestTransformDialectEraseSchedulePass> reg;
}
} // namespace test
} // namespace mlir