diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp index 5406a51d2ab7..a87cb1562513 100644 --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -157,8 +157,9 @@ SmallVector mlir::makeRegionIsolatedFromAbove( // Create a mapping between the captured values and the new arguments added. IRMapping map; auto replaceIfFn = [&](OpOperand &use) { - return use.getOwner()->getBlock()->getParent() == ®ion; + return region.isAncestor(use.getOwner()->getParentRegion()); }; + for (auto [arg, capturedVal] : llvm::zip(newEntryBlockArgs.take_back(finalCapturedValues.size()), finalCapturedValues)) { diff --git a/mlir/test/Transforms/make-isolated-from-above.mlir b/mlir/test/Transforms/make-isolated-from-above.mlir index a9d4325944fd..3b0084d6e000 100644 --- a/mlir/test/Transforms/make-isolated-from-above.mlir +++ b/mlir/test/Transforms/make-isolated-from-above.mlir @@ -113,3 +113,28 @@ func.func @make_isolated_from_above_multiple_blocks(%arg0 : index, %arg1 : index // CLONE2-NEXT: cf.br ^bb1 // CLONE2: ^bb1: // CLONE2: "foo.yield"(%[[C0]], %[[C1]], %[[D0]], %[[D1]], %[[B0]]) + + +// ----- + +// CHECK-LABEL: func @make_isolated_from_above_nested_region +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<8xindex> +// CHECK: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[C8:.+]] = arith.constant 8 : index +// CHECK: test.isolated_one_region_op %[[C1]], %[[ARG0]], %[[C8]] +// CHECK: ^bb0(%[[B0:[a-zA-Z0-9]+]]: index, %[[B1:[a-zA-Z0-9]+]]: memref<8xindex>, %[[B2:[a-zA-Z0-9]+]]: index) +// CHECK: scf.for %arg4 = %[[B0]] to %[[B2]] step %[[B0]] +// CHECK: memref.store %[[B0]], %[[B1]][%arg4] : memref<8xindex> +// CHECK: "foo.yield"() : () -> () + +func.func @make_isolated_from_above_nested_region(%arg0 : memref<8xindex>) { + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + "test.one_region_with_operands_op"() ({ + scf.for %arg1 = %c1 to %c8 step %c1 { + memref.store %c1, %arg0[%arg1] : memref<8xindex> + } + "foo.yield"() : () -> () + }) : () -> () + return +}