[mlir][SliceAnalysis] Fix stack overflow in graph regions (#139694)
This analysis currently just crashes when applied to a graph region that has a use-def cycle. This PR fixes that by keeping track of the operations the DFS has already visited when following use-def edges and stopping once we visit an operation again.
This commit is contained in:
parent
6296ebd45d
commit
18624ae54b
@ -65,8 +65,9 @@ using ForwardSliceOptions = SliceOptions;
|
||||
///
|
||||
/// The implementation traverses the use chains in postorder traversal for
|
||||
/// efficiency reasons: if an operation is already in `forwardSlice`, no
|
||||
/// need to traverse its uses again. Since use-def chains form a DAG, this
|
||||
/// terminates.
|
||||
/// need to traverse its uses again. In the presence of use-def cycles in a
|
||||
/// graph region, the traversal stops at the first operation that was already
|
||||
/// visited (which is not added to the slice anymore).
|
||||
///
|
||||
/// Upon return to the root call, `forwardSlice` is filled with a
|
||||
/// postorder list of uses (i.e. a reverse topological order). To get a proper
|
||||
@ -114,8 +115,9 @@ void getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
|
||||
///
|
||||
/// The implementation traverses the def chains in postorder traversal for
|
||||
/// efficiency reasons: if an operation is already in `backwardSlice`, no
|
||||
/// need to traverse its definitions again. Since useuse-def chains form a DAG,
|
||||
/// this terminates.
|
||||
/// need to traverse its definitions again. In the presence of use-def cycles
|
||||
/// in a graph region, the traversal stops at the first operation that was
|
||||
/// already visited (which is not added to the slice anymore).
|
||||
///
|
||||
/// Upon return to the root call, `backwardSlice` is filled with a
|
||||
/// postorder list of defs. This happens to be a topological order, from the
|
||||
|
@ -26,7 +26,8 @@
|
||||
using namespace mlir;
|
||||
|
||||
static void
|
||||
getForwardSliceImpl(Operation *op, SetVector<Operation *> *forwardSlice,
|
||||
getForwardSliceImpl(Operation *op, DenseSet<Operation *> &visited,
|
||||
SetVector<Operation *> *forwardSlice,
|
||||
const SliceOptions::TransitiveFilter &filter = nullptr) {
|
||||
if (!op)
|
||||
return;
|
||||
@ -40,12 +41,31 @@ getForwardSliceImpl(Operation *op, SetVector<Operation *> *forwardSlice,
|
||||
for (Region ®ion : op->getRegions())
|
||||
for (Block &block : region)
|
||||
for (Operation &blockOp : block)
|
||||
if (forwardSlice->count(&blockOp) == 0)
|
||||
getForwardSliceImpl(&blockOp, forwardSlice, filter);
|
||||
for (Value result : op->getResults()) {
|
||||
for (Operation *userOp : result.getUsers())
|
||||
if (forwardSlice->count(userOp) == 0)
|
||||
getForwardSliceImpl(userOp, forwardSlice, filter);
|
||||
if (forwardSlice->count(&blockOp) == 0) {
|
||||
// We don't have to check if the 'blockOp' is already visited because
|
||||
// there cannot be a traversal path from this nested op to the parent
|
||||
// and thus a cycle cannot be closed here. We still have to mark it
|
||||
// as visited to stop before visiting this operation again if it is
|
||||
// part of a cycle.
|
||||
visited.insert(&blockOp);
|
||||
getForwardSliceImpl(&blockOp, visited, forwardSlice, filter);
|
||||
visited.erase(&blockOp);
|
||||
}
|
||||
|
||||
for (Value result : op->getResults())
|
||||
for (Operation *userOp : result.getUsers()) {
|
||||
// A cycle can only occur within a basic block (not across regions or
|
||||
// basic blocks) because the parent region must be a graph region, graph
|
||||
// regions are restricted to always have 0 or 1 blocks, and there cannot
|
||||
// be a def-use edge from a nested operation to an operation in an
|
||||
// ancestor region. Therefore, we don't have to but may use the same
|
||||
// 'visited' set across regions/blocks as long as we remove operations
|
||||
// from the set again when the DFS traverses back from the leaf to the
|
||||
// root.
|
||||
if (forwardSlice->count(userOp) == 0 && visited.insert(userOp).second)
|
||||
getForwardSliceImpl(userOp, visited, forwardSlice, filter);
|
||||
|
||||
visited.erase(userOp);
|
||||
}
|
||||
|
||||
forwardSlice->insert(op);
|
||||
@ -53,7 +73,9 @@ getForwardSliceImpl(Operation *op, SetVector<Operation *> *forwardSlice,
|
||||
|
||||
void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
|
||||
const ForwardSliceOptions &options) {
|
||||
getForwardSliceImpl(op, forwardSlice, options.filter);
|
||||
DenseSet<Operation *> visited;
|
||||
visited.insert(op);
|
||||
getForwardSliceImpl(op, visited, forwardSlice, options.filter);
|
||||
if (!options.inclusive) {
|
||||
// Don't insert the top level operation, we just queried on it and don't
|
||||
// want it in the results.
|
||||
@ -69,8 +91,12 @@ void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
|
||||
|
||||
void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
|
||||
const SliceOptions &options) {
|
||||
for (Operation *user : root.getUsers())
|
||||
getForwardSliceImpl(user, forwardSlice, options.filter);
|
||||
DenseSet<Operation *> visited;
|
||||
for (Operation *user : root.getUsers()) {
|
||||
visited.insert(user);
|
||||
getForwardSliceImpl(user, visited, forwardSlice, options.filter);
|
||||
visited.erase(user);
|
||||
}
|
||||
|
||||
// Reverse to get back the actual topological order.
|
||||
// std::reverse does not work out of the box on SetVector and I want an
|
||||
@ -80,6 +106,7 @@ void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
|
||||
}
|
||||
|
||||
static LogicalResult getBackwardSliceImpl(Operation *op,
|
||||
DenseSet<Operation *> &visited,
|
||||
SetVector<Operation *> *backwardSlice,
|
||||
const BackwardSliceOptions &options) {
|
||||
if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>())
|
||||
@ -93,8 +120,12 @@ static LogicalResult getBackwardSliceImpl(Operation *op,
|
||||
|
||||
auto processValue = [&](Value value) {
|
||||
if (auto *definingOp = value.getDefiningOp()) {
|
||||
if (backwardSlice->count(definingOp) == 0)
|
||||
return getBackwardSliceImpl(definingOp, backwardSlice, options);
|
||||
if (backwardSlice->count(definingOp) == 0 &&
|
||||
visited.insert(definingOp).second)
|
||||
return getBackwardSliceImpl(definingOp, visited, backwardSlice,
|
||||
options);
|
||||
|
||||
visited.erase(definingOp);
|
||||
} else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
|
||||
if (options.omitBlockArguments)
|
||||
return success();
|
||||
@ -107,7 +138,8 @@ static LogicalResult getBackwardSliceImpl(Operation *op,
|
||||
if (parentOp && backwardSlice->count(parentOp) == 0) {
|
||||
if (parentOp->getNumRegions() == 1 &&
|
||||
llvm::hasSingleElement(parentOp->getRegion(0).getBlocks())) {
|
||||
return getBackwardSliceImpl(parentOp, backwardSlice, options);
|
||||
return getBackwardSliceImpl(parentOp, visited, backwardSlice,
|
||||
options);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@ -145,7 +177,10 @@ static LogicalResult getBackwardSliceImpl(Operation *op,
|
||||
LogicalResult mlir::getBackwardSlice(Operation *op,
|
||||
SetVector<Operation *> *backwardSlice,
|
||||
const BackwardSliceOptions &options) {
|
||||
LogicalResult result = getBackwardSliceImpl(op, backwardSlice, options);
|
||||
DenseSet<Operation *> visited;
|
||||
visited.insert(op);
|
||||
LogicalResult result =
|
||||
getBackwardSliceImpl(op, visited, backwardSlice, options);
|
||||
|
||||
if (!options.inclusive) {
|
||||
// Don't insert the top level operation, we just queried on it and don't
|
||||
|
@ -292,3 +292,26 @@ func.func @slicing_test_multiple_return(%arg0: index) -> (index, index) {
|
||||
%0:2 = "slicing-test-op"(%arg0, %arg0): (index, index) -> (index, index)
|
||||
return %0#0, %0#1 : index, index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// FWD-LABEL: graph_region_with_cycle
|
||||
// BWD-LABEL: graph_region_with_cycle
|
||||
// FWDBWD-LABEL: graph_region_with_cycle
|
||||
func.func @graph_region_with_cycle() {
|
||||
test.isolated_graph_region {
|
||||
// FWD: matched: [[V0:%.+]] = "slicing-test-op"([[V1:%.+]]) : (i1) -> i1 forward static slice:
|
||||
// FWD: [[V1]] = "slicing-test-op"([[V0]]) : (i1) -> i1
|
||||
// FWD: matched: [[V1]] = "slicing-test-op"([[V0]]) : (i1) -> i1 forward static slice:
|
||||
// FWD: [[V0]] = "slicing-test-op"([[V1]]) : (i1) -> i1
|
||||
|
||||
// BWD: matched: [[V0:%.+]] = "slicing-test-op"([[V1:%.+]]) : (i1) -> i1 backward static slice:
|
||||
// BWD: [[V1]] = "slicing-test-op"([[V0]]) : (i1) -> i1
|
||||
// BWD: matched: [[V1]] = "slicing-test-op"([[V0]]) : (i1) -> i1 backward static slice:
|
||||
// BWD: [[V0]] = "slicing-test-op"([[V1]]) : (i1) -> i1
|
||||
%0 = "slicing-test-op"(%1) : (i1) -> i1
|
||||
%1 = "slicing-test-op"(%0) : (i1) -> i1
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user