From 49f55f4991227f3c7a2b8161bbf45c74b7023944 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Wed, 5 Nov 2025 21:34:02 -0800 Subject: [PATCH] [mlir][ods] Enable granular pass registration. (#166532) Same as with pass def & decl. This doesn't change anything with registry and the big flag kept (e.g., GEN_PASS_REGISTRATION behaves like GEN_PASS_DECL and so too for sub ones). --- mlir/docs/PassManagement.md | 6 ++++++ mlir/tools/mlir-tblgen/PassGen.cpp | 25 ++++++++++++++++++++----- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/mlir/docs/PassManagement.md b/mlir/docs/PassManagement.md index a920d57c7cd2..8d20b496cd3a 100644 --- a/mlir/docs/PassManagement.md +++ b/mlir/docs/PassManagement.md @@ -835,6 +835,12 @@ each pass, the generator produces a `registerPassName` where generates a `registerGroupPasses`, where `Group` is the tag provided via the `-name` input parameter, that registers all of the passes present. +These declarations can be enabled for the whole group of passes by +defining the `GEN_PASS_REGISTRATION` macro, or on a per-pass basis by +defining `GEN_PASS_REGISTRATION_PASSNAME` where `PASSNAME` is the +uppercase version of the name of the pass (similar to pass def and +decls). + ```c++ // Tablegen options: -gen-pass-decls -name="Example" diff --git a/mlir/tools/mlir-tblgen/PassGen.cpp b/mlir/tools/mlir-tblgen/PassGen.cpp index f7134ce02b72..f4b8eb43b49b 100644 --- a/mlir/tools/mlir-tblgen/PassGen.cpp +++ b/mlir/tools/mlir-tblgen/PassGen.cpp @@ -57,19 +57,23 @@ const char *const passRegistrationCode = R"( //===----------------------------------------------------------------------===// // {0} Registration //===----------------------------------------------------------------------===// +#ifdef {1} inline void register{0}() {{ ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{ - return {1}; + return {2}; }); } // Old registration code, kept for temporary backwards compatibility. inline void register{0}Pass() {{ ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{ - return {1}; + return {2}; }); } + +#undef {1} +#endif // {1} )"; /// The code snippet used to generate a function to register all passes in a @@ -116,6 +120,10 @@ static std::string getPassDeclVarName(const Pass &pass) { return "GEN_PASS_DECL_" + pass.getDef()->getName().upper(); } +static std::string getPassRegistrationVarName(const Pass &pass) { + return "GEN_PASS_REGISTRATION_" + pass.getDef()->getName().upper(); +} + /// Emit the code to be included in the public header of the pass. static void emitPassDecls(const Pass &pass, raw_ostream &os) { StringRef passName = pass.getDef()->getName(); @@ -143,18 +151,25 @@ static void emitPassDecls(const Pass &pass, raw_ostream &os) { /// PassRegistry. static void emitRegistrations(llvm::ArrayRef passes, raw_ostream &os) { os << "#ifdef GEN_PASS_REGISTRATION\n"; + os << "// Generate registrations for all passes.\n"; + for (const Pass &pass : passes) + os << "#define " << getPassRegistrationVarName(pass) << "\n"; + os << "#endif // GEN_PASS_REGISTRATION\n"; for (const Pass &pass : passes) { + std::string passName = pass.getDef()->getName().str(); + std::string passEnableVarName = getPassRegistrationVarName(pass); + std::string constructorCall; if (StringRef constructor = pass.getConstructor(); !constructor.empty()) constructorCall = constructor.str(); else - constructorCall = formatv("create{0}()", pass.getDef()->getName()).str(); - - os << formatv(passRegistrationCode, pass.getDef()->getName(), + constructorCall = formatv("create{0}()", passName).str(); + os << formatv(passRegistrationCode, passName, passEnableVarName, constructorCall); } + os << "#ifdef GEN_PASS_REGISTRATION\n"; os << formatv(passGroupRegistrationCode, groupName); for (const Pass &pass : passes)