[mlir] Extend moveValueDefinitions/moveOperationDependencies with cross-region support (#176343)

Extends `moveValueDefinitions` and `moveOperationDependencies` to
support moving operations across basic blocks and out of nested regions
This commit is contained in:
Jorn Tuyls 2026-02-02 11:39:02 +01:00 committed by GitHub
parent f288f463ad
commit f84c3672c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 390 additions and 48 deletions

View File

@ -71,11 +71,18 @@ SmallVector<Value> makeRegionIsolatedFromAbove(
llvm::function_ref<bool(Operation *)> cloneOperationIntoRegion =
[](Operation *) { return false; });
/// Move SSA values used within an operation before an insertion point,
/// so that the operation itself (or its replacement) can be moved to
/// the insertion point. Current support is only for movement of
/// dependencies of `op` before `insertionPoint` in the same basic block.
/// Any side-effecting operations in the dependency chain pessimistically
/// Move the operation dependencies (producers) of `op` before `insertionPoint`,
/// so that `op` itself can subsequently be moved. This includes transitive
/// dependencies. Supports movement within the same block or from nested regions
/// to an outer block.
///
/// The following conditions cause the move to fail:
/// - `insertionPoint` does not dominate `op`.
/// - Movement across an isolated-from-above region boundary.
/// - A dependency uses a block argument that wouldn't dominate
/// `insertionPoint`.
/// - `insertionPoint` is itself a dependency of `op` (cycle).
/// - Any side-effecting operations in the dependency chain pessimistically
/// blocks movement.
LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op,
Operation *insertionPoint,
@ -83,11 +90,19 @@ LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op,
LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op,
Operation *insertionPoint);
/// Move definitions of `values` before an insertion point. Current support is
/// only for movement of definitions within the same basic block. Note that this
/// is an all-or-nothing approach. Either definitions of all values are moved
/// before insertion point, or none of them are. Any side-effecting operations
/// in the producer chain pessimistically blocks movement.
/// Move definitions of `values` (and their transitive dependencies) before
/// `insertionPoint`. Supports movement within the same block or from nested
/// regions to an outer block.
///
/// This is all-or-nothing: either all definitions are moved, or none are.
///
/// The following conditions cause the move to fail:
/// - Any value is a block argument (cannot be moved).
/// - Any side-effecting operations in the dependency chain.
/// - Movement across an isolated-from-above region boundary.
/// - A dependency uses a block argument that wouldn't dominate
/// `insertionPoint`.
/// - `insertionPoint` is itself a dependency (cycle).
LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values,
Operation *insertionPoint,
DominanceInfo &dominance);

View File

