diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp index e43ab54a048b..3aaa38272935 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -101,12 +101,14 @@ static FuncOp getCalledFunction(CallOpInterface callOp, /// Return the FuncOp called by `callOp`. static FuncOp getCalledFunction(CallOpInterface callOp, const AnalysisState &state) { - auto &oneShotAnalysisState = static_cast(state); - - if (auto *funcAnalysisState = - oneShotAnalysisState.getExtension()) { - // Use the cached symbol tables. - return getCalledFunction(callOp, funcAnalysisState->symbolTables); + if (isa(state)) { + auto &oneShotAnalysisState = + static_cast(state); + if (auto *funcAnalysisState = + oneShotAnalysisState.getExtension()) { + // Use the cached symbol tables. + return getCalledFunction(callOp, funcAnalysisState->symbolTables); + } } SymbolTableCollection symbolTables; diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-call-copy-before-write.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-call-copy-before-write.mlir new file mode 100644 index 000000000000..7addca2c9d6a --- /dev/null +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-call-copy-before-write.mlir @@ -0,0 +1,16 @@ +// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries=1 copy-before-write=1" | FileCheck %s + +// Regression test for https://github.com/llvm/llvm-project/issues/163052 +// copy-before-write=1 + bufferize-function-boundaries=1 with a call to a +// private (declaration-only) function used to crash with a stack overflow due +// to an invalid cast of AnalysisState to OneShotAnalysisState inside +// getCalledFunction(). + +// CHECK-LABEL: func.func private @callee(memref<64xf32 +// CHECK-LABEL: func.func @caller +// CHECK: call @callee +func.func private @callee(tensor<64xf32>) +func.func @caller(%A : tensor<64xf32>) { + call @callee(%A) : (tensor<64xf32>) -> () + return +}