Fix ownership based deallocation pass crash (#179357)
The `OwnershipBasedBufferDeallocation` pass crashes when the IR contains memrefs that are live in the same Block but are defined in different Blocks. During this pass, live memrefs in a given block are sorted according to the comparison function `ValueComparator`. This causes an assertion to be triggered when sorting memref values using `ValueComparator` as the comparison function. The assertion triggered is found in `Operation::isBeforeInBlock`, which requires `this` and `other` to reside in the same block. (See the definition [here](https://github.com/llvm/llvm-project/blob/main/mlir/lib/IR/Operation.cpp#L385-L386).) The fix is to handle values from different blocks in the `ValueComparator` by sorting based on Block number if the compared ops aren't in the same block. While `computeBlockNumber` is intended for debugging and error messages, it is a convenient utility that can provide a sufficient weak ordering for `llvm::sort` while handling operations from different parent blocks. I'm not aware of another ordering relation for Blocks that would be appropriate as well as cheap to compute here. I've added a test to exercise this that would fail otherwise. As I was already editing the test file, I thought I would refactor it according to the recommendations of the [MLIR Testing Guide](https://mlir.llvm.org/getting_started/TestingGuide/#contributor-guidelines) Fixes #137342 Fixes #116363
This commit is contained in:
parent
760f70711a
commit
47331ae735
@ -248,7 +248,12 @@ bool ValueComparator::operator()(const Value &lhs, const Value &rhs) const {
|
||||
lhsRegion = lhs.getDefiningOp()->getParentRegion();
|
||||
rhsRegion = rhs.getDefiningOp()->getParentRegion();
|
||||
if (lhsRegion == rhsRegion) {
|
||||
return lhs.getDefiningOp()->isBeforeInBlock(rhs.getDefiningOp());
|
||||
Block *lhsBlock = lhs.getDefiningOp()->getBlock();
|
||||
Block *rhsBlock = rhs.getDefiningOp()->getBlock();
|
||||
if (lhsBlock == rhsBlock) {
|
||||
return lhs.getDefiningOp()->isBeforeInBlock(rhs.getDefiningOp());
|
||||
}
|
||||
return lhsBlock->computeBlockNumber() < rhsBlock->computeBlockNumber();
|
||||
}
|
||||
}
|
||||
|
||||
@ -262,8 +267,14 @@ bool ValueComparator::operator()(const Value &lhs, const Value &rhs) const {
|
||||
return lhsRegion->getRegionNumber() < rhsRegion->getRegionNumber();
|
||||
}
|
||||
if (lhsRegion->getParentRegion() == rhsRegion->getParentRegion()) {
|
||||
return lhsRegion->getParentOp()->isBeforeInBlock(
|
||||
rhsRegion->getParentOp());
|
||||
Block *lhsParentOpBlock = lhsRegion->getParentOp()->getBlock();
|
||||
Block *rhsParentOpBlock = rhsRegion->getParentOp()->getBlock();
|
||||
if (lhsParentOpBlock == rhsParentOpBlock) {
|
||||
return lhsRegion->getParentOp()->isBeforeInBlock(
|
||||
rhsRegion->getParentOp());
|
||||
}
|
||||
return lhsParentOpBlock->computeBlockNumber() <
|
||||
rhsParentOpBlock->computeBlockNumber();
|
||||
}
|
||||
lhsRegion = lhsRegion->getParentRegion();
|
||||
rhsRegion = rhsRegion->getParentRegion();
|
||||
|
||||
@ -1,28 +1,72 @@
|
||||
// RUN: mlir-opt -verify-diagnostics -ownership-based-buffer-deallocation -split-input-file %s
|
||||
// RUN: mlir-opt -ownership-based-buffer-deallocation -split-input-file %s
|
||||
|
||||
// Test Case: ownership-based-buffer-deallocation should not fail
|
||||
// with cf.assert op
|
||||
|
||||
// CHECK-LABEL: func @func_with_assert(
|
||||
// CHECK: %0 = arith.cmpi slt, %arg0, %arg1 : index
|
||||
// CHECK: cf.assert %0, "%arg0 must be less than %arg1"
|
||||
// CHECK-SAME: %[[ARG0:.*]]: index,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: index
|
||||
// CHECK: %[[CMPI:.*]] = arith.cmpi slt, %[[ARG0]], %[[ARG1]]
|
||||
// CHECK: cf.assert %[[CMPI]]
|
||||
func.func @func_with_assert(%arg0: index, %arg1: index) {
|
||||
%0 = arith.cmpi slt, %arg0, %arg1 : index
|
||||
cf.assert %0, "%arg0 must be less than %arg1"
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @func_with_assume_alignment(
|
||||
// CHECK: %0 = memref.assume_alignment %arg0, 64 : memref<128xi8>
|
||||
// CHECK: %[[ARG0:.*]]: memref
|
||||
// CHECK: memref.assume_alignment %[[ARG0]], 64
|
||||
func.func @func_with_assume_alignment(%arg0: memref<128xi8>) {
|
||||
%0 = memref.assume_alignment %arg0, 64 : memref<128xi8>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @func_with_prefetch(
|
||||
// CHECK: memref.prefetch %arg0[%c0, %c0], read, locality<1>, data : memref<4x8xf32>
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref
|
||||
// CHECK: %[[ZERO:.*]] = arith.constant 0
|
||||
// CHECK: memref.prefetch %[[ARG0]][%[[ZERO]], %[[ZERO]]], read, locality<1>, data
|
||||
func.func @func_with_prefetch(%arg0: memref<4x8xf32>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
memref.prefetch %arg0[%c0, %c0], read, locality<1>, data : memref<4x8xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test Case: ownership-based-buffer-deallocation should not fail
|
||||
// with basic blocks that contain live memrefs defined
|
||||
// in other blocks
|
||||
|
||||
// CHECK-LABEL: func @func_with_multi_block_memref_liveness(
|
||||
// CHECK: %[[FIRST_ALLOC:.*]] = memref.alloc()
|
||||
// CHECK: %[[BASE_0:[^,]+]], {{.*}} = memref.extract_strided_metadata %[[FIRST_ALLOC]]
|
||||
// CHECK: bufferization.dealloc (%[[BASE_0]]
|
||||
// CHECK: ^bb1:
|
||||
// CHECK: %[[SECOND_ALLOC:.*]] = memref.alloc()
|
||||
// CHECK: %[[BASE_1:[^,]+]], {{.*}} = memref.extract_strided_metadata %[[FIRST_ALLOC]]
|
||||
// CHECK: %[[BASE_2:[^,]+]], {{.*}} = memref.extract_strided_metadata %[[SECOND_ALLOC]]
|
||||
// CHECK: bufferization.dealloc (%[[BASE_1]], %[[BASE_2]]
|
||||
// CHECK: ^bb2:
|
||||
// CHECK: "test.read_buffer"(%[[FIRST_ALLOC]])
|
||||
// CHECK: "test.read_buffer"(%[[SECOND_ALLOC]])
|
||||
// CHECK: %[[BASE_3:[^,]+]], {{.*}} = memref.extract_strided_metadata %[[FIRST_ALLOC]]
|
||||
// CHECK: %[[BASE_4:[^,]+]], {{.*}} = memref.extract_strided_metadata %[[SECOND_ALLOC]]
|
||||
// CHECK: bufferization.dealloc (%[[BASE_3]], %[[BASE_4]]
|
||||
module {
|
||||
func.func @func_with_multi_block_memref_liveness() {
|
||||
%alloc = memref.alloc() : memref<3x3xf32>
|
||||
cf.br ^bb1
|
||||
^bb1: // pred: ^bb0
|
||||
%alloc_1 = memref.alloc() : memref<4x4xf32>
|
||||
cf.br ^bb2
|
||||
^bb2: // 1 pred: ^bb1
|
||||
"test.read_buffer"(%alloc) : (memref<3x3xf32>) -> ()
|
||||
"test.read_buffer"(%alloc_1) : (memref<4x4xf32>) -> ()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user