[mlir][py] Enable disabling loading all registered (#117643)

There is a pending todo about always eagerly loading or not. Make this
behavior optional and give the control to the user in a backwards
compatible manner. This is made optional as there were arguments for
both forms, kept it in form that is backwards compatible.
This commit is contained in:
Jacques Pienaar 2024-11-25 15:39:55 -08:00 committed by GitHub
parent 32432a6a02
commit 1ea7ced7ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -80,9 +80,16 @@ def _site_initialize():
logger = logging.getLogger(__name__)
post_init_hooks = []
disable_multithreading = False
# This flag disables eagerly loading all dialects. Eagerly loading is often
# not the desired behavior (see
# https://github.com/llvm/llvm-project/issues/56037), and the logic is that
# if any module has this attribute set, then we don't load all (e.g., it's
# being used in a solution where the loading is controlled).
disable_load_all_available_dialects = False
def process_initializer_module(module_name):
nonlocal disable_multithreading
nonlocal disable_load_all_available_dialects
try:
m = importlib.import_module(f".{module_name}", __name__)
except ModuleNotFoundError:
@ -107,6 +114,8 @@ def _site_initialize():
if bool(m.disable_multithreading):
logger.debug("Disabling multi-threading for context")
disable_multithreading = True
if hasattr(m, "disable_load_all_available_dialects"):
disable_load_all_available_dialects = True
return True
# If _mlirRegisterEverything is built, then include it as an initializer
@ -130,10 +139,8 @@ def _site_initialize():
hook(self)
if not disable_multithreading:
self.enable_multithreading(True)
# TODO: There is some debate about whether we should eagerly load
# all dialects. It is being done here in order to preserve existing
# behavior. See: https://github.com/llvm/llvm-project/issues/56037
self.load_all_available_dialects()
if not disable_load_all_available_dialects:
self.load_all_available_dialects()
if init_module:
logger.debug(
"Registering translations from initializer %r", init_module