[mlir][ods] Allow sharding of op definitions (#89423)
Adds an option to `mlir-tblgen -gen-op-defs` `op-shard-count=N` that divides the op class definitions and op list into N segments, e.g. ``` // mlir-tblgen -gen-op-defs -op-shard-count=2 void FooDialect::initialize() { addOperations< >(); addOperations< >(); } ``` When split across multiple source files, this can help significantly improve dialect compile time for dialects with a large opset.
This commit is contained in:
parent
c3def59d0f
commit
1b232fa0e9
@ -185,10 +185,13 @@ include_directories( ${MLIR_INCLUDE_DIR})
|
||||
add_subdirectory(tools/mlir-linalg-ods-gen)
|
||||
add_subdirectory(tools/mlir-pdll)
|
||||
add_subdirectory(tools/mlir-tblgen)
|
||||
add_subdirectory(tools/mlir-src-sharder)
|
||||
set(MLIR_TABLEGEN_EXE "${MLIR_TABLEGEN_EXE}" CACHE INTERNAL "")
|
||||
set(MLIR_TABLEGEN_TARGET "${MLIR_TABLEGEN_TARGET}" CACHE INTERNAL "")
|
||||
set(MLIR_PDLL_TABLEGEN_EXE "${MLIR_PDLL_TABLEGEN_EXE}" CACHE INTERNAL "")
|
||||
set(MLIR_PDLL_TABLEGEN_TARGET "${MLIR_PDLL_TABLEGEN_TARGET}" CACHE INTERNAL "")
|
||||
set(MLIR_SRC_SHARDER_TABLEGEN_EXE "${MLIR_SRC_SHARDER_TABLEGEN_EXE}" CACHE INTERNAL "")
|
||||
set(MLIR_SRC_SHARDER_TABLEGEN_TARGET "${MLIR_SRC_SHARDER_TABLEGEN_TARGET}" CACHE INTERNAL "")
|
||||
|
||||
add_subdirectory(include/mlir)
|
||||
add_subdirectory(lib)
|
||||
|
@ -5,6 +5,28 @@ function(mlir_tablegen ofn)
|
||||
tablegen(MLIR ${ARGV})
|
||||
set(TABLEGEN_OUTPUT ${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${ofn}
|
||||
PARENT_SCOPE)
|
||||
|
||||
# Get the current set of include paths for this td file.
|
||||
cmake_parse_arguments(ARG "" "" "DEPENDS;EXTRA_INCLUDES" ${ARGN})
|
||||
get_directory_property(tblgen_includes INCLUDE_DIRECTORIES)
|
||||
list(APPEND tblgen_includes ${ARG_EXTRA_INCLUDES})
|
||||
# Filter out any empty include items.
|
||||
list(REMOVE_ITEM tblgen_includes "")
|
||||
|
||||
# Build the absolute path for the current input file.
|
||||
if (IS_ABSOLUTE ${LLVM_TARGET_DEFINITIONS})
|
||||
set(LLVM_TARGET_DEFINITIONS_ABSOLUTE ${LLVM_TARGET_DEFINITIONS})
|
||||
else()
|
||||
set(LLVM_TARGET_DEFINITIONS_ABSOLUTE ${CMAKE_CURRENT_SOURCE_DIR}/${LLVM_TARGET_DEFINITIONS})
|
||||
endif()
|
||||
|
||||
# Append the includes used for this file to the tablegen_compile_commands
|
||||
# file.
|
||||
file(APPEND ${CMAKE_BINARY_DIR}/tablegen_compile_commands.yml
|
||||
"--- !FileInfo:\n"
|
||||
" filepath: \"${LLVM_TARGET_DEFINITIONS_ABSOLUTE}\"\n"
|
||||
" includes: \"${CMAKE_CURRENT_SOURCE_DIR};${tblgen_includes}\"\n"
|
||||
)
|
||||
endfunction()
|
||||
|
||||
# Clear out any pre-existing compile_commands file before processing. This
|
||||
@ -149,6 +171,22 @@ function(add_mlir_dialect dialect dialect_namespace)
|
||||
add_dependencies(mlir-headers MLIR${dialect}IncGen)
|
||||
endfunction()
|
||||
|
||||
# Declare sharded dialect operation declarations and definitions
|
||||
function(add_sharded_ops ops_target shard_count)
|
||||
set(LLVM_TARGET_DEFINITIONS ${ops_target}.td)
|
||||
mlir_tablegen(${ops_target}.h.inc -gen-op-decls -op-shard-count=${shard_count})
|
||||
mlir_tablegen(${ops_target}.cpp.inc -gen-op-defs -op-shard-count=${shard_count})
|
||||
set(LLVM_TARGET_DEFINITIONS ${ops_target}.cpp)
|
||||
foreach(index RANGE ${shard_count})
|
||||
set(SHARDED_SRC ${ops_target}.${index}.cpp)
|
||||
list(APPEND SHARDED_SRCS ${SHARDED_SRC})
|
||||
tablegen(MLIR_SRC_SHARDER ${SHARDED_SRC} -op-shard-index=${index})
|
||||
set(TABLEGEN_OUTPUT ${TABLEGEN_OUTPUT} ${CMAKE_CURRENT_BINARY_DIR}/${SHARDED_SRC})
|
||||
endforeach()
|
||||
add_public_tablegen_target(MLIR${ops_target}ShardGen)
|
||||
set(SHARDED_SRCS ${SHARDED_SRCS} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
# Declare a dialect in the include directory
|
||||
function(add_mlir_interface interface)
|
||||
set(LLVM_TARGET_DEFINITIONS ${interface}.td)
|
||||
|
@ -39,6 +39,7 @@ set(MLIR_CONFIG_INCLUDE_DIRS
|
||||
# Refer to the best host mlir-tbgen, which might be a host-optimized version
|
||||
set(MLIR_CONFIG_TABLEGEN_EXE "${MLIR_TABLEGEN_EXE}")
|
||||
set(MLIR_CONFIG_PDLL_TABLEGEN_EXE "${MLIR_PDLL_TABLEGEN_EXE}")
|
||||
set(MLIR_CONFIG_SRC_SHARDER_TABLEGEN_EXE "${MLIR_SRC_SHARDER_TABLEGEN_EXE}")
|
||||
|
||||
configure_file(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/MLIRConfig.cmake.in
|
||||
@ -77,6 +78,7 @@ set(MLIR_CONFIG_INCLUDE_DIRS
|
||||
# if we're building with a host-optimized mlir-tblgen (with LLVM_OPTIMIZED_TABLEGEN).
|
||||
set(MLIR_CONFIG_TABLEGEN_EXE mlir-tblgen)
|
||||
set(MLIR_CONFIG_PDLL_TABLEGEN_EXE mlir-pdll)
|
||||
set(MLIR_CONFIG_SRC_SHARDER_TABLEGEN_EXE mlir-src-sharder)
|
||||
|
||||
configure_file(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/MLIRConfig.cmake.in
|
||||
|
@ -11,6 +11,7 @@ set(MLIR_CMAKE_DIR "@MLIR_CONFIG_CMAKE_DIR@")
|
||||
set(MLIR_INCLUDE_DIRS "@MLIR_CONFIG_INCLUDE_DIRS@")
|
||||
set(MLIR_TABLEGEN_EXE "@MLIR_CONFIG_TABLEGEN_EXE@")
|
||||
set(MLIR_PDLL_TABLEGEN_EXE "@MLIR_CONFIG_PDLL_TABLEGEN_EXE@")
|
||||
set(MLIR_SRC_SHARDER_TABLEGEN_EXE "@MLIR_CONFIG_SRC_SHARDER_TABLEGEN_EXE@")
|
||||
set(MLIR_INSTALL_AGGREGATE_OBJECTS "@MLIR_INSTALL_AGGREGATE_OBJECTS@")
|
||||
set(MLIR_ENABLE_BINDINGS_PYTHON "@MLIR_ENABLE_BINDINGS_PYTHON@")
|
||||
set(MLIR_ENABLE_EXECUTION_ENGINE "@MLIR_ENABLE_EXECUTION_ENGINE@")
|
||||
|
@ -99,8 +99,14 @@ private:
|
||||
///
|
||||
class StaticVerifierFunctionEmitter {
|
||||
public:
|
||||
/// Create a constraint uniquer with a unique prefix derived from the record
|
||||
/// keeper with an optional tag.
|
||||
StaticVerifierFunctionEmitter(raw_ostream &os,
|
||||
const llvm::RecordKeeper &records);
|
||||
const llvm::RecordKeeper &records,
|
||||
StringRef tag = "");
|
||||
|
||||
/// Collect and unique all the constraints used by operations.
|
||||
void collectOpConstraints(ArrayRef<llvm::Record *> opDefs);
|
||||
|
||||
/// Collect and unique all compatible type, attribute, successor, and region
|
||||
/// constraints from the operations in the file and emit them at the top of
|
||||
@ -108,7 +114,7 @@ public:
|
||||
///
|
||||
/// Constraints that do not meet the restriction that they can only reference
|
||||
/// `$_self` and `$_op` are not uniqued.
|
||||
void emitOpConstraints(ArrayRef<llvm::Record *> opDefs, bool emitDecl);
|
||||
void emitOpConstraints(ArrayRef<llvm::Record *> opDefs);
|
||||
|
||||
/// Unique all compatible type and attribute constraints from a pattern file
|
||||
/// and emit them at the top of the generated file.
|
||||
@ -177,8 +183,6 @@ private:
|
||||
/// Emit pattern constraints.
|
||||
void emitPatternConstraints();
|
||||
|
||||
/// Collect and unique all the constraints used by operations.
|
||||
void collectOpConstraints(ArrayRef<llvm::Record *> opDefs);
|
||||
/// Collect and unique all pattern constraints.
|
||||
void collectPatternConstraints(ArrayRef<DagLeaf> constraints);
|
||||
|
||||
|
@ -24,7 +24,8 @@ using namespace mlir::tblgen;
|
||||
|
||||
/// Generate a unique label based on the current file name to prevent name
|
||||
/// collisions if multiple generated files are included at once.
|
||||
static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) {
|
||||
static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records,
|
||||
StringRef tag) {
|
||||
// Use the input file name when generating a unique name.
|
||||
std::string inputFilename = records.getInputFilename();
|
||||
|
||||
@ -33,7 +34,7 @@ static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) {
|
||||
nameRef.consume_back(".td");
|
||||
|
||||
// Sanitize any invalid characters.
|
||||
std::string uniqueName;
|
||||
std::string uniqueName(tag);
|
||||
for (char c : nameRef) {
|
||||
if (llvm::isAlnum(c) || c == '_')
|
||||
uniqueName.push_back(c);
|
||||
@ -44,15 +45,11 @@ static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) {
|
||||
}
|
||||
|
||||
StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
|
||||
raw_ostream &os, const llvm::RecordKeeper &records)
|
||||
: os(os), uniqueOutputLabel(getUniqueOutputLabel(records)) {}
|
||||
raw_ostream &os, const llvm::RecordKeeper &records, StringRef tag)
|
||||
: os(os), uniqueOutputLabel(getUniqueOutputLabel(records, tag)) {}
|
||||
|
||||
void StaticVerifierFunctionEmitter::emitOpConstraints(
|
||||
ArrayRef<llvm::Record *> opDefs, bool emitDecl) {
|
||||
collectOpConstraints(opDefs);
|
||||
if (emitDecl)
|
||||
return;
|
||||
|
||||
ArrayRef<llvm::Record *> opDefs) {
|
||||
NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace());
|
||||
emitTypeConstraints();
|
||||
emitAttrConstraints();
|
||||
|
33
mlir/test/mlir-tblgen/shard-op-defs.td
Normal file
33
mlir/test/mlir-tblgen/shard-op-defs.td
Normal file
@ -0,0 +1,33 @@
|
||||
// RUN: mlir-tblgen -gen-op-defs -op-shard-count=2 -I %S/../../include %s | FileCheck %s --check-prefix=DEFS
|
||||
// RUN: mlir-tblgen -gen-op-decls -op-shard-count=2 -I %S/../../include %s | FileCheck %s --check-prefix=DECLS
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def Test_Dialect : Dialect {
|
||||
let name = "test";
|
||||
let cppNamespace = "test";
|
||||
}
|
||||
|
||||
class Test_Op<string mnemonic, list<Trait> traits = []>
|
||||
: Op<Test_Dialect, mnemonic, traits>;
|
||||
|
||||
def OpA : Test_Op<"a">;
|
||||
def OpB : Test_Op<"b">;
|
||||
def OpC : Test_Op<"c">;
|
||||
|
||||
// DECLS: OpA
|
||||
// DECLS: OpB
|
||||
// DECLS: OpC
|
||||
// DECLS: registerTestDialectOperations(
|
||||
// DECLS: registerTestDialectOperations0(
|
||||
// DECLS: registerTestDialectOperations1(
|
||||
|
||||
// DEFS-LABEL: GET_OP_DEFS_0
|
||||
// DEFS: void test::registerTestDialectOperations(
|
||||
// DEFS: void test::registerTestDialectOperations0(
|
||||
// DEFS: OpAAdaptor
|
||||
// DEFS: OpBAdaptor
|
||||
|
||||
// DEFS-LABEL: GET_OP_DEFS_1
|
||||
// DEFS: void test::registerTestDialectOperations1(
|
||||
// DEFS: OpCAdaptor
|
14
mlir/tools/mlir-src-sharder/CMakeLists.txt
Normal file
14
mlir/tools/mlir-src-sharder/CMakeLists.txt
Normal file
@ -0,0 +1,14 @@
|
||||
set(LLVM_LINK_COMPONENTS Support)
|
||||
set(LIBS MLIRSupport)
|
||||
|
||||
add_tablegen(mlir-src-sharder MLIR_SRC_SHARDER
|
||||
mlir-src-sharder.cpp
|
||||
|
||||
DEPENDS
|
||||
${LIBS}
|
||||
)
|
||||
|
||||
set_target_properties(mlir-src-sharder PROPERTIES FOLDER "Tablegenning")
|
||||
target_link_libraries(mlir-src-sharder PRIVATE ${LIBS})
|
||||
|
||||
mlir_check_all_link_libraries(mlir-src-sharder)
|
114
mlir/tools/mlir-src-sharder/mlir-src-sharder.cpp
Normal file
114
mlir/tools/mlir-src-sharder/mlir-src-sharder.cpp
Normal file
@ -0,0 +1,114 @@
|
||||
//===- mlir-src-sharder.cpp - A tool for sharder generated source files ---===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Support/FileUtilities.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/InitLLVM.h"
|
||||
#include "llvm/Support/MemoryBuffer.h"
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
/// Create a dependency file for `-d` option.
|
||||
///
|
||||
/// This functionality is generally only for the benefit of the build system,
|
||||
/// and is modeled after the same option in TableGen.
|
||||
static LogicalResult createDependencyFile(StringRef outputFilename,
|
||||
StringRef dependencyFile) {
|
||||
if (outputFilename == "-") {
|
||||
llvm::errs() << "error: the option -d must be used together with -o\n";
|
||||
return failure();
|
||||
}
|
||||
|
||||
std::string errorMessage;
|
||||
std::unique_ptr<llvm::ToolOutputFile> outputFile =
|
||||
openOutputFile(dependencyFile, &errorMessage);
|
||||
if (!outputFile) {
|
||||
llvm::errs() << errorMessage << "\n";
|
||||
return failure();
|
||||
}
|
||||
|
||||
outputFile->os() << outputFilename << ":\n";
|
||||
outputFile->keep();
|
||||
return success();
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
// FIXME: This is necessary because we link in TableGen, which defines its
|
||||
// options as static variables.. some of which overlap with our options.
|
||||
llvm::cl::ResetCommandLineParser();
|
||||
|
||||
llvm::cl::opt<unsigned> opShardIndex(
|
||||
"op-shard-index", llvm::cl::desc("The current shard index"));
|
||||
llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
|
||||
llvm::cl::desc("<input file>"),
|
||||
llvm::cl::init("-"));
|
||||
llvm::cl::opt<std::string> outputFilename(
|
||||
"o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
|
||||
llvm::cl::init("-"));
|
||||
llvm::cl::list<std::string> includeDirs(
|
||||
"I", llvm::cl::desc("Directory of include files"),
|
||||
llvm::cl::value_desc("directory"), llvm::cl::Prefix);
|
||||
llvm::cl::opt<std::string> dependencyFilename(
|
||||
"d", llvm::cl::desc("Dependency filename"),
|
||||
llvm::cl::value_desc("filename"), llvm::cl::init(""));
|
||||
llvm::cl::opt<bool> writeIfChanged(
|
||||
"write-if-changed",
|
||||
llvm::cl::desc("Only write to the output file if it changed"));
|
||||
|
||||
llvm::InitLLVM y(argc, argv);
|
||||
llvm::cl::ParseCommandLineOptions(argc, argv);
|
||||
|
||||
// Open the input file.
|
||||
std::string errorMessage;
|
||||
std::unique_ptr<llvm::MemoryBuffer> inputFile =
|
||||
openInputFile(inputFilename, &errorMessage);
|
||||
if (!inputFile) {
|
||||
llvm::errs() << errorMessage << "\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Write the output to a buffer.
|
||||
std::string outputStr;
|
||||
llvm::raw_string_ostream os(outputStr);
|
||||
os << "#define GET_OP_DEFS_" << opShardIndex << "\n"
|
||||
<< inputFile->getBuffer();
|
||||
|
||||
// Determine whether we need to write the output file.
|
||||
bool shouldWriteOutput = true;
|
||||
if (writeIfChanged) {
|
||||
// Only update the real output file if there are any differences. This
|
||||
// prevents recompilation of all the files depending on it if there aren't
|
||||
// any.
|
||||
if (auto existingOrErr =
|
||||
llvm::MemoryBuffer::getFile(outputFilename, /*IsText=*/true))
|
||||
if (std::move(existingOrErr.get())->getBuffer() == os.str())
|
||||
shouldWriteOutput = false;
|
||||
}
|
||||
|
||||
// Populate the output file if necessary.
|
||||
if (shouldWriteOutput) {
|
||||
std::unique_ptr<llvm::ToolOutputFile> outputFile =
|
||||
openOutputFile(outputFilename, &errorMessage);
|
||||
if (!outputFile) {
|
||||
llvm::errs() << errorMessage << "\n";
|
||||
return 1;
|
||||
}
|
||||
outputFile->os() << os.str();
|
||||
outputFile->keep();
|
||||
}
|
||||
|
||||
// Always write the depfile, even if the main output hasn't changed. If it's
|
||||
// missing, Ninja considers the output dirty.
|
||||
if (!dependencyFilename.empty())
|
||||
if (failed(createDependencyFile(outputFilename, dependencyFilename)))
|
||||
return 1;
|
||||
|
||||
return 0;
|
||||
}
|
@ -4303,32 +4303,15 @@ void OpOperandAdaptorEmitter::emitDef(
|
||||
emitter.adaptor.writeDefTo(os);
|
||||
}
|
||||
|
||||
// Emits the opcode enum and op classes.
|
||||
static void emitOpClasses(const RecordKeeper &recordKeeper,
|
||||
/// Emit the class declarations or definitions for the given op defs.
|
||||
static void
|
||||
emitOpClasses(const RecordKeeper &recordKeeper,
|
||||
const std::vector<Record *> &defs, raw_ostream &os,
|
||||
const StaticVerifierFunctionEmitter &staticVerifierEmitter,
|
||||
bool emitDecl) {
|
||||
// First emit forward declaration for each class, this allows them to refer
|
||||
// to each others in traits for example.
|
||||
if (emitDecl) {
|
||||
os << "#if defined(GET_OP_CLASSES) || defined(GET_OP_FWD_DEFINES)\n";
|
||||
os << "#undef GET_OP_FWD_DEFINES\n";
|
||||
for (auto *def : defs) {
|
||||
Operator op(*def);
|
||||
NamespaceEmitter emitter(os, op.getCppNamespace());
|
||||
os << "class " << op.getCppClassName() << ";\n";
|
||||
}
|
||||
os << "#endif\n\n";
|
||||
}
|
||||
|
||||
IfDefScope scope("GET_OP_CLASSES", os);
|
||||
if (defs.empty())
|
||||
return;
|
||||
|
||||
// Generate all of the locally instantiated methods first.
|
||||
StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper);
|
||||
os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
|
||||
staticVerifierEmitter.emitOpConstraints(defs, emitDecl);
|
||||
|
||||
for (auto *def : defs) {
|
||||
Operator op(*def);
|
||||
if (emitDecl) {
|
||||
@ -4358,34 +4341,145 @@ static void emitOpClasses(const RecordKeeper &recordKeeper,
|
||||
}
|
||||
}
|
||||
|
||||
// Emits a comma-separated list of the ops.
|
||||
static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
|
||||
IfDefScope scope("GET_OP_LIST", os);
|
||||
|
||||
interleave(
|
||||
// TODO: We are constructing the Operator wrapper instance just for
|
||||
// getting it's qualified class name here. Reduce the overhead by having a
|
||||
// lightweight version of Operator class just for that purpose.
|
||||
defs, [&os](Record *def) { os << Operator(def).getQualCppClassName(); },
|
||||
[&os]() { os << ",\n"; });
|
||||
/// Emit the declarations for the provided op classes.
|
||||
static void emitOpClassDecls(const RecordKeeper &recordKeeper,
|
||||
const std::vector<Record *> &defs,
|
||||
raw_ostream &os) {
|
||||
// First emit forward declaration for each class, this allows them to refer
|
||||
// to each others in traits for example.
|
||||
for (auto *def : defs) {
|
||||
Operator op(*def);
|
||||
NamespaceEmitter emitter(os, op.getCppNamespace());
|
||||
os << "class " << op.getCppClassName() << ";\n";
|
||||
}
|
||||
|
||||
// Emit the op class declarations.
|
||||
IfDefScope scope("GET_OP_CLASSES", os);
|
||||
if (defs.empty())
|
||||
return;
|
||||
StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper);
|
||||
staticVerifierEmitter.collectOpConstraints(defs);
|
||||
emitOpClasses(recordKeeper, defs, os, staticVerifierEmitter,
|
||||
/*emitDecl=*/true);
|
||||
}
|
||||
|
||||
/// Emit the definitions for the provided op classes.
|
||||
static void emitOpClassDefs(const RecordKeeper &recordKeeper,
|
||||
ArrayRef<Record *> defs, raw_ostream &os,
|
||||
StringRef constraintPrefix = "") {
|
||||
if (defs.empty())
|
||||
return;
|
||||
|
||||
// Generate all of the locally instantiated methods first.
|
||||
StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper,
|
||||
constraintPrefix);
|
||||
os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
|
||||
staticVerifierEmitter.collectOpConstraints(defs);
|
||||
staticVerifierEmitter.emitOpConstraints(defs);
|
||||
|
||||
// Emit the classes.
|
||||
emitOpClasses(recordKeeper, defs, os, staticVerifierEmitter,
|
||||
/*emitDecl=*/false);
|
||||
}
|
||||
|
||||
/// Emit op declarations for all op records.
|
||||
static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
||||
emitSourceFileHeader("Op Declarations", os, recordKeeper);
|
||||
|
||||
std::vector<Record *> defs = getRequestedOpDefinitions(recordKeeper);
|
||||
emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/true);
|
||||
emitOpClassDecls(recordKeeper, defs, os);
|
||||
|
||||
// If we are generating sharded op definitions, emit the sharded op
|
||||
// registration hooks.
|
||||
SmallVector<ArrayRef<Record *>, 4> shardedDefs;
|
||||
shardOpDefinitions(defs, shardedDefs);
|
||||
if (defs.empty() || shardedDefs.size() <= 1)
|
||||
return false;
|
||||
|
||||
Dialect dialect = Operator(defs.front()).getDialect();
|
||||
NamespaceEmitter ns(os, dialect);
|
||||
|
||||
const char *const opRegistrationHook =
|
||||
"void register{0}Operations{1}({2}::{0} *dialect);\n";
|
||||
os << formatv(opRegistrationHook, dialect.getCppClassName(), "",
|
||||
dialect.getCppNamespace());
|
||||
for (unsigned i = 0; i < shardedDefs.size(); ++i) {
|
||||
os << formatv(opRegistrationHook, dialect.getCppClassName(), i,
|
||||
dialect.getCppNamespace());
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Generate the dialect op registration hook and the op class definitions for a
|
||||
/// shard of ops.
|
||||
static void emitOpDefShard(const RecordKeeper &recordKeeper,
|
||||
ArrayRef<Record *> defs, const Dialect &dialect,
|
||||
unsigned shardIndex, unsigned shardCount,
|
||||
raw_ostream &os) {
|
||||
std::string shardGuard = "GET_OP_DEFS_";
|
||||
std::string indexStr = std::to_string(shardIndex);
|
||||
shardGuard += indexStr;
|
||||
IfDefScope scope(shardGuard, os);
|
||||
|
||||
// Emit the op registration hook in the first shard.
|
||||
const char *const opRegistrationHook =
|
||||
"void {0}::register{1}Operations{2}({0}::{1} *dialect) {{\n";
|
||||
if (shardIndex == 0) {
|
||||
os << formatv(opRegistrationHook, dialect.getCppNamespace(),
|
||||
dialect.getCppClassName(), "");
|
||||
for (unsigned i = 0; i < shardCount; ++i) {
|
||||
os << formatv(" {0}::register{1}Operations{2}(dialect);\n",
|
||||
dialect.getCppNamespace(), dialect.getCppClassName(), i);
|
||||
}
|
||||
os << "}\n";
|
||||
}
|
||||
|
||||
// Generate the per-shard op registration hook.
|
||||
os << formatv(opCommentHeader, dialect.getCppClassName(),
|
||||
"Op Registration Hook")
|
||||
<< formatv(opRegistrationHook, dialect.getCppNamespace(),
|
||||
dialect.getCppClassName(), shardIndex);
|
||||
for (Record *def : defs) {
|
||||
os << formatv(" ::mlir::RegisteredOperationName::insert<{0}>(*dialect);\n",
|
||||
Operator(def).getQualCppClassName());
|
||||
}
|
||||
os << "}\n";
|
||||
|
||||
// Generate the per-shard op definitions.
|
||||
emitOpClassDefs(recordKeeper, defs, os, indexStr);
|
||||
}
|
||||
|
||||
/// Emit op definitions for all op records.
|
||||
static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
||||
emitSourceFileHeader("Op Definitions", os, recordKeeper);
|
||||
|
||||
std::vector<Record *> defs = getRequestedOpDefinitions(recordKeeper);
|
||||
emitOpList(defs, os);
|
||||
emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/false);
|
||||
SmallVector<ArrayRef<Record *>, 4> shardedDefs;
|
||||
shardOpDefinitions(defs, shardedDefs);
|
||||
|
||||
// If no shard was requested, emit the regular op list and class definitions.
|
||||
if (shardedDefs.size() == 1) {
|
||||
{
|
||||
IfDefScope scope("GET_OP_LIST", os);
|
||||
interleave(
|
||||
defs, os,
|
||||
[&](Record *def) { os << Operator(def).getQualCppClassName(); },
|
||||
",\n");
|
||||
}
|
||||
{
|
||||
IfDefScope scope("GET_OP_CLASSES", os);
|
||||
emitOpClassDefs(recordKeeper, defs, os);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if (defs.empty())
|
||||
return false;
|
||||
Dialect dialect = Operator(defs.front()).getDialect();
|
||||
for (auto [idx, value] : llvm::enumerate(shardedDefs)) {
|
||||
emitOpDefShard(recordKeeper, value, dialect, idx, shardedDefs.size(), os);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -31,6 +31,10 @@ static cl::opt<std::string> opExcFilter(
|
||||
"op-exclude-regex",
|
||||
cl::desc("Regex of name of op's to exclude (no filter if empty)"),
|
||||
cl::cat(opDefGenCat));
|
||||
static cl::opt<unsigned> opShardCount(
|
||||
"op-shard-count",
|
||||
cl::desc("The number of shards into which the op classes will be divided"),
|
||||
cl::cat(opDefGenCat), cl::init(1));
|
||||
|
||||
static std::string getOperationName(const Record &def) {
|
||||
auto prefix = def.getValueAsDef("opDialect")->getValueAsString("name");
|
||||
@ -80,3 +84,22 @@ bool mlir::tblgen::isPythonReserved(StringRef str) {
|
||||
reserved.insert("type");
|
||||
return reserved.contains(str);
|
||||
}
|
||||
|
||||
void mlir::tblgen::shardOpDefinitions(
|
||||
ArrayRef<llvm::Record *> defs,
|
||||
SmallVectorImpl<ArrayRef<llvm::Record *>> &shardedDefs) {
|
||||
assert(opShardCount > 0 && "expected a positive shard count");
|
||||
if (opShardCount == 1) {
|
||||
shardedDefs.push_back(defs);
|
||||
return;
|
||||
}
|
||||
|
||||
unsigned minShardSize = defs.size() / opShardCount;
|
||||
unsigned numMissing = defs.size() - minShardSize * opShardCount;
|
||||
shardedDefs.reserve(opShardCount);
|
||||
for (unsigned i = 0, start = 0; i < opShardCount; ++i) {
|
||||
unsigned size = minShardSize + (i < numMissing);
|
||||
shardedDefs.push_back(defs.slice(start, size));
|
||||
start += size;
|
||||
}
|
||||
}
|
||||
|
@ -13,6 +13,7 @@
|
||||
#ifndef MLIR_TOOLS_MLIRTBLGEN_OPGENHELPERS_H_
|
||||
#define MLIR_TOOLS_MLIRTBLGEN_OPGENHELPERS_H_
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/TableGen/Record.h"
|
||||
#include <vector>
|
||||
|
||||
@ -28,6 +29,10 @@ getRequestedOpDefinitions(const llvm::RecordKeeper &recordKeeper);
|
||||
/// Regenerate using python -c"print(set(sorted(__import__('keyword').kwlist)))"
|
||||
bool isPythonReserved(llvm::StringRef str);
|
||||
|
||||
/// Shard the op defintions into the number of shards set by "op-shard-count".
|
||||
void shardOpDefinitions(ArrayRef<llvm::Record *> defs,
|
||||
SmallVectorImpl<ArrayRef<llvm::Record *>> &shardedDefs);
|
||||
|
||||
} // namespace tblgen
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -9771,6 +9771,15 @@ cc_binary(
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "mlir-src-sharder",
|
||||
srcs = ["tools/mlir-src-sharder/mlir-src-sharder.cpp"],
|
||||
deps = [
|
||||
":Support",
|
||||
"//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "mlir-linalg-ods-yaml-gen",
|
||||
srcs = [
|
||||
|
@ -432,3 +432,136 @@ def gentbl_cc_library(
|
||||
copts = copts,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def _gentbl_shard_impl(ctx):
|
||||
args = ctx.actions.args()
|
||||
args.add(ctx.file.src_file)
|
||||
args.add("-op-shard-index", ctx.attr.index)
|
||||
args.add("-o", ctx.outputs.out.path)
|
||||
ctx.actions.run(
|
||||
outputs = [ctx.outputs.out],
|
||||
inputs = [ctx.file.src_file],
|
||||
executable = ctx.executable.sharder,
|
||||
arguments = [args],
|
||||
use_default_shell_env = True,
|
||||
mnemonic = "ShardGenerate",
|
||||
)
|
||||
|
||||
gentbl_shard_rule = rule(
|
||||
_gentbl_shard_impl,
|
||||
doc = "",
|
||||
output_to_genfiles = True,
|
||||
attrs = {
|
||||
"index": attr.int(mandatory = True, doc = ""),
|
||||
"sharder": attr.label(
|
||||
doc = "",
|
||||
executable = True,
|
||||
cfg = "exec",
|
||||
),
|
||||
"src_file": attr.label(
|
||||
doc = "",
|
||||
allow_single_file = True,
|
||||
mandatory = True,
|
||||
),
|
||||
"out": attr.output(
|
||||
doc = "",
|
||||
mandatory = True,
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
def gentbl_sharded_ops(
|
||||
name,
|
||||
tblgen,
|
||||
sharder,
|
||||
td_file,
|
||||
shard_count,
|
||||
src_file,
|
||||
src_out,
|
||||
hdr_out,
|
||||
test = False,
|
||||
includes = [],
|
||||
strip_include_prefix = None,
|
||||
deps = []):
|
||||
"""Generate sharded op declarations and definitions.
|
||||
|
||||
This special build rule shards op definitions in a TableGen file and generates multiple copies
|
||||
of a template source file for including and compiling each shard. The rule defines a filegroup
|
||||
consisting of the source shards, the generated source file, and the generated header file.
|
||||
|
||||
Args:
|
||||
name: The name of the filegroup.
|
||||
tblgen: The binary used to produce the output.
|
||||
sharder: The source file sharder to use.
|
||||
td_file: The primary table definitions file.
|
||||
shard_count: The number of op definition shards to produce.
|
||||
src_file: The source file template.
|
||||
src_out: The generated source file.
|
||||
hdr_out: The generated header file.
|
||||
test: Whether this is a test target.
|
||||
includes: See gentbl_rule.includes
|
||||
deps: See gentbl_rule.deps
|
||||
strip_include_prefix: Attribute to pass through to cc_library.
|
||||
"""
|
||||
cc_lib_name = name + "__gentbl_cc_lib"
|
||||
gentbl_cc_library(
|
||||
name = cc_lib_name,
|
||||
strip_include_prefix = strip_include_prefix,
|
||||
includes = includes,
|
||||
tbl_outs = [
|
||||
(
|
||||
[
|
||||
"-gen-op-defs",
|
||||
"-op-shard-count=" + str(shard_count),
|
||||
],
|
||||
src_out,
|
||||
),
|
||||
(
|
||||
[
|
||||
"-gen-op-decls",
|
||||
"-op-shard-count=" + str(shard_count),
|
||||
],
|
||||
hdr_out,
|
||||
),
|
||||
],
|
||||
tblgen = tblgen,
|
||||
td_file = td_file,
|
||||
test = test,
|
||||
deps = deps,
|
||||
)
|
||||
all_files = [hdr_out, src_out]
|
||||
for i in range(0, shard_count):
|
||||
out_file = "shard_copy_" + str(i) + "_" + src_file
|
||||
gentbl_shard_rule(
|
||||
index = i,
|
||||
name = name + "__src_shard" + str(i),
|
||||
testonly = test,
|
||||
out = out_file,
|
||||
sharder = sharder,
|
||||
src_file = src_file,
|
||||
)
|
||||
all_files.append(out_file)
|
||||
native.filegroup(name = name, srcs = all_files)
|
||||
|
||||
def gentbl_sharded_op_defs(name, source_file, shard_count):
|
||||
"""Generates multiple copies of a source file that includes sharded op definitions.
|
||||
|
||||
Args:
|
||||
name: The name of the rule.
|
||||
source_file: The source to copy.
|
||||
shard_count: The number of shards.
|
||||
|
||||
Returns:
|
||||
A list of the copied filenames to be included in the dialect library.
|
||||
"""
|
||||
copies = []
|
||||
for i in range(0, shard_count):
|
||||
out_file = "shard_copy_" + str(i) + "_" + source_file
|
||||
copies.append(out_file)
|
||||
native.genrule(
|
||||
name = name + "_shard_" + str(i),
|
||||
srcs = [source_file],
|
||||
outs = [out_file],
|
||||
cmd = "echo -e \"#define GET_OP_DEFS_" + str(i) + "\n$$(cat $(SRCS))\" > $(OUTS)",
|
||||
)
|
||||
return copies
|
||||
|
Loading…
x
Reference in New Issue
Block a user