[mlir][memref] Generalize dead store detection to all view-like ops (#168507)

The dead alloc elimination pass previously considered only subviews when
checking for dead stores. This change generalizes the logic to support
all view-like operations, ensuring broader coverage.
This commit is contained in:
Simone Pellegrini 2025-11-20 15:10:03 +01:00 committed by GitHub
parent 4544ff68dc
commit e0850825cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 69 additions and 2 deletions

View File

@ -133,7 +133,7 @@ getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits,
}
/// Returns true if all the uses of op are not read/load.
/// There can be SubviewOp users as long as all its users are also
/// There can be view-like-op users as long as all its users are also
/// StoreOp/transfer_write. If return true it also fills out the uses, if it
/// returns false uses is unchanged.
static bool resultIsNotRead(Operation *op, std::vector<Operation *> &uses) {
@ -146,7 +146,7 @@ static bool resultIsNotRead(Operation *op, std::vector<Operation *> &uses) {
if (isa<memref::DeallocOp>(useOp) ||
(useOp->getNumResults() == 0 && useOp->getNumRegions() == 0 &&
!mlir::hasEffect<MemoryEffects::Read>(useOp)) ||
(isa<memref::SubViewOp>(useOp) && resultIsNotRead(useOp, opUses))) {
(isa<ViewLikeOpInterface>(useOp) && resultIsNotRead(useOp, opUses))) {
opUses.push_back(useOp);
continue;
}

View File

@ -395,6 +395,73 @@ module attributes {transform.with_named_sequence} {
// -----
// CHECK-LABEL: @dead_store_through_subview
// CHECK-SAME: (%[[ARG:.+]]: vector<4xf32>)
// CHECK-NOT: memref.alloc()
// CHECK-NOT: vector.transfer_write
func.func @dead_store_through_subview(%arg: vector<4xf32>) {
%c0 = arith.constant 0 : index
%alloc = memref.alloc() {alignment = 64 : i64} : memref<64xf32>
%subview = memref.subview %alloc[%c0] [4] [1] : memref<64xf32> to memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
vector.transfer_write %arg, %subview[%c0] {in_bounds = [true]}
: vector<4xf32>, memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
return
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.memref.erase_dead_alloc_and_stores %0 : (!transform.any_op) -> ()
transform.yield
}
}
// -----
// CHECK-LABEL: @dead_store_through_expand
// CHECK-SAME: (%[[ARG:.+]]: vector<4xf32>)
// CHECK-NOT: memref.alloc()
// CHECK-NOT: vector.transfer_write
func.func @dead_store_through_expand(%arg: vector<4xf32>) {
%c0 = arith.constant 0 : index
%alloc = memref.alloc() {alignment = 64 : i64} : memref<64xf32>
%expand = memref.expand_shape %alloc [[0, 1]] output_shape [16, 4] : memref<64xf32> into memref<16x4xf32>
vector.transfer_write %arg, %expand[%c0, %c0] {in_bounds = [true]} : vector<4xf32>, memref<16x4xf32>
return
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.memref.erase_dead_alloc_and_stores %0 : (!transform.any_op) -> ()
transform.yield
}
}
// -----
// CHECK-LABEL: @dead_store_through_collapse
// CHECK-SAME: (%[[ARG:.+]]: vector<4xf32>)
// CHECK-NOT: memref.alloc()
// CHECK-NOT: vector.transfer_write
func.func @dead_store_through_collapse(%arg: vector<4xf32>) {
%c0 = arith.constant 0 : index
%alloc = memref.alloc() {alignment = 64 : i64} : memref<16x4xf32>
%collapse = memref.collapse_shape %alloc [[0, 1]] : memref<16x4xf32> into memref<64xf32>
vector.transfer_write %arg, %collapse[%c0] {in_bounds = [true]} : vector<4xf32>, memref<64xf32>
return
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.memref.erase_dead_alloc_and_stores %0 : (!transform.any_op) -> ()
transform.yield
}
}
// -----
// CHECK-LABEL: func @lower_to_llvm
// CHECK-NOT: memref.alloc
// CHECK: llvm.call @malloc