[MLIR][Affine] Fix dead store elimination for vector stores with different types (#189248)

affine-scalrep's findUnusedStore incorrectly classified an
affine.vector_store as dead when a subsequent store wrote to the same
base index but with a smaller vector type. A vector<1xi64> store at
[0,0] does not fully overwrite a vector<5xi64> store at [0,0], so the
first store must be preserved.

The loadCSE function in the same file already had the correct
type-equality check for loads; this patch adds the analogous check for
stores in findUnusedStore.

Fixes #113687

Assisted-by: Claude Code
This commit is contained in:
Mehdi Amini 2026-04-01 12:40:53 +02:00 committed by GitHub
parent d6cd15901a
commit f6ffdbcbae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 50 additions and 2 deletions

View File

@ -908,8 +908,9 @@ mlir::affine::hasNoInterveningEffect<mlir::MemoryEffects::Read,
// This attempts to find stores which have no impact on the final result.
// A writing op writeA will be eliminated if there exists an op writeB if
// 1) writeA and writeB have mathematically equivalent affine access functions.
// 2) writeB postdominates writeA.
// 3) There is no potential read between writeA and writeB.
// 2) writeB writes the same type as writeA (so it fully covers writeA's bytes).
// 3) writeB postdominates writeA.
// 4) There is no potential read between writeA and writeB.
static void findUnusedStore(AffineWriteOpInterface writeA,
SmallVectorImpl<Operation *> &opsToErase,
PostDominanceInfo &postDominanceInfo,
@ -936,6 +937,16 @@ static void findUnusedStore(AffineWriteOpInterface writeA,
if (srcAccess != destAccess)
continue;
// Check that the store types match. If types differ, writeB may not cover
// all bytes written by writeA (e.g. a narrower vector type), so
// conservatively assume writeA is not dead.
// One could be tempted whether writeA type is smaller than writeB, however
// it can become tricky with cases like vector<4xi6> vs vector<3xi8> due to
// padding that can be datalayout dependent.
if (writeA.getValueToStore().getType() !=
writeB.getValueToStore().getType())
continue;
// writeB must postdominate writeA.
if (!postDominanceInfo.postDominates(writeB, writeA))
continue;

View File

@ -997,3 +997,40 @@ func.func @zero_d_memrefs() {
}
return
}
// CHECK-LABEL: func @vector_store_dead_elim_different_types
// A vector_store with a different vector type at the same base index must NOT
// cause the earlier vector_store to be treated as a dead store.
// (GitHub issue #113687)
func.func @vector_store_dead_elim_different_types(%arg0: memref<20x1xi64>) {
%c0 = arith.constant 0 : index
// CHECK-DAG: %[[CST1:.+]] = arith.constant dense<1>
// CHECK-DAG: %[[CST2:.+]] = arith.constant dense<2>
%cst1 = arith.constant dense<1> : vector<5xi64>
%cst2 = arith.constant dense<2> : vector<1xi64>
// Both stores must be preserved: the second (narrow) store does not fully
// overwrite the first (wide) store.
// CHECK: affine.vector_store %[[CST1]], %arg0[%c0, %c0] : memref<20x1xi64>, vector<5xi64>
affine.vector_store %cst1, %arg0[%c0, %c0] : memref<20x1xi64>, vector<5xi64>
// CHECK: affine.vector_store %[[CST2]], %arg0[%c0, %c0] : memref<20x1xi64>, vector<1xi64>
affine.vector_store %cst2, %arg0[%c0, %c0] : memref<20x1xi64>, vector<1xi64>
return
}
// CHECK-LABEL: func @vector_store_dead_elim_same_type
// A vector_store with the same type at the same base index should still cause
// the earlier store to be treated as a dead store.
func.func @vector_store_dead_elim_same_type(%arg0: memref<20x1xi64>) {
%c0 = arith.constant 0 : index
%cst1 = arith.constant dense<1> : vector<5xi64>
%cst2 = arith.constant dense<2> : vector<5xi64>
// CHECK-DAG: %[[CST2:.+]] = arith.constant dense<2>
// The first store is dead the second store (same type, same index)
// fully overwrites it.
// CHECK-NOT: arith.constant dense<1>
// CHECK-NOT: affine.vector_store {{.*}} dense<1>
// CHECK: affine.vector_store %[[CST2]], %arg0[%c0, %c0] : memref<20x1xi64>, vector<5xi64>
affine.vector_store %cst1, %arg0[%c0, %c0] : memref<20x1xi64>, vector<5xi64>
affine.vector_store %cst2, %arg0[%c0, %c0] : memref<20x1xi64>, vector<5xi64>
return
}