diff --git a/mlir/docs/PassManagement.md b/mlir/docs/PassManagement.md index a920d57c7cd2..8d20b496cd3a 100644 --- a/mlir/docs/PassManagement.md +++ b/mlir/docs/PassManagement.md @@ -835,6 +835,12 @@ each pass, the generator produces a `registerPassName` where generates a `registerGroupPasses`, where `Group` is the tag provided via the `-name` input parameter, that registers all of the passes present. +These declarations can be enabled for the whole group of passes by +defining the `GEN_PASS_REGISTRATION` macro, or on a per-pass basis by +defining `GEN_PASS_REGISTRATION_PASSNAME` where `PASSNAME` is the +uppercase version of the name of the pass (similar to pass def and +decls). + ```c++ // Tablegen options: -gen-pass-decls -name="Example" diff --git a/mlir/tools/mlir-tblgen/PassGen.cpp b/mlir/tools/mlir-tblgen/PassGen.cpp index f7134ce02b72..f4b8eb43b49b 100644 --- a/mlir/tools/mlir-tblgen/PassGen.cpp +++ b/mlir/tools/mlir-tblgen/PassGen.cpp @@ -57,19 +57,23 @@ const char *const passRegistrationCode = R"( //===----------------------------------------------------------------------===// // {0} Registration //===----------------------------------------------------------------------===// +#ifdef {1} inline void register{0}() {{ ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{ - return {1}; + return {2}; }); } // Old registration code, kept for temporary backwards compatibility. inline void register{0}Pass() {{ ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{ - return {1}; + return {2}; }); } + +#undef {1} +#endif // {1} )"; /// The code snippet used to generate a function to register all passes in a @@ -116,6 +120,10 @@ static std::string getPassDeclVarName(const Pass &pass) { return "GEN_PASS_DECL_" + pass.getDef()->getName().upper(); } +static std::string getPassRegistrationVarName(const Pass &pass) { + return "GEN_PASS_REGISTRATION_" + pass.getDef()->getName().upper(); +} + /// Emit the code to be included in the public header of the pass. static void emitPassDecls(const Pass &pass, raw_ostream &os) { StringRef passName = pass.getDef()->getName(); @@ -143,18 +151,25 @@ static void emitPassDecls(const Pass &pass, raw_ostream &os) { /// PassRegistry. static void emitRegistrations(llvm::ArrayRef passes, raw_ostream &os) { os << "#ifdef GEN_PASS_REGISTRATION\n"; + os << "// Generate registrations for all passes.\n"; + for (const Pass &pass : passes) + os << "#define " << getPassRegistrationVarName(pass) << "\n"; + os << "#endif // GEN_PASS_REGISTRATION\n"; for (const Pass &pass : passes) { + std::string passName = pass.getDef()->getName().str(); + std::string passEnableVarName = getPassRegistrationVarName(pass); + std::string constructorCall; if (StringRef constructor = pass.getConstructor(); !constructor.empty()) constructorCall = constructor.str(); else - constructorCall = formatv("create{0}()", pass.getDef()->getName()).str(); - - os << formatv(passRegistrationCode, pass.getDef()->getName(), + constructorCall = formatv("create{0}()", passName).str(); + os << formatv(passRegistrationCode, passName, passEnableVarName, constructorCall); } + os << "#ifdef GEN_PASS_REGISTRATION\n"; os << formatv(passGroupRegistrationCode, groupName); for (const Pass &pass : passes)