[mlir][memref] Generalize dead store detection to all view-like ops (#168507)
The dead alloc elimination pass previously considered only subviews when checking for dead stores. This change generalizes the logic to support all view-like operations, ensuring broader coverage.
This commit is contained in:
parent
4544ff68dc
commit
e0850825cc
@ -133,7 +133,7 @@ getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits,
|
||||
}
|
||||
|
||||
/// Returns true if all the uses of op are not read/load.
|
||||
/// There can be SubviewOp users as long as all its users are also
|
||||
/// There can be view-like-op users as long as all its users are also
|
||||
/// StoreOp/transfer_write. If return true it also fills out the uses, if it
|
||||
/// returns false uses is unchanged.
|
||||
static bool resultIsNotRead(Operation *op, std::vector<Operation *> &uses) {
|
||||
@ -146,7 +146,7 @@ static bool resultIsNotRead(Operation *op, std::vector<Operation *> &uses) {
|
||||
if (isa<memref::DeallocOp>(useOp) ||
|
||||
(useOp->getNumResults() == 0 && useOp->getNumRegions() == 0 &&
|
||||
!mlir::hasEffect<MemoryEffects::Read>(useOp)) ||
|
||||
(isa<memref::SubViewOp>(useOp) && resultIsNotRead(useOp, opUses))) {
|
||||
(isa<ViewLikeOpInterface>(useOp) && resultIsNotRead(useOp, opUses))) {
|
||||
opUses.push_back(useOp);
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -395,6 +395,73 @@ module attributes {transform.with_named_sequence} {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @dead_store_through_subview
|
||||
// CHECK-SAME: (%[[ARG:.+]]: vector<4xf32>)
|
||||
// CHECK-NOT: memref.alloc()
|
||||
// CHECK-NOT: vector.transfer_write
|
||||
func.func @dead_store_through_subview(%arg: vector<4xf32>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
%alloc = memref.alloc() {alignment = 64 : i64} : memref<64xf32>
|
||||
%subview = memref.subview %alloc[%c0] [4] [1] : memref<64xf32> to memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
|
||||
vector.transfer_write %arg, %subview[%c0] {in_bounds = [true]}
|
||||
: vector<4xf32>, memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
|
||||
return
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
transform.memref.erase_dead_alloc_and_stores %0 : (!transform.any_op) -> ()
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @dead_store_through_expand
|
||||
// CHECK-SAME: (%[[ARG:.+]]: vector<4xf32>)
|
||||
// CHECK-NOT: memref.alloc()
|
||||
// CHECK-NOT: vector.transfer_write
|
||||
func.func @dead_store_through_expand(%arg: vector<4xf32>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
%alloc = memref.alloc() {alignment = 64 : i64} : memref<64xf32>
|
||||
%expand = memref.expand_shape %alloc [[0, 1]] output_shape [16, 4] : memref<64xf32> into memref<16x4xf32>
|
||||
vector.transfer_write %arg, %expand[%c0, %c0] {in_bounds = [true]} : vector<4xf32>, memref<16x4xf32>
|
||||
return
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
transform.memref.erase_dead_alloc_and_stores %0 : (!transform.any_op) -> ()
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @dead_store_through_collapse
|
||||
// CHECK-SAME: (%[[ARG:.+]]: vector<4xf32>)
|
||||
// CHECK-NOT: memref.alloc()
|
||||
// CHECK-NOT: vector.transfer_write
|
||||
func.func @dead_store_through_collapse(%arg: vector<4xf32>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
%alloc = memref.alloc() {alignment = 64 : i64} : memref<16x4xf32>
|
||||
%collapse = memref.collapse_shape %alloc [[0, 1]] : memref<16x4xf32> into memref<64xf32>
|
||||
vector.transfer_write %arg, %collapse[%c0] {in_bounds = [true]} : vector<4xf32>, memref<64xf32>
|
||||
return
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
transform.memref.erase_dead_alloc_and_stores %0 : (!transform.any_op) -> ()
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @lower_to_llvm
|
||||
// CHECK-NOT: memref.alloc
|
||||
// CHECK: llvm.call @malloc
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user