[mlir][python] bind block predecessors and successors (#145116)

bind `block.getSuccessor` and `block.getPredecessors`.
This commit is contained in:
Maksim Levental 2025-06-23 18:59:03 -05:00 committed by GitHub
parent bc5e5c0114
commit a2aa812a31
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 223 additions and 4 deletions

View File

@ -986,6 +986,24 @@ MLIR_CAPI_EXPORTED MlirValue mlirBlockGetArgument(MlirBlock block,
MLIR_CAPI_EXPORTED void
mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData);
/// Returns the number of successor blocks of the block.
MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumSuccessors(MlirBlock block);
/// Returns `pos`-th successor of the block.
MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetSuccessor(MlirBlock block,
intptr_t pos);
/// Returns the number of predecessor blocks of the block.
MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumPredecessors(MlirBlock block);
/// Returns `pos`-th predecessor of the block.
///
/// WARNING: This getter is more expensive than the others here because
/// the impl actually iterates the use-def chain (of block operands) anew for
/// each indexed access.
MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetPredecessor(MlirBlock block,
intptr_t pos);
//===----------------------------------------------------------------------===//
// Value API.
//===----------------------------------------------------------------------===//

View File

@ -2626,6 +2626,88 @@ private:
PyOperationRef operation;
};
/// A list of block successors. Internally, these are stored as consecutive
/// elements, random access is cheap. The (returned) successor list is
/// associated with the operation and block whose successors these are, and thus
/// extends the lifetime of this operation and block.
class PyBlockSuccessors : public Sliceable<PyBlockSuccessors, PyBlock> {
public:
static constexpr const char *pyClassName = "BlockSuccessors";
PyBlockSuccessors(PyBlock block, PyOperationRef operation,
intptr_t startIndex = 0, intptr_t length = -1,
intptr_t step = 1)
: Sliceable(startIndex,
length == -1 ? mlirBlockGetNumSuccessors(block.get())
: length,
step),
operation(operation), block(block) {}
private:
/// Give the parent CRTP class access to hook implementations below.
friend class Sliceable<PyBlockSuccessors, PyBlock>;
intptr_t getRawNumElements() {
block.checkValid();
return mlirBlockGetNumSuccessors(block.get());
}
PyBlock getRawElement(intptr_t pos) {
MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos);
return PyBlock(operation, block);
}
PyBlockSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
return PyBlockSuccessors(block, operation, startIndex, length, step);
}
PyOperationRef operation;
PyBlock block;
};
/// A list of block predecessors. The (returned) predecessor list is
/// associated with the operation and block whose predecessors these are, and
/// thus extends the lifetime of this operation and block.
///
/// WARNING: This Sliceable is more expensive than the others here because
/// mlirBlockGetPredecessor actually iterates the use-def chain (of block
/// operands) anew for each indexed access.
class PyBlockPredecessors : public Sliceable<PyBlockPredecessors, PyBlock> {
public:
static constexpr const char *pyClassName = "BlockPredecessors";
PyBlockPredecessors(PyBlock block, PyOperationRef operation,
intptr_t startIndex = 0, intptr_t length = -1,
intptr_t step = 1)
: Sliceable(startIndex,
length == -1 ? mlirBlockGetNumPredecessors(block.get())
: length,
step),
operation(operation), block(block) {}
private:
/// Give the parent CRTP class access to hook implementations below.
friend class Sliceable<PyBlockPredecessors, PyBlock>;
intptr_t getRawNumElements() {
block.checkValid();
return mlirBlockGetNumPredecessors(block.get());
}
PyBlock getRawElement(intptr_t pos) {
MlirBlock block = mlirBlockGetPredecessor(this->block.get(), pos);
return PyBlock(operation, block);
}
PyBlockPredecessors slice(intptr_t startIndex, intptr_t length,
intptr_t step) {
return PyBlockPredecessors(block, operation, startIndex, length, step);
}
PyOperationRef operation;
PyBlock block;
};
/// A list of operation attributes. Can be indexed by name, producing
/// attributes, or by index, producing named attributes.
class PyOpAttributeMap {
@ -3655,7 +3737,19 @@ void mlir::python::populateIRCore(nb::module_ &m) {
},
nb::arg("operation"),
"Appends an operation to this block. If the operation is currently "
"in another block, it will be moved.");
"in another block, it will be moved.")
.def_prop_ro(
"successors",
[](PyBlock &self) {
return PyBlockSuccessors(self, self.getParentOperation());
},
"Returns the list of Block successors.")
.def_prop_ro(
"predecessors",
[](PyBlock &self) {
return PyBlockPredecessors(self, self.getParentOperation());
},
"Returns the list of Block predecessors.");
//----------------------------------------------------------------------------
// Mapping of PyInsertionPoint.
@ -4099,6 +4193,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
PyBlockArgumentList::bind(m);
PyBlockIterator::bind(m);
PyBlockList::bind(m);
PyBlockSuccessors::bind(m);
PyBlockPredecessors::bind(m);
PyOperationIterator::bind(m);
PyOperationList::bind(m);
PyOpAttributeMap::bind(m);

