llvm-project/mlir/test/Dialect/MLProgram/one-shot-bufferize.mlir
Nathan Malimban 689f9788d3
[ml_program] fix bufferizesToMemoryRead for ml_program.global_store (#177387)
This is a fix for the `BufferizableOpInterface` implementation for
`ml_program.global_store`.

`bufferizesToMemoryRead` currently returns false for
`GlobalStoreOpInterface`, but I believe it should return true as
`ml_program.global_store` needs to read its input buffer to know what
value to store to global.

This manifested in a bug where `one-shot-bufferize` would produce MLIR
that copies uninitialized data to the global var instead of the intended
value to be stored.

For the following MLIR:

```
module {
  ml_program.global private mutable @"state_tensor"(dense<0.0> : tensor<4x75xf32>) : tensor<4x75xf32>
  func.func @main() -> tensor<4x75xf32> {
    %c0 = arith.constant 0 : index
    %cst_val = arith.constant 1.0 : f32
    %initial_state = ml_program.global_load @"state_tensor" : tensor<4x75xf32>
    %val = tensor.extract %initial_state[%c0, %c0] : tensor<4x75xf32>
    %next_val = arith.addf %val, %cst_val : f32
    %updated_tensor = tensor.insert %next_val into %initial_state[%c0, %c0] : tensor<4x75xf32>
    ml_program.global_store @"state_tensor" = %updated_tensor : tensor<4x75xf32>
    return %updated_tensor : tensor<4x75xf32>
  }
}
```
`one-shot-bufferize` produces this incorrect MLIR
```
module {
  memref.global "private" @state_tensor : memref<4x75xf32> = dense<0.000000e+00>
  func.func @main() -> tensor<4x75xf32> {
    %c0 = arith.constant 0 : index
    %cst = arith.constant 1.000000e+00 : f32
    %0 = memref.get_global @state_tensor : memref<4x75xf32>
    %1 = memref.load %0[%c0, %c0] : memref<4x75xf32>
    %2 = arith.addf %1, %cst : f32
    %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x75xf32>
    memref.copy %0, %alloc : memref<4x75xf32> to memref<4x75xf32>
    memref.store %2, %alloc[%c0, %c0] : memref<4x75xf32>
    %3 = bufferization.to_tensor %alloc : memref<4x75xf32> to tensor<4x75xf32>
    %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<4x75xf32>
    %4 = memref.get_global @state_tensor : memref<4x75xf32>
    memref.copy %alloc_0, %4 : memref<4x75xf32> to memref<4x75xf32>
    return %3 : tensor<4x75xf32>
  }
}
```
Note that `memref.copy` at the end copies an uninitialized `alloc_0` to
the global variable.

But after the change we see the following MLIR:
```
module {
  memref.global "private" @state_tensor : memref<4x75xf32> = dense<0.000000e+00>
  func.func @main() -> tensor<4x75xf32> {
    %c0 = arith.constant 0 : index
    %cst = arith.constant 1.000000e+00 : f32
    %0 = memref.get_global @state_tensor : memref<4x75xf32>
    %1 = memref.load %0[%c0, %c0] : memref<4x75xf32>
    %2 = arith.addf %1, %cst : f32
    %alloc = memref.alloc() {alignment = 64 : i64} : memref<4x75xf32>
    memref.copy %0, %alloc : memref<4x75xf32> to memref<4x75xf32>
    memref.store %2, %alloc[%c0, %c0] : memref<4x75xf32>
    %3 = bufferization.to_tensor %alloc : memref<4x75xf32> to tensor<4x75xf32>
    %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<4x75xf32>
    memref.copy %alloc, %alloc_0 : memref<4x75xf32> to memref<4x75xf32>
    %4 = memref.get_global @state_tensor : memref<4x75xf32>
    memref.copy %alloc_0, %4 : memref<4x75xf32> to memref<4x75xf32>
    return %3 : tensor<4x75xf32>
  }
}
```
We now see that the relevant data is copied to `alloc_0` before it is
stored in global.

Co-authored-by: Nathan Malimban <nmalimba@ah-nmalimba-l.dhcp.mathworks.com>
2026-01-23 13:20:30 +01:00

84 lines
3.8 KiB
MLIR

// RUN: mlir-opt %s -one-shot-bufferize -split-input-file | FileCheck %s
// CHECK-LABEL: memref.global "private" @global
ml_program.global private mutable @global(dense<0> : tensor<i64>) : tensor<i64>
// CHECK-LABEL: func.func @global_load_store
func.func @global_load_store() -> i64 {
// CHECK-DAG: %[[CST127:.*]] = arith.constant 127
// CHECK-DAG: %[[GLOBAL_1:.*]] = memref.get_global @global
// CHECK: %[[VALUE:.*]] = memref.load %[[GLOBAL_1]][]
// CHECK: %[[NEW_VALUE:.*]] = arith.muli %[[VALUE]], %[[CST127]]
// CHECK: %[[ALLOC:.*]] = memref.alloc()
// CHECK: memref.copy %[[GLOBAL_1]], %[[ALLOC]]
// CHECK: memref.store %[[NEW_VALUE]], %[[ALLOC]][]
// CHECK: %[[GLOBAL_2:.*]] = memref.get_global @global
// CHECK: memref.copy %[[ALLOC]], %[[GLOBAL_2]]
// CHECK: return %[[NEW_VALUE]]
%c127 = arith.constant 127 : i64
%0 = ml_program.global_load @global : tensor<i64>
%extracted = tensor.extract %0[] : tensor<i64>
%1 = arith.muli %extracted, %c127 : i64
%inserted = tensor.insert %1 into %0[] : tensor<i64>
ml_program.global_store @global = %inserted : tensor<i64>
return %1 : i64
}
// -----
// CHECK-LABEL: memref.global "private" @global
ml_program.global private mutable @global(dense<0> : tensor<i64>) : tensor<i64>
// CHECK-LABEL: func.func @raw_hazard
func.func @raw_hazard() -> i64 {
// CHECK-DAG: %[[CST127:.*]] = arith.constant 127
// CHECK-DAG: %[[GLOBAL_1:.*]] = memref.get_global @global
// CHECK-DAG: %[[GLOBAL_2:.*]] = memref.get_global @global
// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc()
// CHECK: memref.copy %[[GLOBAL_1]], %[[ALLOC]]
// CHECK: memref.store %[[CST127]], %[[ALLOC]][]
// CHECK: %[[VAL:.*]] = memref.load %[[GLOBAL_2]][]
// CHECK: %[[GLOBAL_3:.*]] = memref.get_global @global
// CHECK: memref.copy %[[ALLOC]], %[[GLOBAL_3]]
// CHECK: return %[[VAL]]
%c127 = arith.constant 127 : i64
%0 = ml_program.global_load @global : tensor<i64>
%1 = ml_program.global_load @global : tensor<i64>
%inserted = tensor.insert %c127 into %0[] : tensor<i64>
%extracted = tensor.extract %1[] : tensor<i64>
ml_program.global_store @global = %inserted : tensor<i64>
return %extracted : i64
}
// -----
// CHECK-LABEL: memref.global "private" @state_tensor
ml_program.global private mutable @"state_tensor"(dense<0.0> : tensor<4x75xf32>) : tensor<4x75xf32>
// CHECK-LABEL: func.func @global_load_store_tensor
func.func @global_load_store_tensor() -> tensor<4x75xf32> {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0
// CHECK-DAG: %[[CST:.*]] = arith.constant 1.000000e+00
// CHECK-DAG: %[[GLOB:.*]] = memref.get_global @state_tensor
// CHECK: %[[VAL:.*]] = memref.load %[[GLOB]][%[[C0]], %[[C0]]]
// CHECK: %[[ADD:.*]] = arith.addf %[[VAL]], %[[CST]]
// CHECK: %[[ALLOC1:.*]] = memref.alloc() {alignment = 64 : i64}
// CHECK: memref.copy %[[GLOB]], %[[ALLOC1]]
// CHECK: memref.store %[[ADD]], %[[ALLOC1]][%[[C0]], %[[C0]]]
// CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[ALLOC1]]
// CHECK: %[[ALLOC2:.*]] = memref.alloc() {alignment = 64 : i64}
// CHECK: memref.copy %[[ALLOC1]], %[[ALLOC2]]
// CHECK: %[[GLOB_REF:.*]] = memref.get_global @state_tensor
// CHECK: memref.copy %[[ALLOC2]], %[[GLOB_REF]]
// CHECK: return %[[TENSOR]]
%c0 = arith.constant 0 : index
%cst_val = arith.constant 1.0 : f32
%initial_state = ml_program.global_load @"state_tensor" : tensor<4x75xf32>
%val = tensor.extract %initial_state[%c0, %c0] : tensor<4x75xf32>
%next_val = arith.addf %val, %cst_val : f32
%updated_tensor = tensor.insert %next_val into %initial_state[%c0, %c0] : tensor<4x75xf32>
ml_program.global_store @"state_tensor" = %updated_tensor : tensor<4x75xf32>
return %updated_tensor : tensor<4x75xf32>
}