[ORC] Replace ThreadSafeContext::getContext with withContextDo. (#146819)

This removes ThreadSafeContext::Lock, ThreadSafeContext::getLock, and
ThreadSafeContext::getContext, and replaces them with a
ThreadSafeContext::withContextDo method (and const override).

The new method can be used to access an existing
ThreadSafeContext-wrapped LLVMContext in a safe way:

ThreadSafeContext TSCtx = ... ;
TSCtx.withContextDo([](LLVMContext *Ctx) {
  // this closure has exclusive access to Ctx.
});

The new API enforces correct locking, whereas the old APIs relied on
manual locking (which almost no in-tree code preformed, relying instead
on incidental exclusive access to the ThreadSafeContext).
This commit is contained in:
Lang Hames 2025-07-03 17:03:39 +10:00 committed by GitHub
parent 9234d07752
commit 0bfa0bcd79
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 151 additions and 164 deletions

View File

@ -373,8 +373,11 @@ Interpreter::Interpreter(std::unique_ptr<CompilerInstance> Instance,
auto LLVMCtx = std::make_unique<llvm::LLVMContext>(); auto LLVMCtx = std::make_unique<llvm::LLVMContext>();
TSCtx = std::make_unique<llvm::orc::ThreadSafeContext>(std::move(LLVMCtx)); TSCtx = std::make_unique<llvm::orc::ThreadSafeContext>(std::move(LLVMCtx));
Act = std::make_unique<IncrementalAction>(*CI, *TSCtx->getContext(), ErrOut, Act = TSCtx->withContextDo([&](llvm::LLVMContext *Ctx) {
*this, std::move(Consumer)); return std::make_unique<IncrementalAction>(*CI, *Ctx, ErrOut, *this,
std::move(Consumer));
});
if (ErrOut) if (ErrOut)
return; return;
CI->ExecuteAction(*Act); CI->ExecuteAction(*Act);
@ -495,10 +498,10 @@ Interpreter::createWithCUDA(std::unique_ptr<CompilerInstance> CI,
std::unique_ptr<Interpreter> Interp = std::move(*InterpOrErr); std::unique_ptr<Interpreter> Interp = std::move(*InterpOrErr);
llvm::Error Err = llvm::Error::success(); llvm::Error Err = llvm::Error::success();
llvm::LLVMContext &LLVMCtx = *Interp->TSCtx->getContext();
auto DeviceAct = auto DeviceAct = Interp->TSCtx->withContextDo([&](llvm::LLVMContext *Ctx) {
std::make_unique<IncrementalAction>(*DCI, LLVMCtx, Err, *Interp); return std::make_unique<IncrementalAction>(*DCI, *Ctx, Err, *Interp);
});
if (Err) if (Err)
return std::move(Err); return std::move(Err);

View File

@ -169,7 +169,9 @@ Expected<ThreadSafeModule> loadModule(StringRef Path,
MemoryBufferRef BitcodeBufferRef = (**BitcodeBuffer).getMemBufferRef(); MemoryBufferRef BitcodeBufferRef = (**BitcodeBuffer).getMemBufferRef();
Expected<std::unique_ptr<Module>> M = Expected<std::unique_ptr<Module>> M =
parseBitcodeFile(BitcodeBufferRef, *TSCtx.getContext()); TSCtx.withContextDo([&](LLVMContext *Ctx) {
return parseBitcodeFile(BitcodeBufferRef, *Ctx);
});
if (!M) if (!M)
return M.takeError(); return M.takeError();

View File

@ -22,11 +22,8 @@ int handleError(LLVMErrorRef Err) {
} }
LLVMOrcThreadSafeModuleRef createDemoModule(void) { LLVMOrcThreadSafeModuleRef createDemoModule(void) {
// Create a new ThreadSafeContext and underlying LLVMContext. // Create an LLVMContext.
LLVMOrcThreadSafeContextRef TSCtx = LLVMOrcCreateNewThreadSafeContext(); LLVMContextRef Ctx = LLVMContextCreate();
// Get a reference to the underlying LLVMContext.
LLVMContextRef Ctx = LLVMOrcThreadSafeContextGetContext(TSCtx);
// Create a new LLVM module. // Create a new LLVM module.
LLVMModuleRef M = LLVMModuleCreateWithNameInContext("demo", Ctx); LLVMModuleRef M = LLVMModuleCreateWithNameInContext("demo", Ctx);
@ -57,6 +54,9 @@ LLVMOrcThreadSafeModuleRef createDemoModule(void) {
// - Free the builder. // - Free the builder.
LLVMDisposeBuilder(Builder); LLVMDisposeBuilder(Builder);
// Create a new ThreadSafeContext to hold the context.
LLVMOrcThreadSafeContextRef TSCtx = LLVMOrcCreateNewThreadSafeContext();
// Our demo module is now complete. Wrap it and our ThreadSafeContext in a // Our demo module is now complete. Wrap it and our ThreadSafeContext in a
// ThreadSafeModule. // ThreadSafeModule.
LLVMOrcThreadSafeModuleRef TSM = LLVMOrcCreateNewThreadSafeModule(M, TSCtx); LLVMOrcThreadSafeModuleRef TSM = LLVMOrcCreateNewThreadSafeModule(M, TSCtx);

View File

@ -31,8 +31,7 @@ int handleError(LLVMErrorRef Err) {
} }
LLVMOrcThreadSafeModuleRef createDemoModule(void) { LLVMOrcThreadSafeModuleRef createDemoModule(void) {
LLVMOrcThreadSafeContextRef TSCtx = LLVMOrcCreateNewThreadSafeContext(); LLVMContextRef Ctx = LLVMContextCreate();
LLVMContextRef Ctx = LLVMOrcThreadSafeContextGetContext(TSCtx);
LLVMModuleRef M = LLVMModuleCreateWithNameInContext("demo", Ctx); LLVMModuleRef M = LLVMModuleCreateWithNameInContext("demo", Ctx);
LLVMTypeRef ParamTypes[] = {LLVMInt32Type(), LLVMInt32Type()}; LLVMTypeRef ParamTypes[] = {LLVMInt32Type(), LLVMInt32Type()};
LLVMTypeRef SumFunctionType = LLVMTypeRef SumFunctionType =
@ -45,6 +44,8 @@ LLVMOrcThreadSafeModuleRef createDemoModule(void) {
LLVMValueRef SumArg1 = LLVMGetParam(SumFunction, 1); LLVMValueRef SumArg1 = LLVMGetParam(SumFunction, 1);
LLVMValueRef Result = LLVMBuildAdd(Builder, SumArg0, SumArg1, "result"); LLVMValueRef Result = LLVMBuildAdd(Builder, SumArg0, SumArg1, "result");
LLVMBuildRet(Builder, Result); LLVMBuildRet(Builder, Result);
LLVMOrcThreadSafeContextRef TSCtx =
LLVMOrcCreateNewThreadSafeContextFromLLVMContext(Ctx);
LLVMOrcThreadSafeModuleRef TSM = LLVMOrcCreateNewThreadSafeModule(M, TSCtx); LLVMOrcThreadSafeModuleRef TSM = LLVMOrcCreateNewThreadSafeModule(M, TSCtx);
LLVMOrcDisposeThreadSafeContext(TSCtx); LLVMOrcDisposeThreadSafeContext(TSCtx);
return TSM; return TSM;

View File

@ -32,8 +32,7 @@ int handleError(LLVMErrorRef Err) {
} }
LLVMOrcThreadSafeModuleRef createDemoModule(void) { LLVMOrcThreadSafeModuleRef createDemoModule(void) {
LLVMOrcThreadSafeContextRef TSCtx = LLVMOrcCreateNewThreadSafeContext(); LLVMContextRef Ctx = LLVMContextCreate();
LLVMContextRef Ctx = LLVMOrcThreadSafeContextGetContext(TSCtx);
LLVMModuleRef M = LLVMModuleCreateWithNameInContext("demo", Ctx); LLVMModuleRef M = LLVMModuleCreateWithNameInContext("demo", Ctx);
LLVMTypeRef ParamTypes[] = {LLVMInt32Type(), LLVMInt32Type()}; LLVMTypeRef ParamTypes[] = {LLVMInt32Type(), LLVMInt32Type()};
LLVMTypeRef SumFunctionType = LLVMTypeRef SumFunctionType =
@ -47,6 +46,7 @@ LLVMOrcThreadSafeModuleRef createDemoModule(void) {
LLVMValueRef Result = LLVMBuildAdd(Builder, SumArg0, SumArg1, "result"); LLVMValueRef Result = LLVMBuildAdd(Builder, SumArg0, SumArg1, "result");
LLVMBuildRet(Builder, Result); LLVMBuildRet(Builder, Result);
LLVMDisposeBuilder(Builder); LLVMDisposeBuilder(Builder);
LLVMOrcThreadSafeContextRef TSCtx = LLVMOrcCreateNewThreadSafeContext();
LLVMOrcThreadSafeModuleRef TSM = LLVMOrcCreateNewThreadSafeModule(M, TSCtx); LLVMOrcThreadSafeModuleRef TSM = LLVMOrcCreateNewThreadSafeModule(M, TSCtx);
LLVMOrcDisposeThreadSafeContext(TSCtx); LLVMOrcDisposeThreadSafeContext(TSCtx);
return TSM; return TSM;

View File

@ -67,11 +67,9 @@ const char MainMod[] =
LLVMErrorRef parseExampleModule(const char *Source, size_t Len, LLVMErrorRef parseExampleModule(const char *Source, size_t Len,
const char *Name, const char *Name,
LLVMOrcThreadSafeModuleRef *TSM) { LLVMOrcThreadSafeModuleRef *TSM) {
// Create a new ThreadSafeContext and underlying LLVMContext.
LLVMOrcThreadSafeContextRef TSCtx = LLVMOrcCreateNewThreadSafeContext();
// Get a reference to the underlying LLVMContext. // Create an LLVMContext for the Module.
LLVMContextRef Ctx = LLVMOrcThreadSafeContextGetContext(TSCtx); LLVMContextRef Ctx = LLVMContextCreate();
// Wrap Source in a MemoryBuffer // Wrap Source in a MemoryBuffer
LLVMMemoryBufferRef MB = LLVMMemoryBufferRef MB =
@ -85,6 +83,10 @@ LLVMErrorRef parseExampleModule(const char *Source, size_t Len,
// TODO: LLVMDisposeMessage(ErrMsg); // TODO: LLVMDisposeMessage(ErrMsg);
} }
// Create a new ThreadSafeContext to hold the context.
LLVMOrcThreadSafeContextRef TSCtx =
LLVMOrcCreateNewThreadSafeContextFromLLVMContext(Ctx);
// Our module is now complete. Wrap it and our ThreadSafeContext in a // Our module is now complete. Wrap it and our ThreadSafeContext in a
// ThreadSafeModule. // ThreadSafeModule.
*TSM = LLVMOrcCreateNewThreadSafeModule(M, TSCtx); *TSM = LLVMOrcCreateNewThreadSafeModule(M, TSCtx);

View File

@ -150,11 +150,8 @@ int handleError(LLVMErrorRef Err) {
} }
LLVMOrcThreadSafeModuleRef createDemoModule(void) { LLVMOrcThreadSafeModuleRef createDemoModule(void) {
// Create a new ThreadSafeContext and underlying LLVMContext. // Create an LLVMContext.
LLVMOrcThreadSafeContextRef TSCtx = LLVMOrcCreateNewThreadSafeContext(); LLVMContextRef Ctx = LLVMContextCreate();
// Get a reference to the underlying LLVMContext.
LLVMContextRef Ctx = LLVMOrcThreadSafeContextGetContext(TSCtx);
// Create a new LLVM module. // Create a new LLVM module.
LLVMModuleRef M = LLVMModuleCreateWithNameInContext("demo", Ctx); LLVMModuleRef M = LLVMModuleCreateWithNameInContext("demo", Ctx);
@ -182,6 +179,10 @@ LLVMOrcThreadSafeModuleRef createDemoModule(void) {
// - Build the return instruction. // - Build the return instruction.
LLVMBuildRet(Builder, Result); LLVMBuildRet(Builder, Result);
// Create a new ThreadSafeContext to hold the context.
LLVMOrcThreadSafeContextRef TSCtx =
LLVMOrcCreateNewThreadSafeContextFromLLVMContext(Ctx);
// Our demo module is now complete. Wrap it and our ThreadSafeContext in a // Our demo module is now complete. Wrap it and our ThreadSafeContext in a
// ThreadSafeModule. // ThreadSafeModule.
LLVMOrcThreadSafeModuleRef TSM = LLVMOrcCreateNewThreadSafeModule(M, TSCtx); LLVMOrcThreadSafeModuleRef TSM = LLVMOrcCreateNewThreadSafeModule(M, TSCtx);

View File

@ -22,11 +22,8 @@ int handleError(LLVMErrorRef Err) {
} }
LLVMOrcThreadSafeModuleRef createDemoModule(void) { LLVMOrcThreadSafeModuleRef createDemoModule(void) {
// Create a new ThreadSafeContext and underlying LLVMContext. // Create an LLVMContext.
LLVMOrcThreadSafeContextRef TSCtx = LLVMOrcCreateNewThreadSafeContext(); LLVMContextRef Ctx = LLVMContextCreate();
// Get a reference to the underlying LLVMContext.
LLVMContextRef Ctx = LLVMOrcThreadSafeContextGetContext(TSCtx);
// Create a new LLVM module. // Create a new LLVM module.
LLVMModuleRef M = LLVMModuleCreateWithNameInContext("demo", Ctx); LLVMModuleRef M = LLVMModuleCreateWithNameInContext("demo", Ctx);
@ -57,6 +54,10 @@ LLVMOrcThreadSafeModuleRef createDemoModule(void) {
// - Free the builder. // - Free the builder.
LLVMDisposeBuilder(Builder); LLVMDisposeBuilder(Builder);
// Create a new ThreadSafeContext to hold the context.
LLVMOrcThreadSafeContextRef TSCtx =
LLVMOrcCreateNewThreadSafeContextFromLLVMContext(Ctx);
// Our demo module is now complete. Wrap it and our ThreadSafeContext in a // Our demo module is now complete. Wrap it and our ThreadSafeContext in a
// ThreadSafeModule. // ThreadSafeModule.
LLVMOrcThreadSafeModuleRef TSM = LLVMOrcCreateNewThreadSafeModule(M, TSCtx); LLVMOrcThreadSafeModuleRef TSM = LLVMOrcCreateNewThreadSafeModule(M, TSCtx);

View File

@ -74,11 +74,8 @@ LLVMErrorRef applyDataLayout(void *Ctx, LLVMModuleRef M) {
LLVMErrorRef parseExampleModule(const char *Source, size_t Len, LLVMErrorRef parseExampleModule(const char *Source, size_t Len,
const char *Name, const char *Name,
LLVMOrcThreadSafeModuleRef *TSM) { LLVMOrcThreadSafeModuleRef *TSM) {
// Create a new ThreadSafeContext and underlying LLVMContext. // Create an LLVMContext.
LLVMOrcThreadSafeContextRef TSCtx = LLVMOrcCreateNewThreadSafeContext(); LLVMContextRef Ctx = LLVMContextCreate();
// Get a reference to the underlying LLVMContext.
LLVMContextRef Ctx = LLVMOrcThreadSafeContextGetContext(TSCtx);
// Wrap Source in a MemoryBuffer // Wrap Source in a MemoryBuffer
LLVMMemoryBufferRef MB = LLVMMemoryBufferRef MB =
@ -93,6 +90,9 @@ LLVMErrorRef parseExampleModule(const char *Source, size_t Len,
return Err; return Err;
} }
// Create a new ThreadSafeContext to hold the context.
LLVMOrcThreadSafeContextRef TSCtx = LLVMOrcCreateNewThreadSafeContext();
// Our module is now complete. Wrap it and our ThreadSafeContext in a // Our module is now complete. Wrap it and our ThreadSafeContext in a
// ThreadSafeModule. // ThreadSafeModule.
*TSM = LLVMOrcCreateNewThreadSafeModule(M, TSCtx); *TSM = LLVMOrcCreateNewThreadSafeModule(M, TSCtx);

View File

@ -1062,20 +1062,32 @@ LLVMErrorRef LLVMOrcCreateStaticLibrarySearchGeneratorForPath(
const char *FileName); const char *FileName);
/** /**
* Create a ThreadSafeContext containing a new LLVMContext. * Create a ThreadSafeContextRef containing a new LLVMContext.
* *
* Ownership of the underlying ThreadSafeContext data is shared: Clients * Ownership of the underlying ThreadSafeContext data is shared: Clients
* can and should dispose of their ThreadSafeContext as soon as they no longer * can and should dispose of their ThreadSafeContextRef as soon as they no
* need to refer to it directly. Other references (e.g. from ThreadSafeModules) * longer need to refer to it directly. Other references (e.g. from
* will keep the data alive as long as it is needed. * ThreadSafeModules) will keep the underlying data alive as long as it is
* needed.
*/ */
LLVMOrcThreadSafeContextRef LLVMOrcCreateNewThreadSafeContext(void); LLVMOrcThreadSafeContextRef LLVMOrcCreateNewThreadSafeContext(void);
/** /**
* Get a reference to the wrapped LLVMContext. * Create a ThreadSafeContextRef from a given LLVMContext, which must not be
* associated with any existing ThreadSafeContext.
*
* The underlying ThreadSafeContext will take ownership of the LLVMContext
* object, so clients should not free the LLVMContext passed to this
* function.
*
* Ownership of the underlying ThreadSafeContext data is shared: Clients
* can and should dispose of their ThreadSafeContextRef as soon as they no
* longer need to refer to it directly. Other references (e.g. from
* ThreadSafeModules) will keep the underlying data alive as long as it is
* needed.
*/ */
LLVMContextRef LLVMOrcThreadSafeContextRef
LLVMOrcThreadSafeContextGetContext(LLVMOrcThreadSafeContextRef TSCtx); LLVMOrcCreateNewThreadSafeContextFromLLVMContext(LLVMContextRef Ctx);
/** /**
* Dispose of a ThreadSafeContext. * Dispose of a ThreadSafeContext.

View File

@ -36,16 +36,6 @@ private:
}; };
public: public:
// RAII based lock for ThreadSafeContext.
class [[nodiscard]] Lock {
public:
Lock(std::shared_ptr<State> S) : S(std::move(S)), L(this->S->Mutex) {}
private:
std::shared_ptr<State> S;
std::unique_lock<std::recursive_mutex> L;
};
/// Construct a null context. /// Construct a null context.
ThreadSafeContext() = default; ThreadSafeContext() = default;
@ -56,17 +46,20 @@ public:
"Can not construct a ThreadSafeContext from a nullptr"); "Can not construct a ThreadSafeContext from a nullptr");
} }
/// Returns a pointer to the LLVMContext that was used to construct this template <typename Func> decltype(auto) withContextDo(Func &&F) {
/// instance, or null if the instance was default constructed. if (auto TmpS = S) {
LLVMContext *getContext() { return S ? S->Ctx.get() : nullptr; } std::lock_guard<std::recursive_mutex> Lock(TmpS->Mutex);
return F(TmpS->Ctx.get());
} else
return F((LLVMContext *)nullptr);
}
/// Returns a pointer to the LLVMContext that was used to construct this template <typename Func> decltype(auto) withContextDo(Func &&F) const {
/// instance, or null if the instance was default constructed. if (auto TmpS = S) {
const LLVMContext *getContext() const { return S ? S->Ctx.get() : nullptr; } std::lock_guard<std::recursive_mutex> Lock(TmpS->Mutex);
return F(const_cast<const LLVMContext *>(TmpS->Ctx.get()));
Lock getLock() const { } else
assert(S && "Can not lock an empty ThreadSafeContext"); return F((const LLVMContext *)nullptr);
return Lock(S);
} }
private: private:
@ -89,10 +82,7 @@ public:
// *before* the context that it depends on. // *before* the context that it depends on.
// We also need to lock the context to make sure the module tear-down // We also need to lock the context to make sure the module tear-down
// does not overlap any other work on the context. // does not overlap any other work on the context.
if (M) { TSCtx.withContextDo([this](LLVMContext *Ctx) { M = nullptr; });
auto L = TSCtx.getLock();
M = nullptr;
}
M = std::move(Other.M); M = std::move(Other.M);
TSCtx = std::move(Other.TSCtx); TSCtx = std::move(Other.TSCtx);
return *this; return *this;
@ -111,45 +101,39 @@ public:
~ThreadSafeModule() { ~ThreadSafeModule() {
// We need to lock the context while we destruct the module. // We need to lock the context while we destruct the module.
if (M) { TSCtx.withContextDo([this](LLVMContext *Ctx) { M = nullptr; });
auto L = TSCtx.getLock();
M = nullptr;
}
} }
/// Boolean conversion: This ThreadSafeModule will evaluate to true if it /// Boolean conversion: This ThreadSafeModule will evaluate to true if it
/// wraps a non-null module. /// wraps a non-null module.
explicit operator bool() const { explicit operator bool() const { return !!M; }
if (M) {
assert(TSCtx.getContext() &&
"Non-null module must have non-null context");
return true;
}
return false;
}
/// Locks the associated ThreadSafeContext and calls the given function /// Locks the associated ThreadSafeContext and calls the given function
/// on the contained Module. /// on the contained Module.
template <typename Func> decltype(auto) withModuleDo(Func &&F) { template <typename Func> decltype(auto) withModuleDo(Func &&F) {
return TSCtx.withContextDo([&](LLVMContext *) {
assert(M && "Can not call on null module"); assert(M && "Can not call on null module");
auto Lock = TSCtx.getLock();
return F(*M); return F(*M);
});
} }
/// Locks the associated ThreadSafeContext and calls the given function /// Locks the associated ThreadSafeContext and calls the given function
/// on the contained Module. /// on the contained Module.
template <typename Func> decltype(auto) withModuleDo(Func &&F) const { template <typename Func> decltype(auto) withModuleDo(Func &&F) const {
return TSCtx.withContextDo([&](const LLVMContext *) {
assert(M && "Can not call on null module"); assert(M && "Can not call on null module");
auto Lock = TSCtx.getLock();
return F(*M); return F(*M);
});
} }
/// Locks the associated ThreadSafeContext and calls the given function, /// Locks the associated ThreadSafeContext and calls the given function,
/// passing the contained std::unique_ptr<Module>. The given function should /// passing the contained std::unique_ptr<Module>. The given function should
/// consume the Module. /// consume the Module.
template <typename Func> decltype(auto) consumingModuleDo(Func &&F) { template <typename Func> decltype(auto) consumingModuleDo(Func &&F) {
auto Lock = TSCtx.getLock(); return TSCtx.withContextDo([&](LLVMContext *) {
assert(M && "Can not call on null module");
return F(std::move(M)); return F(std::move(M));
});
} }
/// Get a raw pointer to the contained module without locking the context. /// Get a raw pointer to the contained module without locking the context.

View File

@ -729,9 +729,9 @@ LLVMOrcThreadSafeContextRef LLVMOrcCreateNewThreadSafeContext(void) {
return wrap(new ThreadSafeContext(std::make_unique<LLVMContext>())); return wrap(new ThreadSafeContext(std::make_unique<LLVMContext>()));
} }
LLVMContextRef LLVMOrcThreadSafeContextRef
LLVMOrcThreadSafeContextGetContext(LLVMOrcThreadSafeContextRef TSCtx) { LLVMOrcCreateNewThreadSafeContextFromLLVMContext(LLVMContextRef Ctx) {
return wrap(unwrap(TSCtx)->getContext()); return wrap(new ThreadSafeContext(std::unique_ptr<LLVMContext>(unwrap(Ctx))));
} }
void LLVMOrcDisposeThreadSafeContext(LLVMOrcThreadSafeContextRef TSCtx) { void LLVMOrcDisposeThreadSafeContext(LLVMOrcThreadSafeContextRef TSCtx) {

View File

@ -60,8 +60,6 @@ void IRSpeculationLayer::emit(std::unique_ptr<MaterializationResponsibility> R,
ThreadSafeModule TSM) { ThreadSafeModule TSM) {
assert(TSM && "Speculation Layer received Null Module ?"); assert(TSM && "Speculation Layer received Null Module ?");
assert(TSM.getContext().getContext() != nullptr &&
"Module with null LLVMContext?");
// Instrumentation of runtime calls, lock the Module // Instrumentation of runtime calls, lock the Module
TSM.withModuleDo([this, &R](Module &M) { TSM.withModuleDo([this, &R](Module &M) {

View File

@ -53,9 +53,11 @@ ThreadSafeModule cloneToNewContext(const ThreadSafeModule &TSM,
"cloned module buffer"); "cloned module buffer");
ThreadSafeContext NewTSCtx(std::make_unique<LLVMContext>()); ThreadSafeContext NewTSCtx(std::make_unique<LLVMContext>());
auto ClonedModule = cantFail( auto ClonedModule = NewTSCtx.withContextDo([&](LLVMContext *Ctx) {
parseBitcodeFile(ClonedModuleBufferRef, *NewTSCtx.getContext())); auto TmpM = cantFail(parseBitcodeFile(ClonedModuleBufferRef, *Ctx));
ClonedModule->setModuleIdentifier(M.getName()); TmpM->setModuleIdentifier(M.getName());
return TmpM;
});
return ThreadSafeModule(std::move(ClonedModule), std::move(NewTSCtx)); return ThreadSafeModule(std::move(ClonedModule), std::move(NewTSCtx));
}); });
} }

View File

@ -878,7 +878,8 @@ static void exitOnLazyCallThroughFailure() { exit(1); }
Expected<orc::ThreadSafeModule> Expected<orc::ThreadSafeModule>
loadModule(StringRef Path, orc::ThreadSafeContext TSCtx) { loadModule(StringRef Path, orc::ThreadSafeContext TSCtx) {
SMDiagnostic Err; SMDiagnostic Err;
auto M = parseIRFile(Path, Err, *TSCtx.getContext()); auto M = TSCtx.withContextDo(
[&](LLVMContext *Ctx) { return parseIRFile(Path, Err, *Ctx); });
if (!M) { if (!M) {
std::string ErrMsg; std::string ErrMsg;
{ {

View File

@ -128,29 +128,25 @@ TEST(RTDyldObjectLinkingLayerTest, TestOverrideObjectFlags) {
}; };
// Create a module with two void() functions: foo and bar. // Create a module with two void() functions: foo and bar.
ThreadSafeContext TSCtx(std::make_unique<LLVMContext>());
ThreadSafeModule M; ThreadSafeModule M;
{ {
ModuleBuilder MB(*TSCtx.getContext(), TM->getTargetTriple().str(), "dummy"); auto Ctx = std::make_unique<LLVMContext>();
ModuleBuilder MB(*Ctx, TM->getTargetTriple().str(), "dummy");
MB.getModule()->setDataLayout(TM->createDataLayout()); MB.getModule()->setDataLayout(TM->createDataLayout());
Function *FooImpl = MB.createFunctionDecl( Function *FooImpl = MB.createFunctionDecl(
FunctionType::get(Type::getVoidTy(*TSCtx.getContext()), {}, false), FunctionType::get(Type::getVoidTy(*Ctx), {}, false), "foo");
"foo"); BasicBlock *FooEntry = BasicBlock::Create(*Ctx, "entry", FooImpl);
BasicBlock *FooEntry =
BasicBlock::Create(*TSCtx.getContext(), "entry", FooImpl);
IRBuilder<> B1(FooEntry); IRBuilder<> B1(FooEntry);
B1.CreateRetVoid(); B1.CreateRetVoid();
Function *BarImpl = MB.createFunctionDecl( Function *BarImpl = MB.createFunctionDecl(
FunctionType::get(Type::getVoidTy(*TSCtx.getContext()), {}, false), FunctionType::get(Type::getVoidTy(*Ctx), {}, false), "bar");
"bar"); BasicBlock *BarEntry = BasicBlock::Create(*Ctx, "entry", BarImpl);
BasicBlock *BarEntry =
BasicBlock::Create(*TSCtx.getContext(), "entry", BarImpl);
IRBuilder<> B2(BarEntry); IRBuilder<> B2(BarEntry);
B2.CreateRetVoid(); B2.CreateRetVoid();
M = ThreadSafeModule(MB.takeModule(), std::move(TSCtx)); M = ThreadSafeModule(MB.takeModule(), std::move(Ctx));
} }
// Create a simple stack and set the override flags option. // Create a simple stack and set the override flags option.
@ -207,21 +203,19 @@ TEST(RTDyldObjectLinkingLayerTest, TestAutoClaimResponsibilityForSymbols) {
}; };
// Create a module with two void() functions: foo and bar. // Create a module with two void() functions: foo and bar.
ThreadSafeContext TSCtx(std::make_unique<LLVMContext>());
ThreadSafeModule M; ThreadSafeModule M;
{ {
ModuleBuilder MB(*TSCtx.getContext(), TM->getTargetTriple().str(), "dummy"); auto Ctx = std::make_unique<LLVMContext>();
ModuleBuilder MB(*Ctx, TM->getTargetTriple().str(), "dummy");
MB.getModule()->setDataLayout(TM->createDataLayout()); MB.getModule()->setDataLayout(TM->createDataLayout());
Function *FooImpl = MB.createFunctionDecl( Function *FooImpl = MB.createFunctionDecl(
FunctionType::get(Type::getVoidTy(*TSCtx.getContext()), {}, false), FunctionType::get(Type::getVoidTy(*Ctx), {}, false), "foo");
"foo"); BasicBlock *FooEntry = BasicBlock::Create(*Ctx, "entry", FooImpl);
BasicBlock *FooEntry =
BasicBlock::Create(*TSCtx.getContext(), "entry", FooImpl);
IRBuilder<> B(FooEntry); IRBuilder<> B(FooEntry);
B.CreateRetVoid(); B.CreateRetVoid();
M = ThreadSafeModule(MB.takeModule(), std::move(TSCtx)); M = ThreadSafeModule(MB.takeModule(), std::move(Ctx));
} }
// Create a simple stack and set the override flags option. // Create a simple stack and set the override flags option.
@ -258,21 +252,19 @@ TEST(RTDyldObjectLinkingLayerTest, TestMemoryBufferNamePropagation) {
GTEST_SKIP(); GTEST_SKIP();
// Create a module with two void() functions: foo and bar. // Create a module with two void() functions: foo and bar.
ThreadSafeContext TSCtx(std::make_unique<LLVMContext>());
ThreadSafeModule M; ThreadSafeModule M;
{ {
ModuleBuilder MB(*TSCtx.getContext(), TM->getTargetTriple().str(), "dummy"); auto Ctx = std::make_unique<LLVMContext>();
ModuleBuilder MB(*Ctx, TM->getTargetTriple().str(), "dummy");
MB.getModule()->setDataLayout(TM->createDataLayout()); MB.getModule()->setDataLayout(TM->createDataLayout());
Function *FooImpl = MB.createFunctionDecl( Function *FooImpl = MB.createFunctionDecl(
FunctionType::get(Type::getVoidTy(*TSCtx.getContext()), {}, false), FunctionType::get(Type::getVoidTy(*Ctx), {}, false), "foo");
"foo"); BasicBlock *FooEntry = BasicBlock::Create(*Ctx, "entry", FooImpl);
BasicBlock *FooEntry =
BasicBlock::Create(*TSCtx.getContext(), "entry", FooImpl);
IRBuilder<> B1(FooEntry); IRBuilder<> B1(FooEntry);
B1.CreateRetVoid(); B1.CreateRetVoid();
M = ThreadSafeModule(MB.takeModule(), std::move(TSCtx)); M = ThreadSafeModule(MB.takeModule(), std::move(Ctx));
} }
ExecutionSession ES{std::make_unique<UnsupportedExecutorProcessControl>()}; ExecutionSession ES{std::make_unique<UnsupportedExecutorProcessControl>()};

View File

@ -172,8 +172,8 @@ TEST_F(ReOptimizeLayerTest, BasicReOptimization) {
}); });
EXPECT_THAT_ERROR(ROLayer->reigsterRuntimeFunctions(*JD), Succeeded()); EXPECT_THAT_ERROR(ROLayer->reigsterRuntimeFunctions(*JD), Succeeded());
ThreadSafeContext Ctx(std::make_unique<LLVMContext>()); auto Ctx = std::make_unique<LLVMContext>();
auto M = std::make_unique<Module>("<main>", *Ctx.getContext()); auto M = std::make_unique<Module>("<main>", *Ctx);
M->setTargetTriple(Triple(sys::getProcessTriple())); M->setTargetTriple(Triple(sys::getProcessTriple()));
(void)createRetFunction(M.get(), "main", 42); (void)createRetFunction(M.get(), "main", 42);

View File

@ -21,20 +21,21 @@ namespace {
TEST(ThreadSafeModuleTest, ContextWhollyOwnedByOneModule) { TEST(ThreadSafeModuleTest, ContextWhollyOwnedByOneModule) {
// Test that ownership of a context can be transferred to a single // Test that ownership of a context can be transferred to a single
// ThreadSafeModule. // ThreadSafeModule.
ThreadSafeContext TSCtx(std::make_unique<LLVMContext>()); auto Ctx = std::make_unique<LLVMContext>();
auto M = std::make_unique<Module>("M", *TSCtx.getContext()); auto M = std::make_unique<Module>("M", *Ctx);
ThreadSafeModule TSM(std::move(M), std::move(TSCtx)); ThreadSafeModule TSM(std::move(M), std::move(Ctx));
} }
TEST(ThreadSafeModuleTest, ContextOwnershipSharedByTwoModules) { TEST(ThreadSafeModuleTest, ContextOwnershipSharedByTwoModules) {
// Test that ownership of a context can be shared between more than one // Test that ownership of a context can be shared between more than one
// ThreadSafeModule. // ThreadSafeModule.
ThreadSafeContext TSCtx(std::make_unique<LLVMContext>()); auto Ctx = std::make_unique<LLVMContext>();
auto M1 = std::make_unique<Module>("M1", *TSCtx.getContext()); auto M1 = std::make_unique<Module>("M1", *Ctx);
auto M2 = std::make_unique<Module>("M2", *Ctx);
ThreadSafeContext TSCtx(std::move(Ctx));
ThreadSafeModule TSM1(std::move(M1), TSCtx); ThreadSafeModule TSM1(std::move(M1), TSCtx);
auto M2 = std::make_unique<Module>("M2", *TSCtx.getContext());
ThreadSafeModule TSM2(std::move(M2), std::move(TSCtx)); ThreadSafeModule TSM2(std::move(M2), std::move(TSCtx));
} }
@ -45,12 +46,14 @@ TEST(ThreadSafeModuleTest, ContextOwnershipSharedWithClient) {
{ {
// Create and destroy a module. // Create and destroy a module.
auto M1 = std::make_unique<Module>("M1", *TSCtx.getContext()); auto M1 = TSCtx.withContextDo(
[](LLVMContext *Ctx) { return std::make_unique<Module>("M1", *Ctx); });
ThreadSafeModule TSM1(std::move(M1), TSCtx); ThreadSafeModule TSM1(std::move(M1), TSCtx);
} }
// Verify that the context is still available for re-use. // Verify that the context is still available for re-use.
auto M2 = std::make_unique<Module>("M2", *TSCtx.getContext()); auto M2 = TSCtx.withContextDo(
[](LLVMContext *Ctx) { return std::make_unique<Module>("M2", *Ctx); });
ThreadSafeModule TSM2(std::move(M2), std::move(TSCtx)); ThreadSafeModule TSM2(std::move(M2), std::move(TSCtx));
} }
@ -59,59 +62,44 @@ TEST(ThreadSafeModuleTest, ThreadSafeModuleMoveAssignment) {
// to the field order) to ensure that overwriting with an empty // to the field order) to ensure that overwriting with an empty
// ThreadSafeModule does not destroy the context early. // ThreadSafeModule does not destroy the context early.
ThreadSafeContext TSCtx(std::make_unique<LLVMContext>()); ThreadSafeContext TSCtx(std::make_unique<LLVMContext>());
auto M = std::make_unique<Module>("M", *TSCtx.getContext()); auto M = TSCtx.withContextDo(
[](LLVMContext *Ctx) { return std::make_unique<Module>("M", *Ctx); });
ThreadSafeModule TSM(std::move(M), std::move(TSCtx)); ThreadSafeModule TSM(std::move(M), std::move(TSCtx));
TSM = ThreadSafeModule(); TSM = ThreadSafeModule();
} }
TEST(ThreadSafeModuleTest, BasicContextLockAPI) { TEST(ThreadSafeModuleTest, WithContextDoPreservesContext) {
// Test that basic lock API calls work. // Test that withContextDo passes through the LLVMContext that was used
ThreadSafeContext TSCtx(std::make_unique<LLVMContext>()); // to create the ThreadSafeContext.
auto M = std::make_unique<Module>("M", *TSCtx.getContext());
ThreadSafeModule TSM(std::move(M), TSCtx);
{ auto L = TSCtx.getLock(); } auto Ctx = std::make_unique<LLVMContext>();
LLVMContext *OriginalCtx = Ctx.get();
{ auto L = TSM.getContext().getLock(); } ThreadSafeContext TSCtx(std::move(Ctx));
} TSCtx.withContextDo(
[&](LLVMContext *ClosureCtx) { EXPECT_EQ(ClosureCtx, OriginalCtx); });
TEST(ThreadSafeModuleTest, ContextLockPreservesContext) {
// Test that the existence of a context lock preserves the attached
// context.
// The trick to verify this is a bit of a hack: We attach a Module
// (without the ThreadSafeModule wrapper) to the context, then verify
// that this Module destructs safely (which it will not if its context
// has been destroyed) even though all references to the context have
// been thrown away (apart from the lock).
ThreadSafeContext TSCtx(std::make_unique<LLVMContext>());
auto L = TSCtx.getLock();
auto &Ctx = *TSCtx.getContext();
auto M = std::make_unique<Module>("M", Ctx);
TSCtx = ThreadSafeContext();
} }
TEST(ThreadSafeModuleTest, WithModuleDo) { TEST(ThreadSafeModuleTest, WithModuleDo) {
// Test non-const version of withModuleDo. // Test non-const version of withModuleDo.
ThreadSafeContext TSCtx(std::make_unique<LLVMContext>()); auto Ctx = std::make_unique<LLVMContext>();
ThreadSafeModule TSM(std::make_unique<Module>("M", *TSCtx.getContext()), auto M = std::make_unique<Module>("M", *Ctx);
TSCtx); ThreadSafeModule TSM(std::move(M), std::move(Ctx));
TSM.withModuleDo([](Module &M) {}); TSM.withModuleDo([](Module &M) {});
} }
TEST(ThreadSafeModuleTest, WithModuleDoConst) { TEST(ThreadSafeModuleTest, WithModuleDoConst) {
// Test const version of withModuleDo. // Test const version of withModuleDo.
ThreadSafeContext TSCtx(std::make_unique<LLVMContext>()); auto Ctx = std::make_unique<LLVMContext>();
const ThreadSafeModule TSM(std::make_unique<Module>("M", *TSCtx.getContext()), auto M = std::make_unique<Module>("M", *Ctx);
TSCtx); const ThreadSafeModule TSM(std::move(M), std::move(Ctx));
TSM.withModuleDo([](const Module &M) {}); TSM.withModuleDo([](const Module &M) {});
} }
TEST(ThreadSafeModuleTest, ConsumingModuleDo) { TEST(ThreadSafeModuleTest, ConsumingModuleDo) {
// Test consumingModuleDo. // Test consumingModuleDo.
ThreadSafeContext TSCtx(std::make_unique<LLVMContext>()); auto Ctx = std::make_unique<LLVMContext>();
ThreadSafeModule TSM(std::make_unique<Module>("M", *TSCtx.getContext()), auto M = std::make_unique<Module>("M", *Ctx);
TSCtx); ThreadSafeModule TSM(std::move(M), std::move(Ctx));
TSM.consumingModuleDo([](std::unique_ptr<Module> M) {}); TSM.consumingModuleDo([](std::unique_ptr<Module> M) {});
} }