#include "llvm/ExecutionEngine/Orc/ReOptimizeLayer.h" #include "llvm/ExecutionEngine/Orc/Mangling.h" using namespace llvm; using namespace orc; bool ReOptimizeLayer::ReOptMaterializationUnitState::tryStartReoptimize() { std::unique_lock Lock(Mutex); if (Reoptimizing) return false; Reoptimizing = true; return true; } void ReOptimizeLayer::ReOptMaterializationUnitState::reoptimizeSucceeded() { std::unique_lock Lock(Mutex); assert(Reoptimizing && "Tried to mark unstarted reoptimization as done"); Reoptimizing = false; CurVersion++; } void ReOptimizeLayer::ReOptMaterializationUnitState::reoptimizeFailed() { std::unique_lock Lock(Mutex); assert(Reoptimizing && "Tried to mark unstarted reoptimization as done"); Reoptimizing = false; } static void orc_rt_lite_reoptimize_helper( shared::CWrapperFunctionBuffer (*JITDispatch)(void *Ctx, void *Tag, const char *Data, size_t Size), void *JITDispatchCtx, void *Tag, uint64_t MUID, uint32_t CurVersion) { // Serialize the arguments into a WrapperFunctionBuffer and call dispatch. using SPSArgs = shared::SPSArgList; auto ArgBytes = shared::WrapperFunctionBuffer::allocate(SPSArgs::size(MUID, CurVersion)); shared::SPSOutputBuffer OB(ArgBytes.data(), ArgBytes.size()); if (!SPSArgs::serialize(OB, MUID, CurVersion)) { errs() << "Reoptimization error: could not serialize reoptimization arguments"; abort(); } shared::WrapperFunctionBuffer Buf{ JITDispatch(JITDispatchCtx, Tag, ArgBytes.data(), ArgBytes.size())}; if (const char *ErrMsg = Buf.getOutOfBandError()) { errs() << "Reoptimization error: " << ErrMsg << "\naborting.\n"; abort(); } } Error ReOptimizeLayer::addOrcRTLiteSupport(JITDylib &PlatformJD, const DataLayout &DL) { auto Ctx = std::make_unique(); auto Mod = std::make_unique("orc-rt-lite-reoptimize.ll", *Ctx); Mod->setDataLayout(DL); IRBuilder<> Builder(*Ctx); // Create basic types portably Type *VoidTy = Type::getVoidTy(*Ctx); Type *Int8Ty = Type::getInt8Ty(*Ctx); Type *Int32Ty = Type::getInt32Ty(*Ctx); Type *Int64Ty = Type::getInt64Ty(*Ctx); Type *VoidPtrTy = PointerType::getUnqual(*Ctx); // Helper function type: void (void*, void*, void*, uint64_t, uint32_t) FunctionType *HelperFnTy = FunctionType::get( VoidTy, {VoidPtrTy, VoidPtrTy, VoidPtrTy, Int64Ty, Int32Ty}, false); // Define ReoptimizeTag with initializer = 0 GlobalVariable *ReoptimizeTag = new GlobalVariable( *Mod, Int8Ty, false, GlobalValue::ExternalLinkage, ConstantInt::get(Int8Ty, 0), "__orc_rt_reoptimize_tag"); // Define orc_rt_lite_reoptimize function: void (uint64_t, uint32_t) FunctionType *ReOptimizeFnTy = FunctionType::get(VoidTy, {Int64Ty, Int32Ty}, false); Function *ReOptimizeFn = Function::Create(ReOptimizeFnTy, Function::ExternalLinkage, "__orc_rt_reoptimize", Mod.get()); // Set parameter names auto ArgIt = ReOptimizeFn->arg_begin(); Value *MUID = &*ArgIt++; MUID->setName("MUID"); Value *CurVersion = &*ArgIt; CurVersion->setName("CurVersion"); // Build function body BasicBlock *Entry = BasicBlock::Create(*Ctx, "entry", ReOptimizeFn); Builder.SetInsertPoint(Entry); // Create absolute address constants auto &JDI = PlatformJD.getExecutionSession() .getExecutorProcessControl() .getJITDispatchInfo(); Type *IntPtrTy = DL.getIntPtrType(*Ctx); Constant *JITDispatchPtr = ConstantExpr::getIntToPtr( ConstantInt::get(IntPtrTy, JDI.JITDispatchFunction.getValue()), VoidPtrTy); Constant *JITDispatchCtxPtr = ConstantExpr::getIntToPtr( ConstantInt::get(IntPtrTy, JDI.JITDispatchContext.getValue()), VoidPtrTy); Constant *HelperFnAddr = ConstantExpr::getIntToPtr( ConstantInt::get(IntPtrTy, reinterpret_cast( &orc_rt_lite_reoptimize_helper)), PointerType::getUnqual(*Ctx)); // Cast ReoptimizeTag to void* Value *ReoptimizeTagPtr = Builder.CreatePointerCast(ReoptimizeTag, VoidPtrTy); // Call the helper function Builder.CreateCall( HelperFnTy, HelperFnAddr, {JITDispatchPtr, JITDispatchCtxPtr, ReoptimizeTagPtr, MUID, CurVersion}); // Return void Builder.CreateRetVoid(); return BaseLayer.add(PlatformJD, ThreadSafeModule(std::move(Mod), std::move(Ctx))); } Error ReOptimizeLayer::registerRuntimeFunctions(JITDylib &PlatformJD) { ExecutionSession::JITDispatchHandlerAssociationMap WFs; using ReoptimizeSPSSig = shared::SPSError(uint64_t, uint32_t); WFs[Mangle("__orc_rt_reoptimize_tag")] = ES.wrapAsyncWithSPS(this, &ReOptimizeLayer::rt_reoptimize); return ES.registerJITDispatchHandlers(PlatformJD, std::move(WFs)); } void ReOptimizeLayer::emit(std::unique_ptr R, ThreadSafeModule TSM) { auto &JD = R->getTargetJITDylib(); bool HasNonCallable = false; for (auto &KV : R->getSymbols()) { auto &Flags = KV.second; if (!Flags.isCallable()) HasNonCallable = true; } if (HasNonCallable) { BaseLayer.emit(std::move(R), std::move(TSM)); return; } auto &MUState = createMaterializationUnitState(TSM); if (auto Err = R->withResourceKeyDo([&](ResourceKey Key) { registerMaterializationUnitResource(Key, MUState); })) { ES.reportError(std::move(Err)); R->failMaterialization(); return; } if (auto Err = ProfilerFunc(*this, MUState.getID(), MUState.getCurVersion(), TSM)) { ES.reportError(std::move(Err)); R->failMaterialization(); return; } auto InitialDests = emitMUImplSymbols(MUState, MUState.getCurVersion(), JD, std::move(TSM)); if (!InitialDests) { ES.reportError(InitialDests.takeError()); R->failMaterialization(); return; } RSManager.emitRedirectableSymbols(std::move(R), std::move(*InitialDests)); } Error ReOptimizeLayer::reoptimizeIfCallFrequent(ReOptimizeLayer &Parent, ReOptMaterializationUnitID MUID, unsigned CurVersion, ThreadSafeModule &TSM) { return TSM.withModuleDo([&](Module &M) -> Error { Type *I64Ty = Type::getInt64Ty(M.getContext()); GlobalVariable *Counter = new GlobalVariable( M, I64Ty, false, GlobalValue::InternalLinkage, Constant::getNullValue(I64Ty), "__orc_reopt_counter"); for (auto &F : M) { if (F.isDeclaration()) continue; auto &BB = F.getEntryBlock(); auto *IP = &*BB.getFirstInsertionPt(); IRBuilder<> IRB(IP); Value *Threshold = ConstantInt::get(I64Ty, CallCountThreshold, true); Value *Cnt = IRB.CreateLoad(I64Ty, Counter); // Use EQ to prevent further reoptimize calls. Value *Cmp = IRB.CreateICmpEQ(Cnt, Threshold); Value *Added = IRB.CreateAdd(Cnt, ConstantInt::get(I64Ty, 1)); (void)IRB.CreateStore(Added, Counter); Instruction *SplitTerminator = SplitBlockAndInsertIfThen(Cmp, IP, false); createReoptimizeCall(M, *SplitTerminator, MUID, CurVersion); } return Error::success(); }); } Expected ReOptimizeLayer::emitMUImplSymbols(ReOptMaterializationUnitState &MUState, uint32_t Version, JITDylib &JD, ThreadSafeModule TSM) { DenseMap RenamedMap; cantFail(TSM.withModuleDo([&](Module &M) -> Error { MangleAndInterner Mangle(ES, M.getDataLayout()); for (auto &F : M) if (!F.isDeclaration()) { std::string NewName = (F.getName() + ".__def__." + Twine(Version)).str(); RenamedMap[Mangle(F.getName())] = Mangle(NewName); F.setName(NewName); } return Error::success(); })); auto RT = JD.createResourceTracker(); if (auto Err = JD.define(std::make_unique( BaseLayer, *getManglingOptions(), std::move(TSM)), RT)) return Err; MUState.setResourceTracker(RT); SymbolLookupSet LookupSymbols; for (auto [K, V] : RenamedMap) LookupSymbols.add(V); auto ImplSymbols = ES.lookup({{&JD, JITDylibLookupFlags::MatchAllSymbols}}, LookupSymbols, LookupKind::Static, SymbolState::Resolved); if (auto Err = ImplSymbols.takeError()) return Err; SymbolMap Result; for (auto [K, V] : RenamedMap) Result[K] = (*ImplSymbols)[V]; return Result; } void ReOptimizeLayer::rt_reoptimize(SendErrorFn SendResult, ReOptMaterializationUnitID MUID, uint32_t CurVersion) { auto &MUState = getMaterializationUnitState(MUID); if (CurVersion < MUState.getCurVersion() || !MUState.tryStartReoptimize()) { SendResult(Error::success()); return; } ThreadSafeModule TSM = cloneToNewContext(MUState.getThreadSafeModule()); auto OldRT = MUState.getResourceTracker(); auto &JD = OldRT->getJITDylib(); if (auto Err = ReOptFunc(*this, MUID, CurVersion + 1, OldRT, TSM)) { ES.reportError(std::move(Err)); MUState.reoptimizeFailed(); SendResult(Error::success()); return; } auto SymbolDests = emitMUImplSymbols(MUState, CurVersion + 1, JD, std::move(TSM)); if (!SymbolDests) { ES.reportError(SymbolDests.takeError()); MUState.reoptimizeFailed(); SendResult(Error::success()); return; } if (auto Err = RSManager.redirect(JD, std::move(*SymbolDests))) { ES.reportError(std::move(Err)); MUState.reoptimizeFailed(); SendResult(Error::success()); return; } MUState.reoptimizeSucceeded(); SendResult(Error::success()); } void ReOptimizeLayer::createReoptimizeCall(Module &M, Instruction &IP, ReOptMaterializationUnitID MUID, uint32_t CurVersion) { Type *MUIDTy = IntegerType::get(M.getContext(), 64); Type *VersionTy = IntegerType::get(M.getContext(), 32); Function *ReoptimizeFunc = M.getFunction("__orc_rt_reoptimize"); if (!ReoptimizeFunc) { std::vector ArgTys = {MUIDTy, VersionTy}; FunctionType *FuncTy = FunctionType::get(Type::getVoidTy(M.getContext()), ArgTys, false); ReoptimizeFunc = Function::Create(FuncTy, GlobalValue::ExternalLinkage, "__orc_rt_reoptimize", &M); } Constant *MUIDArg = ConstantInt::get(MUIDTy, MUID, false); Constant *CurVersionArg = ConstantInt::get(VersionTy, CurVersion, false); IRBuilder<> IRB(&IP); (void)IRB.CreateCall(ReoptimizeFunc, {MUIDArg, CurVersionArg}); } ReOptimizeLayer::ReOptMaterializationUnitState & ReOptimizeLayer::createMaterializationUnitState(const ThreadSafeModule &TSM) { std::unique_lock Lock(Mutex); ReOptMaterializationUnitID MUID = NextID; MUStates.emplace(MUID, ReOptMaterializationUnitState(MUID, cloneToNewContext(TSM))); ++NextID; return MUStates.at(MUID); } ReOptimizeLayer::ReOptMaterializationUnitState & ReOptimizeLayer::getMaterializationUnitState(ReOptMaterializationUnitID MUID) { std::unique_lock Lock(Mutex); return MUStates.at(MUID); } void ReOptimizeLayer::registerMaterializationUnitResource( ResourceKey Key, ReOptMaterializationUnitState &State) { std::unique_lock Lock(Mutex); MUResources[Key].insert(State.getID()); } Error ReOptimizeLayer::handleRemoveResources(JITDylib &JD, ResourceKey K) { std::unique_lock Lock(Mutex); for (auto MUID : MUResources[K]) MUStates.erase(MUID); MUResources.erase(K); return Error::success(); } void ReOptimizeLayer::handleTransferResources(JITDylib &JD, ResourceKey DstK, ResourceKey SrcK) { std::unique_lock Lock(Mutex); MUResources[DstK].insert_range(MUResources[SrcK]); MUResources.erase(SrcK); }