[mlir][scf] Fix crash in extractFixedOuterLoops with iter_args loops (#184106)

The stripmineSink helper splices loop body operations into a new inner
scf.for that has no iter_args. When the target loop carries iter_args,
values yielded by the spliced body are moved inside the inner loop, but
the outer loop's yield terminator still references those values,
creating an SSA invariant violation. In debug builds this triggers the
assertion
  use_empty() && "Cannot destroy a value that still has uses\!"
when the outer RewriterBase tries to erase the now-broken operations.

Fix: in extractFixedOuterLoops, skip the strip-mining transformation if
any of the collected perfectly-nested loops have iter_args.

Add a regression test to parametric-tiling.mlir.

Fixes #129044

Assisted-by: Claude Code
This commit is contained in:
Mehdi Amini 2026-03-11 14:21:57 +01:00 committed by GitHub
parent 7beba38aa3
commit b78ceef43e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 28 additions and 0 deletions

View File

@ -1351,6 +1351,15 @@ TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp,
if (forOps.size() < sizes.size())
sizes = sizes.take_front(forOps.size());
// The strip-mining transformation splices loop bodies into a new inner loop
// without threading iter_args. If any of the collected loops carries
// iter_args, the splice would produce invalid IR (yielded values from the
// inner scope used in the outer terminator). Skip the transformation in
// that case.
if (llvm::any_of(forOps,
[](scf::ForOp op) { return !op.getInitArgs().empty(); }))
return {};
// Compute the tile sizes such that i-th outer loop executes size[i]
// iterations. Given that the loop current executes
// numIterations = ceildiv((upperBound - lowerBound), step)

View File

@ -127,3 +127,22 @@ func.func @triangular(%arg0: memref<?x?xf32>) {
}
return
}
// Verify that extractFixedOuterLoops silently skips loops with iter_args
// instead of producing invalid IR (regression test for
// https://github.com/llvm/llvm-project/issues/129044).
// COMMON-LABEL: @loop_with_iter_args
// COMMON: scf.for {{.*}} iter_args({{.*}}) -> (f32)
func.func @loop_with_iter_args(%buffer: memref<1024xf32>, %lb: index,
%ub: index, %step: index) -> f32 {
%initial_sum = arith.constant 0.0 : f32
// This loop has iter_args; strip-mining must not be applied.
%final_sum = scf.for %iv = %lb to %ub step %step
iter_args(%sum_iter = %initial_sum) -> (f32) {
%element = memref.load %buffer[%iv] : memref<1024xf32>
%updated_sum = arith.addf %sum_iter, %element : f32
scf.yield %updated_sum : f32
}
return %final_sum : f32
}