View File

@ -1059,6 +1059,26 @@ void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
unwrap(block)->print(stream);
}
intptr_t mlirBlockGetNumSuccessors(MlirBlock block) {
return static_cast<intptr_t>(unwrap(block)->getNumSuccessors());
}
MlirBlock mlirBlockGetSuccessor(MlirBlock block, intptr_t pos) {
return wrap(unwrap(block)->getSuccessor(static_cast<unsigned>(pos)));
}
intptr_t mlirBlockGetNumPredecessors(MlirBlock block) {
Block *b = unwrap(block);
return static_cast<intptr_t>(std::distance(b->pred_begin(), b->pred_end()));
}
MlirBlock mlirBlockGetPredecessor(MlirBlock block, intptr_t pos) {
Block *b = unwrap(block);
Block::pred_iterator it = b->pred_begin();
std::advance(it, pos);
return wrap(*it);
}
//===----------------------------------------------------------------------===//
// Value API.
//===----------------------------------------------------------------------===//

View File

@ -2440,6 +2440,74 @@ void testDiagnostics(void) {
mlirContextDestroy(ctx);
}
int testBlockPredecessorsSuccessors(MlirContext ctx) {
// CHECK-LABEL: @testBlockPredecessorsSuccessors
fprintf(stderr, "@testBlockPredecessorsSuccessors\n");
const char *moduleString = "module {\n"
" func.func @test(%arg0: i32, %arg1: i16) {\n"
" cf.br ^bb1(%arg1 : i16)\n"
" ^bb1(%0: i16): // pred: ^bb0\n"
" cf.br ^bb2(%arg0 : i32)\n"
" ^bb2(%1: i32): // pred: ^bb1\n"
" return\n"
" }\n"
"}\n";
MlirModule module =
mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
MlirOperation moduleOp = mlirModuleGetOperation(module);
MlirRegion moduleRegion = mlirOperationGetRegion(moduleOp, 0);
MlirBlock moduleBlock = mlirRegionGetFirstBlock(moduleRegion);
MlirOperation function = mlirBlockGetFirstOperation(moduleBlock);
MlirRegion funcRegion = mlirOperationGetRegion(function, 0);
MlirBlock entryBlock = mlirRegionGetFirstBlock(funcRegion);
MlirBlock middleBlock = mlirBlockGetNextInRegion(entryBlock);
MlirBlock successorBlock = mlirBlockGetNextInRegion(middleBlock);
#define FPRINTF_OP(OP, FMT) fprintf(stderr, #OP ": " FMT "\n", OP)
// CHECK: mlirBlockGetNumPredecessors(entryBlock): 0
FPRINTF_OP(mlirBlockGetNumPredecessors(entryBlock), "%ld");
// CHECK: mlirBlockGetNumSuccessors(entryBlock): 1
FPRINTF_OP(mlirBlockGetNumSuccessors(entryBlock), "%ld");
// CHECK: mlirBlockEqual(middleBlock, mlirBlockGetSuccessor(entryBlock, 0)): 1
FPRINTF_OP(mlirBlockEqual(middleBlock, mlirBlockGetSuccessor(entryBlock, 0)),
"%d");
// CHECK: mlirBlockGetNumPredecessors(middleBlock): 1
FPRINTF_OP(mlirBlockGetNumPredecessors(middleBlock), "%ld");
// CHECK: mlirBlockEqual(entryBlock, mlirBlockGetPredecessor(middleBlock, 0))
FPRINTF_OP(
mlirBlockEqual(entryBlock, mlirBlockGetPredecessor(middleBlock, 0)),
"%d");
// CHECK: mlirBlockGetNumSuccessors(middleBlock): 1
FPRINTF_OP(mlirBlockGetNumSuccessors(middleBlock), "%ld");
// CHECK: BlockEqual(successorBlock, mlirBlockGetSuccessor(middleBlock, 0)): 1
fprintf(
stderr,
"BlockEqual(successorBlock, mlirBlockGetSuccessor(middleBlock, 0)): %d\n",
mlirBlockEqual(successorBlock, mlirBlockGetSuccessor(middleBlock, 0)));
// CHECK: mlirBlockGetNumPredecessors(successorBlock): 1
FPRINTF_OP(mlirBlockGetNumPredecessors(successorBlock), "%ld");
// CHECK: Equal(middleBlock, mlirBlockGetPredecessor(successorBlock, 0)): 1
fprintf(
stderr,
"Equal(middleBlock, mlirBlockGetPredecessor(successorBlock, 0)): %d\n",
mlirBlockEqual(middleBlock, mlirBlockGetPredecessor(successorBlock, 0)));
// CHECK: mlirBlockGetNumSuccessors(successorBlock): 0
FPRINTF_OP(mlirBlockGetNumSuccessors(successorBlock), "%ld");
#undef FPRINTF_OP
mlirModuleDestroy(module);
return 0;
}
int main(void) {
MlirContext ctx = mlirContextCreate();
registerAllUpstreamDialects(ctx);
@ -2486,6 +2554,9 @@ int main(void) {
testExplicitThreadPools();
testDiagnostics();
if (testBlockPredecessorsSuccessors(ctx))
return 17;
// CHECK: DESTROY MAIN CONTEXT
// CHECK: reportResourceDelete: resource_i64_blob
fprintf(stderr, "DESTROY MAIN CONTEXT\n");

View File

@ -1,12 +1,11 @@
# RUN: %PYTHON %s | FileCheck %s
import gc
import io
import itertools
from mlir.ir import *
from mlir.dialects import builtin
from mlir.dialects import cf
from mlir.dialects import func
from mlir.ir import *
def run(f):
@ -54,10 +53,25 @@ def testBlockCreation():
with InsertionPoint(middle_block) as middle_ip:
assert middle_ip.block == middle_block
cf.BranchOp([i32_arg], dest=successor_block)
module.print(enable_debug_info=True)
# Ensure region back references are coherent.
assert entry_block.region == middle_block.region == successor_block.region
assert len(entry_block.predecessors) == 0
assert len(entry_block.successors) == 1
assert middle_block == entry_block.successors[0]
assert len(middle_block.predecessors) == 1
assert entry_block == middle_block.predecessors[0]
assert len(middle_block.successors) == 1
assert successor_block == middle_block.successors[0]
assert len(successor_block.predecessors) == 1
assert middle_block == successor_block.predecessors[0]
assert len(successor_block.successors) == 0
# CHECK-LABEL: TEST: testBlockCreationArgLocs
@run