@ -1096,31 +1096,105 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
// Move operation dependencies
//===---------------------------------------------------------------------===//
/// Check if moving operations in the slice before `insertionPoint` would break
/// dominance due to block argument operands. Returns true if all block args
/// dominate the insertion point (no issue), false otherwise. If `failingOp` is
/// provided, it will be set to the first problematic op.
///
/// For operands defined by ops: either the defining op is in the slice (so
/// dominance preserved), or it already dominates insertionPoint (otherwise it
/// would be in the slice). So we only need to check block argument operands,
/// both as direct operands and as values captured inside regions.
static bool blockArgsDominateInsertionPoint(
const llvm::SetVector<Operation *> &slice, Operation *insertionPoint,
DominanceInfo &dominance, Operation **failingOp = nullptr) {
Block *insertionBlock = insertionPoint->getBlock();
// Returns true if the block arg dominates, false otherwise. Sets failingOp
// on failure.
auto argDominates = [&](BlockArgument arg, Operation *op) {
Block *argBlock = arg.getOwner();
bool dominates = argBlock == insertionBlock ||
dominance.dominates(argBlock, insertionBlock);
if (!dominates && failingOp)
*failingOp = op;
return dominates;
};
for (Operation *op : slice) {
// Check direct operands.
for (Value operand : op->getOperands()) {
auto arg = dyn_cast<BlockArgument>(operand);
if (!arg)
continue;
if (!argDominates(arg, op))
return false;
}
// Check block arguments captured inside regions. Process one region at a
// time to enable early exit without collecting values from all regions.
for (Region &region : op->getRegions()) {
SetVector<Value> capturedValues;
getUsedValuesDefinedAbove(region, region, capturedValues);
for (Value val : capturedValues) {
auto arg = dyn_cast<BlockArgument>(val);
if (!arg)
continue;
if (!argDominates(arg, op))
return false;
}
}
}
return true;
}
/// Check if any region between an operation and an ancestor block is
/// isolated from above. If so, moving the operation out would break
/// the isolation semantics.
static bool hasIsolatedRegionBetween(Operation *op, Block *ancestorBlock) {
Region *ancestorRegion = ancestorBlock->getParent();
// Walk up from the op's region to find if there's an isolated region
// between the op and the ancestor.
Region *region = op->getParentRegion();
while (region && region != ancestorRegion) {
Operation *parentOp = region->getParentOp();
if (!parentOp)
break;
if (parentOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
return true;
region = parentOp->getParentRegion();
}
return false;
}
LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter,
Operation *op,
Operation *insertionPoint,
DominanceInfo &dominance) {
// Currently unsupported case where the op and insertion point are
// in different basic blocks.
if (op->getBlock() != insertionPoint->getBlock()) {
return rewriter.notifyMatchFailure(
op, "unsupported case where operation and insertion point are not in "
"the same basic block");
}
// If `insertionPoint` does not dominate `op`, do nothing
Block *insertionBlock = insertionPoint->getBlock();
// If `insertionPoint` does not dominate `op`, do nothing.
if (!dominance.properlyDominates(insertionPoint, op)) {
return rewriter.notifyMatchFailure(op,
"insertion point does not dominate op");
}
// Verify we're not crossing an isolated region.
if (hasIsolatedRegionBetween(op, insertionBlock)) {
return rewriter.notifyMatchFailure(
op, "cannot move operation across isolated-from-above region");
}
// Find the backward slice of operation for each `Value` the operation
// depends on. Prune the slice to only include operations not already
// dominated by the `insertionPoint`.
BackwardSliceOptions options;
options.inclusive = false;
options.omitUsesFromAbove = false;
// Since current support is to only move within a same basic block,
// the slices dont need to look past block arguments.
// Block arguments cannot be moved; dominance check handles this case.
options.omitBlockArguments = true;
bool dependsOnSideEffectingOp = false;
options.filter = [&](Operation *sliceBoundaryOp) {
@ -1158,6 +1232,15 @@ LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter,
"cannot move dependencies before operation in backward slice of op");
}
// Verify no operation in the slice uses a block argument that wouldn't
// dominate at the new location.
Operation *badOp = nullptr;
if (!blockArgsDominateInsertionPoint(slice, insertionPoint, dominance,
&badOp)) {
return rewriter.notifyMatchFailure(
badOp, "moving op would break dominance for block argument operand");
}
// We should move the slice in topological order, but `getBackwardSlice`
// already does that. So no need to sort again.
for (Operation *op : slice) {
@ -1188,13 +1271,26 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
insertionPoint,
"unsupported case of moving block argument before insertion point");
}
// Check for currently unsupported case if the insertion point is in a
// different block.
if (value.getDefiningOp()->getBlock() != insertionPoint->getBlock()) {
Block *insertionBlock = insertionPoint->getBlock();
Operation *definingOp = value.getDefiningOp();
Block *definingBlock = definingOp->getBlock();
// Verify we're not crossing an isolated region.
if (hasIsolatedRegionBetween(definingOp, insertionBlock)) {
return rewriter.notifyMatchFailure(
insertionPoint,
"unsupported case of moving definition of value before an insertion "
"point in a different basic block");
"cannot move value definition across isolated-from-above region");
}
// Verify the insertion point's block dominates the defining block,
// otherwise we're trying to move "backwards" in the CFG which doesn't
// make sense.
if (!dominance.dominates(insertionBlock, definingBlock)) {
return rewriter.notifyMatchFailure(
insertionPoint,
"insertion point block does not dominate the value's defining "
"block");
}
prunedValues.push_back(value);
}
@ -1205,8 +1301,9 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
BackwardSliceOptions options;
options.inclusive = true;
options.omitUsesFromAbove = false;
// Since current support is to only move within a same basic block,
// the slices dont need to look past block arguments.
// Block arguments cannot be moved, so we stop the slice computation there.
// If an op uses a block argument that wouldn't dominate at the new location,
// the dominance check will catch it.
options.omitBlockArguments = true;
bool dependsOnSideEffectingOp = false;
options.filter = [&](Operation *sliceBoundaryOp) {
@ -1243,9 +1340,20 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
"cannot move dependencies before operation in backward slice of op");
}
// Sort operations topologically before moving.
// Sort operations topologically. This is needed because we call
// getBackwardSlice multiple times (once per value), and the combined slice
// may not be in topological order when independent subgraphs interleave.
mlir::topologicalSort(slice);
// Verify no operation in the slice uses a block argument that wouldn't
// dominate at the new location.
Operation *badOp = nullptr;
if (!blockArgsDominateInsertionPoint(slice, insertionPoint, dominance,
&badOp)) {
return rewriter.notifyMatchFailure(
badOp, "moving op would break dominance for block argument operand");
}
for (Operation *op : slice)
rewriter.moveOpBefore(op, insertionPoint);
return success();

View File

