[mlir][py] Enable loading only specified dialects during creation. (#121421)
Gives option post as global list as well as arg to control which dialects are loaded during context creation. This enables setting either a good base set or skipping in individual cases.
This commit is contained in:
parent
4b57783003
commit
c703b4645c
@ -58,6 +58,7 @@ def get_include_dirs() -> Sequence[str]:
|
||||
# needs.
|
||||
|
||||
_dialect_registry = None
|
||||
_load_on_create_dialects = None
|
||||
|
||||
|
||||
def get_dialect_registry():
|
||||
@ -71,6 +72,21 @@ def get_dialect_registry():
|
||||
return _dialect_registry
|
||||
|
||||
|
||||
def append_load_on_create_dialect(dialect: str):
|
||||
global _load_on_create_dialects
|
||||
if _load_on_create_dialects is None:
|
||||
_load_on_create_dialects = [dialect]
|
||||
else:
|
||||
_load_on_create_dialects.append(dialect)
|
||||
|
||||
|
||||
def get_load_on_create_dialects():
|
||||
global _load_on_create_dialects
|
||||
if _load_on_create_dialects is None:
|
||||
_load_on_create_dialects = []
|
||||
return _load_on_create_dialects
|
||||
|
||||
|
||||
def _site_initialize():
|
||||
import importlib
|
||||
import itertools
|
||||
@ -132,15 +148,35 @@ def _site_initialize():
|
||||
break
|
||||
|
||||
class Context(ir._BaseContext):
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, load_on_create_dialects=None, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.append_dialect_registry(get_dialect_registry())
|
||||
for hook in post_init_hooks:
|
||||
hook(self)
|
||||
if not disable_multithreading:
|
||||
self.enable_multithreading(True)
|
||||
if not disable_load_all_available_dialects:
|
||||
self.load_all_available_dialects()
|
||||
if load_on_create_dialects is not None:
|
||||
logger.debug(
|
||||
"Loading all dialects from load_on_create_dialects arg %r",
|
||||
load_on_create_dialects,
|
||||
)
|
||||
for dialect in load_on_create_dialects:
|
||||
# This triggers loading the dialect into the context.
|
||||
_ = self.dialects[dialect]
|
||||
else:
|
||||
if disable_load_all_available_dialects:
|
||||
dialects = get_load_on_create_dialects()
|
||||
if dialects:
|
||||
logger.debug(
|
||||
"Loading all dialects from global load_on_create_dialects %r",
|
||||
dialects,
|
||||
)
|
||||
for dialect in dialects:
|
||||
# This triggers loading the dialect into the context.
|
||||
_ = self.dialects[dialect]
|
||||
else:
|
||||
logger.debug("Loading all available dialects")
|
||||
self.load_all_available_dialects()
|
||||
if init_module:
|
||||
logger.debug(
|
||||
"Registering translations from initializer %r", init_module
|
||||
|
@ -5,7 +5,11 @@
|
||||
from ._mlir_libs._mlir.ir import *
|
||||
from ._mlir_libs._mlir.ir import _GlobalDebug
|
||||
from ._mlir_libs._mlir import register_type_caster, register_value_caster
|
||||
from ._mlir_libs import get_dialect_registry
|
||||
from ._mlir_libs import (
|
||||
get_dialect_registry,
|
||||
append_load_on_create_dialect,
|
||||
get_load_on_create_dialects,
|
||||
)
|
||||
|
||||
|
||||
# Convenience decorator for registering user-friendly Attribute builders.
|
||||
|
@ -121,3 +121,39 @@ def testAppendPrefixSearchPath():
|
||||
sys.path.append(".")
|
||||
_cext.globals.append_dialect_search_prefix("custom_dialect")
|
||||
assert _cext.globals._check_dialect_module_loaded("custom")
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testDialectLoadOnCreate
|
||||
@run
|
||||
def testDialectLoadOnCreate():
|
||||
with Context(load_on_create_dialects=[]) as ctx:
|
||||
ctx.emit_error_diagnostics = True
|
||||
ctx.allow_unregistered_dialects = True
|
||||
|
||||
def callback(d):
|
||||
# CHECK: DIAGNOSTIC
|
||||
# CHECK-SAME: op created with unregistered dialect
|
||||
print(f"DIAGNOSTIC={d.message}")
|
||||
return True
|
||||
|
||||
handler = ctx.attach_diagnostic_handler(callback)
|
||||
loc = Location.unknown(ctx)
|
||||
try:
|
||||
op = Operation.create("arith.addi", loc=loc)
|
||||
ctx.allow_unregistered_dialects = False
|
||||
op.verify()
|
||||
except MLIRError as e:
|
||||
pass
|
||||
|
||||
with Context(load_on_create_dialects=["func"]) as ctx:
|
||||
loc = Location.unknown(ctx)
|
||||
fn = Operation.create("func.func", loc=loc)
|
||||
|
||||
# TODO: This may require an update if a site wide policy is set.
|
||||
# CHECK: Load on create: []
|
||||
print(f"Load on create: {get_load_on_create_dialects()}")
|
||||
append_load_on_create_dialect("func")
|
||||
# CHECK: Load on create:
|
||||
# CHECK-SAME: func
|
||||
print(f"Load on create: {get_load_on_create_dialects()}")
|
||||
print(get_load_on_create_dialects())
|
||||
|
Loading…
x
Reference in New Issue
Block a user