[mlir][py] Enable loading only specified dialects during creation. (#121421)

Gives option post as global list as well as arg to control which
dialects are loaded during context creation. This enables setting either
a good base set or skipping in individual cases.
This commit is contained in:
Jacques Pienaar 2025-01-02 14:40:15 -08:00 committed by GitHub
parent 4b57783003
commit c703b4645c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 80 additions and 4 deletions

View File

@ -58,6 +58,7 @@ def get_include_dirs() -> Sequence[str]:
# needs.
_dialect_registry = None
_load_on_create_dialects = None
def get_dialect_registry():
@ -71,6 +72,21 @@ def get_dialect_registry():
return _dialect_registry
def append_load_on_create_dialect(dialect: str):
global _load_on_create_dialects
if _load_on_create_dialects is None:
_load_on_create_dialects = [dialect]
else:
_load_on_create_dialects.append(dialect)
def get_load_on_create_dialects():
global _load_on_create_dialects
if _load_on_create_dialects is None:
_load_on_create_dialects = []
return _load_on_create_dialects
def _site_initialize():
import importlib
import itertools
@ -132,15 +148,35 @@ def _site_initialize():
break
class Context(ir._BaseContext):
def __init__(self, *args, **kwargs):
def __init__(self, load_on_create_dialects=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.append_dialect_registry(get_dialect_registry())
for hook in post_init_hooks:
hook(self)
if not disable_multithreading:
self.enable_multithreading(True)
if not disable_load_all_available_dialects:
self.load_all_available_dialects()
if load_on_create_dialects is not None:
logger.debug(
"Loading all dialects from load_on_create_dialects arg %r",
load_on_create_dialects,
)
for dialect in load_on_create_dialects:
# This triggers loading the dialect into the context.
_ = self.dialects[dialect]
else:
if disable_load_all_available_dialects:
dialects = get_load_on_create_dialects()
if dialects:
logger.debug(
"Loading all dialects from global load_on_create_dialects %r",
dialects,
)
for dialect in dialects:
# This triggers loading the dialect into the context.
_ = self.dialects[dialect]
else:
logger.debug("Loading all available dialects")
self.load_all_available_dialects()
if init_module:
logger.debug(
"Registering translations from initializer %r", init_module

View File

@ -5,7 +5,11 @@
from ._mlir_libs._mlir.ir import *
from ._mlir_libs._mlir.ir import _GlobalDebug
from ._mlir_libs._mlir import register_type_caster, register_value_caster
from ._mlir_libs import get_dialect_registry
from ._mlir_libs import (
get_dialect_registry,
append_load_on_create_dialect,
get_load_on_create_dialects,
)
# Convenience decorator for registering user-friendly Attribute builders.

View File

@ -121,3 +121,39 @@ def testAppendPrefixSearchPath():
sys.path.append(".")
_cext.globals.append_dialect_search_prefix("custom_dialect")
assert _cext.globals._check_dialect_module_loaded("custom")
# CHECK-LABEL: TEST: testDialectLoadOnCreate
@run
def testDialectLoadOnCreate():
with Context(load_on_create_dialects=[]) as ctx:
ctx.emit_error_diagnostics = True
ctx.allow_unregistered_dialects = True
def callback(d):
# CHECK: DIAGNOSTIC
# CHECK-SAME: op created with unregistered dialect
print(f"DIAGNOSTIC={d.message}")
return True
handler = ctx.attach_diagnostic_handler(callback)
loc = Location.unknown(ctx)
try:
op = Operation.create("arith.addi", loc=loc)
ctx.allow_unregistered_dialects = False
op.verify()
except MLIRError as e:
pass
with Context(load_on_create_dialects=["func"]) as ctx:
loc = Location.unknown(ctx)
fn = Operation.create("func.func", loc=loc)
# TODO: This may require an update if a site wide policy is set.
# CHECK: Load on create: []
print(f"Load on create: {get_load_on_create_dialects()}")
append_load_on_create_dialect("func")
# CHECK: Load on create:
# CHECK-SAME: func
print(f"Load on create: {get_load_on_create_dialects()}")
print(get_load_on_create_dialects())