@ -454,24 +454,28 @@ module attributes {transform.with_named_sequence} {
// -----
// Do not move across basic blocks
func.func @no_move_across_basic_blocks() -> (index, index) {
%0 = "unmoved_op"() : () -> (index)
%1 = "before"() : () -> (index)
cf.br ^bb0(%0 : index)
^bb0(%arg0 : index) :
%2 = arith.addi %arg0, %arg0 {moved_op} : index
return %1, %2 : index, index
// Successfully move operation between blocks in the same region.
// The function argument %arg0 dominates all blocks, so the move is valid.
func.func @move_between_blocks_same_region(%arg0 : index, %cond : i1) -> index {
%0 = "before"() : () -> (index)
cf.cond_br %cond, ^bb1, ^bb2(%arg0 : index)
^bb1:
%1 = arith.addi %arg0, %arg0 {to_move} : index
cf.br ^bb2(%1 : index)
^bb2(%result : index):
return %result : index
}
// CHECK-LABEL: func @move_between_blocks_same_region
// CHECK: %[[MOVED:.+]] = arith.addi {{.*}} {to_move}
// CHECK: "before"
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match ops{["arith.addi"]} in %arg0
%op2 = transform.structured.match attributes{to_move} in %arg0
: (!transform.any_op) -> !transform.any_op
%v1 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
// expected-remark@+1{{unsupported case of moving definition of value before an insertion point in a different basic block}}
transform.test.move_value_defns %v1 before %op1
: (!transform.any_value), !transform.any_op
transform.yield
@ -480,22 +484,36 @@ module attributes {transform.with_named_sequence} {
// -----
func.func @move_isolated_from_above(%arg0 : index) -> () {
%1 = "before"() : () -> (index)
%2 = arith.addi %arg0, %arg0 {moved0} : index
%3 = arith.muli %2, %2 {moved1} : index
return
//===----------------------------------------------------------------------===//
// Cross-region move tests
//===----------------------------------------------------------------------===//
// Move multiple values with dependencies out of nested region
func.func @move_chain_out_of_region(%arg0 : index, %cond : i1) -> index {
%0 = "before"() : () -> (index)
%1 = scf.if %cond -> index {
%2 = arith.addi %arg0, %arg0 {dep1} : index
%3 = arith.muli %2, %2 {dep2} : index
%4 = arith.subi %3, %arg0 {to_move} : index
scf.yield %4 : index
} else {
scf.yield %arg0 : index
}
return %1 : index
}
// CHECK-LABEL: func @move_isolated_from_above(
// CHECK: %[[MOVED0:.+]] = arith.addi {{.*}} {moved0}
// CHECK: %[[MOVED1:.+]] = arith.muli %[[MOVED0]], %[[MOVED0]] {moved1}
// CHECK: %[[BEFORE:.+]] = "before"
// CHECK-LABEL: func @move_chain_out_of_region(
// CHECK: arith.addi {{.*}} {dep1}
// CHECK: arith.muli {{.*}} {dep2}
// CHECK: arith.subi {{.*}} {to_move}
// CHECK: "before"
// CHECK: scf.if
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match ops{["arith.muli"]} in %arg0
%op2 = transform.structured.match attributes{to_move} in %arg0
: (!transform.any_op) -> !transform.any_op
%v1 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
transform.test.move_value_defns %v1 before %op1
@ -551,3 +569,204 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
// -----
// Can move op using outer loop's IV when staying within outer loop
func.func @move_op_using_outer_loop_iv(%lb : index, %ub : index, %step : index) -> index {
%result = scf.for %outer_iv = %lb to %ub step %step iter_args(%acc = %lb) -> index {
%before = "before"() : () -> (index)
scf.for %inner_iv = %lb to %ub step %step {
// Uses outer_iv which dominates within the outer loop body
%x = arith.addi %outer_iv, %outer_iv {to_move} : index
"use"(%x) : (index) -> ()
}
scf.yield %acc : index
}
return %result : index
}
// CHECK-LABEL: func @move_op_using_outer_loop_iv(
// CHECK: scf.for %[[OUTER_IV:[a-zA-Z0-9_]+]] =
// CHECK: arith.addi %[[OUTER_IV]], %[[OUTER_IV]] {to_move}
// CHECK: "before"
// CHECK: scf.for
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match attributes{to_move} in %arg0
: (!transform.any_op) -> !transform.any_op
%v1 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
transform.test.move_value_defns %v1 before %op1
: (!transform.any_value), !transform.any_op
transform.yield
}
}
// -----
// Move out of doubly nested non-isolated region
func.func @move_out_of_doubly_nested_region(%arg0 : index, %cond1 : i1, %cond2 : i1) -> index {
%0 = "before"() : () -> (index)
%1 = scf.if %cond1 -> index {
%2 = scf.if %cond2 -> index {
%3 = arith.addi %arg0, %arg0 {to_move} : index
scf.yield %3 : index
} else {
scf.yield %arg0 : index
}
scf.yield %2 : index
} else {
scf.yield %arg0 : index
}
return %1 : index
}
// CHECK-LABEL: func @move_out_of_doubly_nested_region(
// CHECK: %[[MOVED:.+]] = arith.addi {{.*}} {to_move}
// CHECK: %[[BEFORE:.+]] = "before"
// CHECK: scf.if
// CHECK: scf.if
// CHECK: scf.yield %[[MOVED]]
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match attributes{to_move} in %arg0
: (!transform.any_op) -> !transform.any_op
%v1 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
transform.test.move_value_defns %v1 before %op1
: (!transform.any_value), !transform.any_op
transform.yield
}
}
// -----
// Move operand deps out of nested region
func.func @move_operand_deps_out_of_region(%arg0 : index, %cond : i1) -> index {
%0 = "before"() : () -> (index)
%1 = scf.if %cond -> index {
%2 = arith.addi %arg0, %arg0 {dep} : index
%3 = "foo"(%2) {target} : (index) -> (index)
scf.yield %3 : index
} else {
scf.yield %arg0 : index
}
return %1 : index
}
// CHECK-LABEL: func @move_operand_deps_out_of_region(
// CHECK: %[[DEP:.+]] = arith.addi {{.*}} {dep}
// CHECK: %[[BEFORE:.+]] = "before"
// CHECK: scf.if
// CHECK: "foo"(%[[DEP]]) {target}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["foo"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
transform.test.move_operand_deps %op1 before %op2
: !transform.any_op, !transform.any_op
transform.yield
}
}
// -----
// Cannot move op that depends on loop induction variable (block argument)
func.func @cannot_move_op_using_loop_iv(%arg0 : index, %lb : index, %ub : index, %step : index) -> index {
%0 = "before"() : () -> (index)
%1 = scf.for %iv = %lb to %ub step %step iter_args(%acc = %arg0) -> index {
%2 = arith.addi %iv, %iv {to_move} : index
%3 = arith.addi %acc, %2 : index
scf.yield %3 : index
}
return %1 : index
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match attributes{to_move} in %arg0
: (!transform.any_op) -> !transform.any_op
%v1 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
// expected-remark@+1{{moving op would break dominance for block argument operand}}
transform.test.move_value_defns %v1 before %op1
: (!transform.any_value), !transform.any_op
transform.yield
}
}
// -----
// Cannot move out of an isolated-from-above region, even when op is in a
// non-isolated region nested inside the isolated region
func.func @cannot_move_out_of_isolated_region(%arg0 : index, %cond : i1) -> index {
%0 = "before"() : () -> (index)
%1 = "test.isolated_one_region_op"(%arg0, %cond) ({
^bb0(%inner_arg: index, %inner_cond: i1):
// scf.if is NOT isolated, but it's inside an isolated region
%2 = scf.if %inner_cond -> index {
%3 = arith.addi %inner_arg, %inner_arg {to_move} : index
scf.yield %3 : index
} else {
scf.yield %inner_arg : index
}
"test.region_yield"(%2) : (index) -> ()
}) : (index, i1) -> (index)
return %1 : index
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match ops{["before"]} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match attributes{to_move} in %arg0
: (!transform.any_op) -> !transform.any_op
%v1 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
// expected-remark@+1{{cannot move value definition across isolated-from-above region}}
transform.test.move_value_defns %v1 before %op1
: (!transform.any_value), !transform.any_op
transform.yield
}
}
// -----
// Fail when trying to move an operation whose region captures a block argument
// that wouldn't dominate at the insertion point.
func.func @captured_block_arg_does_not_dominate(%arg0 : f32, %cond : i1) -> f32 {
%0 = arith.addf %arg0, %arg0 {before} : f32
cf.br ^bb1(%0 : f32)
^bb1(%bbArg : f32):
// scf.if will be part of the slice that needs to move.
// It has a region that captures %bbArg from bb1.
// Moving it before the {before} op in the entry block would be invalid
// because %bbArg (a block argument of bb1) doesn't dominate the entry block.
%1 = scf.if %cond -> f32 {
%inner = arith.addf %bbArg, %bbArg : f32
scf.yield %inner : f32
} else {
scf.yield %bbArg : f32
}
%2 = arith.mulf %1, %1 {target} : f32
return %2 : f32
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
%op1 = transform.structured.match attributes{before} in %arg0
: (!transform.any_op) -> !transform.any_op
%op2 = transform.structured.match attributes{target} in %arg0
: (!transform.any_op) -> !transform.any_op
%v1 = transform.get_result %op2[0] : (!transform.any_op) -> !transform.any_value
// expected-remark@+1{{moving op would break dominance for block argument operand}}
transform.test.move_value_defns %v1 before %op1
: (!transform.any_value), !transform.any_op
transform.yield
}
}