From 04cc7523ed6a84ddff4f1b08e53e408a3a259ea8 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Tue, 17 Mar 2026 12:56:48 +0100 Subject: [PATCH] [mlir][bufferization] Fix crash with copy-before-write + bufferize-function-boundaries (#186446) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When `copy-before-write=1` is combined with `bufferize-function-boundaries=1`, `bufferizeOp` creates a plain `AnalysisState` (not `OneShotAnalysisState`) and passes it to `insertTensorCopies`. Walking `CallOp`s during conflict resolution called `getCalledFunction(callOp, state)`, which unconditionally cast the `AnalysisState` to `OneShotAnalysisState` via `static_cast`, causing UB and a stack overflow crash. Fix by guarding the cast with `isa()` so that when the state is a plain `AnalysisState`, the function falls through to building a fresh `SymbolTableCollection` — the same safe fallback already present. Fixes https://github.com/llvm/llvm-project/issues/163052 Assisted-by: Claude Code --- .../FuncBufferizableOpInterfaceImpl.cpp | 14 ++++++++------ ...-module-bufferize-call-copy-before-write.mlir | 16 ++++++++++++++++ 2 files changed, 24 insertions(+), 6 deletions(-) create mode 100644 mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-call-copy-before-write.mlir diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp index e43ab54a048b..3aaa38272935 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -101,12 +101,14 @@ static FuncOp getCalledFunction(CallOpInterface callOp, /// Return the FuncOp called by `callOp`. static FuncOp getCalledFunction(CallOpInterface callOp, const AnalysisState &state) { - auto &oneShotAnalysisState = static_cast(state); - - if (auto *funcAnalysisState = - oneShotAnalysisState.getExtension()) { - // Use the cached symbol tables. - return getCalledFunction(callOp, funcAnalysisState->symbolTables); + if (isa(state)) { + auto &oneShotAnalysisState = + static_cast(state); + if (auto *funcAnalysisState = + oneShotAnalysisState.getExtension()) { + // Use the cached symbol tables. + return getCalledFunction(callOp, funcAnalysisState->symbolTables); + } } SymbolTableCollection symbolTables; diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-call-copy-before-write.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-call-copy-before-write.mlir new file mode 100644 index 000000000000..7addca2c9d6a --- /dev/null +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-call-copy-before-write.mlir @@ -0,0 +1,16 @@ +// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries=1 copy-before-write=1" | FileCheck %s + +// Regression test for https://github.com/llvm/llvm-project/issues/163052 +// copy-before-write=1 + bufferize-function-boundaries=1 with a call to a +// private (declaration-only) function used to crash with a stack overflow due +// to an invalid cast of AnalysisState to OneShotAnalysisState inside +// getCalledFunction(). + +// CHECK-LABEL: func.func private @callee(memref<64xf32 +// CHECK-LABEL: func.func @caller +// CHECK: call @callee +func.func private @callee(tensor<64xf32>) +func.func @caller(%A : tensor<64xf32>) { + call @callee(%A) : (tensor<64xf32>) -> () + return +}