Revert "[MLIR] Split ExecutionEngine Initialization out of ctor into an explicit method call" (#153477)

Reverts llvm/llvm-project#153373

Sanitizer bot is broken
This commit is contained in:
Mehdi Amini 2025-08-13 21:43:04 +02:00 committed by GitHub
parent 85cd3d9868
commit bfd490e0cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 12 additions and 141 deletions

View File

@ -46,13 +46,6 @@ MLIR_CAPI_EXPORTED MlirExecutionEngine mlirExecutionEngineCreate(
MlirModule op, int optLevel, int numPaths, MlirModule op, int optLevel, int numPaths,
const MlirStringRef *sharedLibPaths, bool enableObjectDump); const MlirStringRef *sharedLibPaths, bool enableObjectDump);
/// Initialize the ExecutionEngine. Global constructors specified by
/// `llvm.mlir.global_ctors` will be run. One common scenario is that kernel
/// binary compiled from `gpu.module` gets loaded during initialization. Make
/// sure all symbols are resolvable before initialization by calling
/// `mlirExecutionEngineRegisterSymbol` or including shared libraries.
MLIR_CAPI_EXPORTED void mlirExecutionEngineInitialize(MlirExecutionEngine jit);
/// Destroy an ExecutionEngine instance. /// Destroy an ExecutionEngine instance.
MLIR_CAPI_EXPORTED void mlirExecutionEngineDestroy(MlirExecutionEngine jit); MLIR_CAPI_EXPORTED void mlirExecutionEngineDestroy(MlirExecutionEngine jit);

View File

@ -227,13 +227,6 @@ public:
llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)> llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
symbolMap); symbolMap);
/// Initialize the ExecutionEngine. Global constructors specified by
/// `llvm.mlir.global_ctors` will be run. One common scenario is that kernel
/// binary compiled from `gpu.module` gets loaded during initialization. Make
/// sure all symbols are resolvable before initialization by calling
/// `registerSymbols` or including shared libraries.
void initialize();
private: private:
/// Ordering of llvmContext and jit is important for destruction purposes: the /// Ordering of llvmContext and jit is important for destruction purposes: the
/// jit must be destroyed before the context. /// jit must be destroyed before the context.
@ -257,8 +250,6 @@ private:
/// Destroy functions in the libraries loaded by the ExecutionEngine that are /// Destroy functions in the libraries loaded by the ExecutionEngine that are
/// called when this ExecutionEngine is destructed. /// called when this ExecutionEngine is destructed.
SmallVector<LibraryDestroyFn> destroyFns; SmallVector<LibraryDestroyFn> destroyFns;
bool isInitialized = false;
}; };
} // namespace mlir } // namespace mlir

View File

