[mlir][bufferization] Fix crash with copy-before-write + bufferize-function-boundaries (#186446)

When `copy-before-write=1` is combined with
`bufferize-function-boundaries=1`, `bufferizeOp` creates a plain
`AnalysisState` (not `OneShotAnalysisState`) and passes it to
`insertTensorCopies`. Walking `CallOp`s during conflict resolution
called `getCalledFunction(callOp, state)`, which unconditionally cast
the `AnalysisState` to `OneShotAnalysisState` via `static_cast`, causing
UB and a stack overflow crash.

Fix by guarding the cast with `isa<OneShotAnalysisState>()` so that when
the state is a plain `AnalysisState`, the function falls through to
building a fresh `SymbolTableCollection` — the same safe fallback
already present.

Fixes https://github.com/llvm/llvm-project/issues/163052

Assisted-by: Claude Code
This commit is contained in:
Mehdi Amini 2026-03-17 12:56:48 +01:00 committed by GitHub
parent a26077ee5f
commit 04cc7523ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 24 additions and 6 deletions

View File

@ -101,12 +101,14 @@ static FuncOp getCalledFunction(CallOpInterface callOp,
/// Return the FuncOp called by `callOp`.
static FuncOp getCalledFunction(CallOpInterface callOp,
const AnalysisState &state) {
auto &oneShotAnalysisState = static_cast<const OneShotAnalysisState &>(state);
if (auto *funcAnalysisState =
oneShotAnalysisState.getExtension<FuncAnalysisState>()) {
// Use the cached symbol tables.
return getCalledFunction(callOp, funcAnalysisState->symbolTables);
if (isa<OneShotAnalysisState>(state)) {
auto &oneShotAnalysisState =
static_cast<const OneShotAnalysisState &>(state);
if (auto *funcAnalysisState =
oneShotAnalysisState.getExtension<FuncAnalysisState>()) {
// Use the cached symbol tables.
return getCalledFunction(callOp, funcAnalysisState->symbolTables);
}
}
SymbolTableCollection symbolTables;

View File

@ -0,0 +1,16 @@
// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries=1 copy-before-write=1" | FileCheck %s
// Regression test for https://github.com/llvm/llvm-project/issues/163052
// copy-before-write=1 + bufferize-function-boundaries=1 with a call to a
// private (declaration-only) function used to crash with a stack overflow due
// to an invalid cast of AnalysisState to OneShotAnalysisState inside
// getCalledFunction().
// CHECK-LABEL: func.func private @callee(memref<64xf32
// CHECK-LABEL: func.func @caller
// CHECK: call @callee
func.func private @callee(tensor<64xf32>)
func.func @caller(%A : tensor<64xf32>) {
call @callee(%A) : (tensor<64xf32>) -> ()
return
}