[mlir][python] bind block predecessors and successors (#145116)
bind `block.getSuccessor` and `block.getPredecessors`.
This commit is contained in:
parent
bc5e5c0114
commit
a2aa812a31
@ -986,6 +986,24 @@ MLIR_CAPI_EXPORTED MlirValue mlirBlockGetArgument(MlirBlock block,
|
||||
MLIR_CAPI_EXPORTED void
|
||||
mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData);
|
||||
|
||||
/// Returns the number of successor blocks of the block.
|
||||
MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumSuccessors(MlirBlock block);
|
||||
|
||||
/// Returns `pos`-th successor of the block.
|
||||
MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetSuccessor(MlirBlock block,
|
||||
intptr_t pos);
|
||||
|
||||
/// Returns the number of predecessor blocks of the block.
|
||||
MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumPredecessors(MlirBlock block);
|
||||
|
||||
/// Returns `pos`-th predecessor of the block.
|
||||
///
|
||||
/// WARNING: This getter is more expensive than the others here because
|
||||
/// the impl actually iterates the use-def chain (of block operands) anew for
|
||||
/// each indexed access.
|
||||
MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetPredecessor(MlirBlock block,
|
||||
intptr_t pos);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Value API.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -2626,6 +2626,88 @@ private:
|
||||
PyOperationRef operation;
|
||||
};
|
||||
|
||||
/// A list of block successors. Internally, these are stored as consecutive
|
||||
/// elements, random access is cheap. The (returned) successor list is
|
||||
/// associated with the operation and block whose successors these are, and thus
|
||||
/// extends the lifetime of this operation and block.
|
||||
class PyBlockSuccessors : public Sliceable<PyBlockSuccessors, PyBlock> {
|
||||
public:
|
||||
static constexpr const char *pyClassName = "BlockSuccessors";
|
||||
|
||||
PyBlockSuccessors(PyBlock block, PyOperationRef operation,
|
||||
intptr_t startIndex = 0, intptr_t length = -1,
|
||||
intptr_t step = 1)
|
||||
: Sliceable(startIndex,
|
||||
length == -1 ? mlirBlockGetNumSuccessors(block.get())
|
||||
: length,
|
||||
step),
|
||||
operation(operation), block(block) {}
|
||||
|
||||
private:
|
||||
/// Give the parent CRTP class access to hook implementations below.
|
||||
friend class Sliceable<PyBlockSuccessors, PyBlock>;
|
||||
|
||||
intptr_t getRawNumElements() {
|
||||
block.checkValid();
|
||||
return mlirBlockGetNumSuccessors(block.get());
|
||||
}
|
||||
|
||||
PyBlock getRawElement(intptr_t pos) {
|
||||
MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos);
|
||||
return PyBlock(operation, block);
|
||||
}
|
||||
|
||||
PyBlockSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
|
||||
return PyBlockSuccessors(block, operation, startIndex, length, step);
|
||||
}
|
||||
|
||||
PyOperationRef operation;
|
||||
PyBlock block;
|
||||
};
|
||||
|
||||
/// A list of block predecessors. The (returned) predecessor list is
|
||||
/// associated with the operation and block whose predecessors these are, and
|
||||
/// thus extends the lifetime of this operation and block.
|
||||
///
|
||||
/// WARNING: This Sliceable is more expensive than the others here because
|
||||
/// mlirBlockGetPredecessor actually iterates the use-def chain (of block
|
||||
/// operands) anew for each indexed access.
|
||||
class PyBlockPredecessors : public Sliceable<PyBlockPredecessors, PyBlock> {
|
||||
public:
|
||||
static constexpr const char *pyClassName = "BlockPredecessors";
|
||||
|
||||
PyBlockPredecessors(PyBlock block, PyOperationRef operation,
|
||||
intptr_t startIndex = 0, intptr_t length = -1,
|
||||
intptr_t step = 1)
|
||||
: Sliceable(startIndex,
|
||||
length == -1 ? mlirBlockGetNumPredecessors(block.get())
|
||||
: length,
|
||||
step),
|
||||
operation(operation), block(block) {}
|
||||
|
||||
private:
|
||||
/// Give the parent CRTP class access to hook implementations below.
|
||||
friend class Sliceable<PyBlockPredecessors, PyBlock>;
|
||||
|
||||
intptr_t getRawNumElements() {
|
||||
block.checkValid();
|
||||
return mlirBlockGetNumPredecessors(block.get());
|
||||
}
|
||||
|
||||
PyBlock getRawElement(intptr_t pos) {
|
||||
MlirBlock block = mlirBlockGetPredecessor(this->block.get(), pos);
|
||||
return PyBlock(operation, block);
|
||||
}
|
||||
|
||||
PyBlockPredecessors slice(intptr_t startIndex, intptr_t length,
|
||||
intptr_t step) {
|
||||
return PyBlockPredecessors(block, operation, startIndex, length, step);
|
||||
}
|
||||
|
||||
PyOperationRef operation;
|
||||
PyBlock block;
|
||||
};
|
||||
|
||||
/// A list of operation attributes. Can be indexed by name, producing
|
||||
/// attributes, or by index, producing named attributes.
|
||||
class PyOpAttributeMap {
|
||||
@ -3655,7 +3737,19 @@ void mlir::python::populateIRCore(nb::module_ &m) {
|
||||
},
|
||||
nb::arg("operation"),
|
||||
"Appends an operation to this block. If the operation is currently "
|
||||
"in another block, it will be moved.");
|
||||
"in another block, it will be moved.")
|
||||
.def_prop_ro(
|
||||
"successors",
|
||||
[](PyBlock &self) {
|
||||
return PyBlockSuccessors(self, self.getParentOperation());
|
||||
},
|
||||
"Returns the list of Block successors.")
|
||||
.def_prop_ro(
|
||||
"predecessors",
|
||||
[](PyBlock &self) {
|
||||
return PyBlockPredecessors(self, self.getParentOperation());
|
||||
},
|
||||
"Returns the list of Block predecessors.");
|
||||
|
||||
//----------------------------------------------------------------------------
|
||||
// Mapping of PyInsertionPoint.
|
||||
@ -4099,6 +4193,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
|
||||
PyBlockArgumentList::bind(m);
|
||||
PyBlockIterator::bind(m);
|
||||
PyBlockList::bind(m);
|
||||
PyBlockSuccessors::bind(m);
|
||||
PyBlockPredecessors::bind(m);
|
||||
PyOperationIterator::bind(m);
|
||||
PyOperationList::bind(m);
|
||||
PyOpAttributeMap::bind(m);
|
||||
|
@ -1059,6 +1059,26 @@ void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
|
||||
unwrap(block)->print(stream);
|
||||
}
|
||||
|
||||
intptr_t mlirBlockGetNumSuccessors(MlirBlock block) {
|
||||
return static_cast<intptr_t>(unwrap(block)->getNumSuccessors());
|
||||
}
|
||||
|
||||
MlirBlock mlirBlockGetSuccessor(MlirBlock block, intptr_t pos) {
|
||||
return wrap(unwrap(block)->getSuccessor(static_cast<unsigned>(pos)));
|
||||
}
|
||||
|
||||
intptr_t mlirBlockGetNumPredecessors(MlirBlock block) {
|
||||
Block *b = unwrap(block);
|
||||
return static_cast<intptr_t>(std::distance(b->pred_begin(), b->pred_end()));
|
||||
}
|
||||
|
||||
MlirBlock mlirBlockGetPredecessor(MlirBlock block, intptr_t pos) {
|
||||
Block *b = unwrap(block);
|
||||
Block::pred_iterator it = b->pred_begin();
|
||||
std::advance(it, pos);
|
||||
return wrap(*it);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Value API.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -2440,6 +2440,74 @@ void testDiagnostics(void) {
|
||||
mlirContextDestroy(ctx);
|
||||
}
|
||||
|
||||
int testBlockPredecessorsSuccessors(MlirContext ctx) {
|
||||
// CHECK-LABEL: @testBlockPredecessorsSuccessors
|
||||
fprintf(stderr, "@testBlockPredecessorsSuccessors\n");
|
||||
|
||||
const char *moduleString = "module {\n"
|
||||
" func.func @test(%arg0: i32, %arg1: i16) {\n"
|
||||
" cf.br ^bb1(%arg1 : i16)\n"
|
||||
" ^bb1(%0: i16): // pred: ^bb0\n"
|
||||
" cf.br ^bb2(%arg0 : i32)\n"
|
||||
" ^bb2(%1: i32): // pred: ^bb1\n"
|
||||
" return\n"
|
||||
" }\n"
|
||||
"}\n";
|
||||
|
||||
MlirModule module =
|
||||
mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
|
||||
|
||||
MlirOperation moduleOp = mlirModuleGetOperation(module);
|
||||
MlirRegion moduleRegion = mlirOperationGetRegion(moduleOp, 0);
|
||||
MlirBlock moduleBlock = mlirRegionGetFirstBlock(moduleRegion);
|
||||
MlirOperation function = mlirBlockGetFirstOperation(moduleBlock);
|
||||
MlirRegion funcRegion = mlirOperationGetRegion(function, 0);
|
||||
MlirBlock entryBlock = mlirRegionGetFirstBlock(funcRegion);
|
||||
MlirBlock middleBlock = mlirBlockGetNextInRegion(entryBlock);
|
||||
MlirBlock successorBlock = mlirBlockGetNextInRegion(middleBlock);
|
||||
|
||||
#define FPRINTF_OP(OP, FMT) fprintf(stderr, #OP ": " FMT "\n", OP)
|
||||
|
||||
// CHECK: mlirBlockGetNumPredecessors(entryBlock): 0
|
||||
FPRINTF_OP(mlirBlockGetNumPredecessors(entryBlock), "%ld");
|
||||
|
||||
// CHECK: mlirBlockGetNumSuccessors(entryBlock): 1
|
||||
FPRINTF_OP(mlirBlockGetNumSuccessors(entryBlock), "%ld");
|
||||
// CHECK: mlirBlockEqual(middleBlock, mlirBlockGetSuccessor(entryBlock, 0)): 1
|
||||
FPRINTF_OP(mlirBlockEqual(middleBlock, mlirBlockGetSuccessor(entryBlock, 0)),
|
||||
"%d");
|
||||
// CHECK: mlirBlockGetNumPredecessors(middleBlock): 1
|
||||
FPRINTF_OP(mlirBlockGetNumPredecessors(middleBlock), "%ld");
|
||||
// CHECK: mlirBlockEqual(entryBlock, mlirBlockGetPredecessor(middleBlock, 0))
|
||||
FPRINTF_OP(
|
||||
mlirBlockEqual(entryBlock, mlirBlockGetPredecessor(middleBlock, 0)),
|
||||
"%d");
|
||||
|
||||
// CHECK: mlirBlockGetNumSuccessors(middleBlock): 1
|
||||
FPRINTF_OP(mlirBlockGetNumSuccessors(middleBlock), "%ld");
|
||||
// CHECK: BlockEqual(successorBlock, mlirBlockGetSuccessor(middleBlock, 0)): 1
|
||||
fprintf(
|
||||
stderr,
|
||||
"BlockEqual(successorBlock, mlirBlockGetSuccessor(middleBlock, 0)): %d\n",
|
||||
mlirBlockEqual(successorBlock, mlirBlockGetSuccessor(middleBlock, 0)));
|
||||
// CHECK: mlirBlockGetNumPredecessors(successorBlock): 1
|
||||
FPRINTF_OP(mlirBlockGetNumPredecessors(successorBlock), "%ld");
|
||||
// CHECK: Equal(middleBlock, mlirBlockGetPredecessor(successorBlock, 0)): 1
|
||||
fprintf(
|
||||
stderr,
|
||||
"Equal(middleBlock, mlirBlockGetPredecessor(successorBlock, 0)): %d\n",
|
||||
mlirBlockEqual(middleBlock, mlirBlockGetPredecessor(successorBlock, 0)));
|
||||
|
||||
// CHECK: mlirBlockGetNumSuccessors(successorBlock): 0
|
||||
FPRINTF_OP(mlirBlockGetNumSuccessors(successorBlock), "%ld");
|
||||
|
||||
#undef FPRINTF_OP
|
||||
|
||||
mlirModuleDestroy(module);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(void) {
|
||||
MlirContext ctx = mlirContextCreate();
|
||||
registerAllUpstreamDialects(ctx);
|
||||
@ -2486,6 +2554,9 @@ int main(void) {
|
||||
testExplicitThreadPools();
|
||||
testDiagnostics();
|
||||
|
||||
if (testBlockPredecessorsSuccessors(ctx))
|
||||
return 17;
|
||||
|
||||
// CHECK: DESTROY MAIN CONTEXT
|
||||
// CHECK: reportResourceDelete: resource_i64_blob
|
||||
fprintf(stderr, "DESTROY MAIN CONTEXT\n");
|
||||
|
@ -1,12 +1,11 @@
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
import gc
|
||||
import io
|
||||
import itertools
|
||||
from mlir.ir import *
|
||||
|
||||
from mlir.dialects import builtin
|
||||
from mlir.dialects import cf
|
||||
from mlir.dialects import func
|
||||
from mlir.ir import *
|
||||
|
||||
|
||||
def run(f):
|
||||
@ -54,10 +53,25 @@ def testBlockCreation():
|
||||
with InsertionPoint(middle_block) as middle_ip:
|
||||
assert middle_ip.block == middle_block
|
||||
cf.BranchOp([i32_arg], dest=successor_block)
|
||||
|
||||
module.print(enable_debug_info=True)
|
||||
# Ensure region back references are coherent.
|
||||
assert entry_block.region == middle_block.region == successor_block.region
|
||||
|
||||
assert len(entry_block.predecessors) == 0
|
||||
|
||||
assert len(entry_block.successors) == 1
|
||||
assert middle_block == entry_block.successors[0]
|
||||
assert len(middle_block.predecessors) == 1
|
||||
assert entry_block == middle_block.predecessors[0]
|
||||
|
||||
assert len(middle_block.successors) == 1
|
||||
assert successor_block == middle_block.successors[0]
|
||||
assert len(successor_block.predecessors) == 1
|
||||
assert middle_block == successor_block.predecessors[0]
|
||||
|
||||
assert len(successor_block.successors) == 0
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testBlockCreationArgLocs
|
||||
@run
|
||||
|
Loading…
x
Reference in New Issue
Block a user