diff --git a/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp b/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp index 34459b8564a2..6c1b6820707f 100644 --- a/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp +++ b/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp @@ -18,6 +18,7 @@ #include "mlir/Pass/PassManager.h" #include "mlir/Reducer/Passes.h" #include "mlir/Support/FileUtilities.h" +#include "mlir/Support/ToolUtilities.h" #include "mlir/Tools/ParseUtilities.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/SourceMgr.h" @@ -25,23 +26,6 @@ using namespace mlir; -// Parse and verify the input MLIR file. Returns null on error. -static OwningOpRef loadModule(MLIRContext &context, - StringRef inputFilename, - bool insertImplictModule) { - // Set up the input file. - std::string errorMessage; - auto file = openInputFile(inputFilename, &errorMessage); - if (!file) { - llvm::errs() << errorMessage << "\n"; - return nullptr; - } - - auto sourceMgr = std::make_shared(); - sourceMgr->AddNewSourceBuffer(std::move(file), SMLoc()); - return parseSourceFileForTool(sourceMgr, &context, insertImplictModule); -} - LogicalResult mlir::mlirReduceMain(int argc, char **argv, MLIRContext &context) { // Override the default '-h' and use the default PrintHelpMessage() which @@ -70,6 +54,18 @@ LogicalResult mlir::mlirReduceMain(int argc, char **argv, llvm::cl::desc("Allow operation with no registered dialects"), llvm::cl::init(false)); + static llvm::cl::opt splitInputFile( + "split-input-file", llvm::cl::ValueOptional, + llvm::cl::callback([&](const std::string &str) { + // Implicit value: use default marker if flag was used without + // value. + if (str.empty()) + splitInputFile.setValue(kDefaultSplitMarker); + }), + llvm::cl::desc("Split the input file into chunks using the given or " + "default marker and process each chunk independently"), + llvm::cl::init("")); + llvm::cl::HideUnrelatedOptions(mlirReduceCategory); llvm::InitLLVM y(argc, argv); @@ -93,27 +89,44 @@ LogicalResult mlir::mlirReduceMain(int argc, char **argv, if (!output) return failure(); - OwningOpRef opRef = - loadModule(context, inputFilename, !noImplicitModule); - if (!opRef) + std::unique_ptr input = + openInputFile(inputFilename, &errorMessage); + if (!input) { + llvm::errs() << errorMessage << "\n"; return failure(); + } auto errorHandler = [&](const Twine &msg) { return emitError(UnknownLoc::get(&context)) << msg; }; - // Reduction pass pipeline. - PassManager pm(&context, opRef.get()->getName().getStringRef()); - if (failed(parser.addToPipeline(pm, errorHandler))) - return failure(); + auto chunkFn = [&](std::unique_ptr chunkBuffer, + raw_ostream &os) { + auto sourceMgr = std::make_shared(); + sourceMgr->AddNewSourceBuffer(std::move(chunkBuffer), SMLoc()); + OwningOpRef opRef = + parseSourceFileForTool(sourceMgr, &context, !noImplicitModule); + if (!opRef) + return failure(); + // Reduction pass pipeline. + PassManager pm(&context, opRef.get()->getName().getStringRef()); + if (failed(parser.addToPipeline(pm, errorHandler))) + return failure(); - OwningOpRef op = opRef.get()->clone(); + OwningOpRef op = opRef.get()->clone(); - if (failed(pm.run(op.get()))) - return failure(); + if (failed(pm.run(op.get()))) + return failure(); + op.get()->print(output->os()); + output->keep(); + return success(); + }; - op.get()->print(output->os()); - output->keep(); + auto &splitInputFileDelimiter = splitInputFile.getValue(); + if (!splitInputFileDelimiter.empty()) + return splitAndProcessBuffer(std::move(input), chunkFn, output->os(), + splitInputFileDelimiter, + splitInputFileDelimiter); - return success(); + return chunkFn(std::move(input), output->os()); }