@ -7,8 +7,8 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "mlir-c/ExecutionEngine.h" #include "mlir-c/ExecutionEngine.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h" #include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"
namespace nb = nanobind; namespace nb = nanobind;
using namespace mlir; using namespace mlir;
@ -124,17 +124,6 @@ NB_MODULE(_mlirExecutionEngine, m) {
}, },
nb::arg("name"), nb::arg("callback"), nb::arg("name"), nb::arg("callback"),
"Register `callback` as the runtime symbol `name`.") "Register `callback` as the runtime symbol `name`.")
.def(
"initialize",
[](PyExecutionEngine &executionEngine) {
mlirExecutionEngineInitialize(executionEngine.get());
},
"Initialize the ExecutionEngine. Global constructors specified by "
"`llvm.mlir.global_ctors` will be run. One common scenario is that "
"kernel binary compiled from `gpu.module` gets loaded during "
"initialization. Make sure all symbols are resolvable before "
"initialization by calling `raw_register_runtime` or including "
"shared libraries.")
.def( .def(
"dump_to_object_file", "dump_to_object_file",
[](PyExecutionEngine &executionEngine, const std::string &fileName) { [](PyExecutionEngine &executionEngine, const std::string &fileName) {

View File

@ -68,10 +68,6 @@ mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths,
return wrap(jitOrError->release()); return wrap(jitOrError->release());
} }
extern "C" void mlirExecutionEngineInitialize(MlirExecutionEngine jit) {
unwrap(jit)->initialize();
}
extern "C" void mlirExecutionEngineDestroy(MlirExecutionEngine jit) { extern "C" void mlirExecutionEngineDestroy(MlirExecutionEngine jit) {
delete (unwrap(jit)); delete (unwrap(jit));
} }
@ -110,8 +106,9 @@ extern "C" void mlirExecutionEngineRegisterSymbol(MlirExecutionEngine jit,
void *sym) { void *sym) {
unwrap(jit)->registerSymbols([&](llvm::orc::MangleAndInterner interner) { unwrap(jit)->registerSymbols([&](llvm::orc::MangleAndInterner interner) {
llvm::orc::SymbolMap symbolMap; llvm::orc::SymbolMap symbolMap;
symbolMap[interner(unwrap(name))] = {llvm::orc::ExecutorAddr::fromPtr(sym), symbolMap[interner(unwrap(name))] =
llvm::JITSymbolFlags::Exported}; { llvm::orc::ExecutorAddr::fromPtr(sym),
llvm::JITSymbolFlags::Exported };
return symbolMap; return symbolMap;
}); });
} }

View File

@ -400,6 +400,13 @@ ExecutionEngine::create(Operation *m, const ExecutionEngineOptions &options,
return symbolMap; return symbolMap;
}; };
engine->registerSymbols(runtimeSymbolMap); engine->registerSymbols(runtimeSymbolMap);
// Execute the global constructors from the module being processed.
// TODO: Allow JIT initialize for AArch64. Currently there's a bug causing a
// crash for AArch64 see related issue #71963.
if (!engine->jit->getTargetTriple().isAArch64())
cantFail(engine->jit->initialize(engine->jit->getMainJITDylib()));
return std::move(engine); return std::move(engine);
} }
@ -435,7 +442,6 @@ Expected<void *> ExecutionEngine::lookup(StringRef name) const {
Error ExecutionEngine::invokePacked(StringRef name, Error ExecutionEngine::invokePacked(StringRef name,
MutableArrayRef<void *> args) { MutableArrayRef<void *> args) {
initialize();
auto expectedFPtr = lookupPacked(name); auto expectedFPtr = lookupPacked(name);
if (!expectedFPtr) if (!expectedFPtr)
return expectedFPtr.takeError(); return expectedFPtr.takeError();
@ -445,13 +451,3 @@ Error ExecutionEngine::invokePacked(StringRef name,
return Error::success(); return Error::success();
} }
void ExecutionEngine::initialize() {
if (isInitialized)
return;
// TODO: Allow JIT initialize for AArch64. Currently there's a bug causing a
// crash for AArch64 see related issue #71963.
if (!jit->getTargetTriple().isAArch64())
cantFail(jit->initialize(jit->getMainJITDylib()));
isInitialized = true;
}

View File

@ -202,8 +202,6 @@ compileAndExecute(Options &options, Operation *module, StringRef entryPoint,
auto engine = std::move(*expectedEngine); auto engine = std::move(*expectedEngine);
engine->initialize();
auto expectedFPtr = engine->lookupPacked(entryPoint); auto expectedFPtr = engine->lookupPacked(entryPoint);
if (!expectedFPtr) if (!expectedFPtr)
return expectedFPtr.takeError(); return expectedFPtr.takeError();

View File

@ -137,60 +137,6 @@ void testOmpCreation(void) {
mlirContextDestroy(ctx); mlirContextDestroy(ctx);
} }
// Helper variable to track callback invocations
static int initCnt = 0;
// Callback function that will be called during JIT initialization
static void initCallback(void) { initCnt += 1; }
// CHECK-LABEL: Running test 'testGlobalCtorJitCallback'
void testGlobalCtorJitCallback(void) {
MlirContext ctx = mlirContextCreate();
registerAllUpstreamDialects(ctx);
// Create module with global constructor that calls our callback
MlirModule module = mlirModuleCreateParse(
ctx, mlirStringRefCreateFromCString(
// clang-format off
"module { \n"
" llvm.mlir.global_ctors ctors = [@ctor], priorities = [0 : i32], data = [#llvm.zero] \n"
" llvm.func @ctor() { \n"
" func.call @init_callback() : () -> () \n"
" llvm.return \n"
" } \n"
" func.func private @init_callback() attributes { llvm.emit_c_interface } \n"
"} \n"
// clang-format on
));
lowerModuleToLLVM(ctx, module);
mlirRegisterAllLLVMTranslations(ctx);
// Create execution engine with initialization disabled
MlirExecutionEngine jit = mlirExecutionEngineCreate(
module, /*optLevel=*/2, /*numPaths=*/0, /*sharedLibPaths=*/NULL,
/*enableObjectDump=*/false);
if (mlirExecutionEngineIsNull(jit)) {
fprintf(stderr, "Execution engine creation failed");
exit(2);
}
// Register callback symbol before initialization
mlirExecutionEngineRegisterSymbol(
jit, mlirStringRefCreateFromCString("_mlir_ciface_init_callback"),
(void *)(uintptr_t)initCallback);
mlirExecutionEngineInitialize(jit);
// CHECK: Init count: 1
printf("Init count: %d\n", initCnt);
mlirExecutionEngineDestroy(jit);
mlirModuleDestroy(module);
mlirContextDestroy(ctx);
}
int main(void) { int main(void) {
#define _STRINGIFY(x) #x #define _STRINGIFY(x) #x
@ -201,6 +147,5 @@ int main(void) {
TEST(testSimpleExecution); TEST(testSimpleExecution);
TEST(testOmpCreation); TEST(testOmpCreation);
TEST(testGlobalCtorJitCallback);
return 0; return 0;
} }

View File

@ -322,42 +322,4 @@ TEST(NativeMemRefJit, MAYBE_JITCallback) {
ASSERT_EQ(elt, coefficient * count++); ASSERT_EQ(elt, coefficient * count++);
} }
static int initCnt = 0;
// A helper function that will be called during the JIT's initialization.
static void initCallback() { initCnt += 1; }
TEST(GlobalCtorJit, MAYBE_JITCallback) {
std::string moduleStr = R"mlir(
llvm.mlir.global_ctors ctors = [@ctor], priorities = [0 : i32], data = [#llvm.zero]
llvm.func @ctor() {
func.call @init_callback() : () -> ()
llvm.return
}
func.func private @init_callback() attributes { llvm.emit_c_interface }
)mlir";
DialectRegistry registry;
registerAllDialects(registry);
registerBuiltinDialectTranslation(registry);
registerLLVMDialectTranslation(registry);
MLIRContext context(registry);
auto module = parseSourceString<ModuleOp>(moduleStr, &context);
ASSERT_TRUE(!!module);
ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
ExecutionEngineOptions jitOptions;
auto jitOrError = ExecutionEngine::create(*module, jitOptions);
ASSERT_TRUE(!!jitOrError);
auto jit = std::move(jitOrError.get());
// Define any extra symbols so they're available at initialization.
jit->registerSymbols([&](llvm::orc::MangleAndInterner interner) {
llvm::orc::SymbolMap symbolMap;
symbolMap[interner("_mlir_ciface_init_callback")] = {
llvm::orc::ExecutorAddr::fromPtr(initCallback),
llvm::JITSymbolFlags::Exported};
return symbolMap;
});
jit->initialize();
ASSERT_EQ(initCnt, 1);
}
#endif // _WIN32 #endif // _WIN32