[mlir] add an example of using transform dialect standalone (#82623)

Transform dialect interpreter is designed to be usable outside of the
pass pipeline, as the main program transformation driver, e.g., for
languages with explicit schedules. Provide an example of such usage with
a couple of tests.
This commit is contained in:
Oleksandr "Alex" Zinenko 2024-02-28 09:48:15 +01:00 committed by GitHub
parent a4fff36b6c
commit 619ee20b39
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 541 additions and 0 deletions

View File

@ -1,3 +1,4 @@
add_subdirectory(toy)
add_subdirectory(transform)
add_subdirectory(transform-opt)
add_subdirectory(minimal-opt)

View File

@ -0,0 +1,26 @@
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS)
set(LIBS
MLIRAnalysis
MLIRIR
MLIRParser
MLIRSupport
MLIRTransformDialect
MLIRTransformDialectTransforms
MLIRTransforms
${dialect_libs}
${conversion_libs}
${extension_libs}
)
add_mlir_tool(mlir-transform-opt
mlir-transform-opt.cpp
DEPENDS
${LIBS}
)
target_link_libraries(mlir-transform-opt PRIVATE ${LIBS})
llvm_update_compile_flags(mlir-transform-opt)
mlir_check_all_link_libraries(mlir-transform-opt)

View File

@ -0,0 +1,40 @@
# Standalone Transform Dialect Interpreter
This is an example of using the Transform dialect interpreter functionality standalone, that is, outside of the regular pass pipeline. The example is a
binary capable of processing MLIR source files similar to `mlir-opt` and other
optimizer drivers, with the entire transformation process driven by a Transform
dialect script. This script can be embedded into the source file or provided in
a separate MLIR source file.
Either the input module or the transform module must contain a top-level symbol
named `__transform_main`, which is used as the entry point to the transformation
script.
```sh
mlir-transform-opt payload_with_embedded_transform.mlir
mlir-transform-opt payload.mlir -transform=transform.mlir
```
The name of the entry point can be overridden using command-line options.
```sh
mlir-transform-opt payload-mlir -transform-entry-point=another_entry_point
```
Transform scripts can reference symbols defined in other source files, called
libraries, which can be supplied to the binary through command-line options.
Libraries will be embedded into the main transformation module by the tool and
the interpreter will process everything as a single module. A debug option is
available to see the contents of the transform module before it goes into the interpreter.
```sh
mlir-transform-opt payload.mlir -transform=transform.mlir \
-transform-library=external_definitions_1.mlir \
-transform-library=external_definitions_2.mlir \
-dump-library-module
```
Check out the [Transform dialect
tutorial](https://mlir.llvm.org/docs/Tutorials/transform/) as well as
[documentation](https://mlir.llvm.org/docs/Dialects/Transform/) to learn more
about the dialect.

View File

@ -0,0 +1,389 @@
//===- mlir-transform-opt.cpp -----------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/Utils.h"
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllExtensions.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h"
#include <cstdlib>
namespace {
using namespace llvm;
/// Structure containing command line options for the tool, these will get
/// initialized when an instance is created.
struct MlirTransformOptCLOptions {
cl::opt<bool> allowUnregisteredDialects{
"allow-unregistered-dialect",
cl::desc("Allow operations coming from an unregistered dialect"),
cl::init(false)};
cl::opt<bool> verifyDiagnostics{
"verify-diagnostics",
cl::desc("Check that emitted diagnostics match expected-* lines "
"on the corresponding line"),
cl::init(false)};
cl::opt<std::string> payloadFilename{cl::Positional, cl::desc("<input file>"),
cl::init("-")};
cl::opt<std::string> outputFilename{"o", cl::desc("Output filename"),
cl::value_desc("filename"),
cl::init("-")};
cl::opt<std::string> transformMainFilename{
"transform",
cl::desc("File containing entry point of the transform script, if "
"different from the input file"),
cl::value_desc("filename"), cl::init("")};
cl::list<std::string> transformLibraryFilenames{
"transform-library", cl::desc("File(s) containing definitions of "
"additional transform script symbols")};
cl::opt<std::string> transformEntryPoint{
"transform-entry-point",
cl::desc("Name of the entry point transform symbol"),
cl::init(mlir::transform::TransformDialect::kTransformEntryPointSymbolName
.str())};
cl::opt<bool> disableExpensiveChecks{
"disable-expensive-checks",
cl::desc("Disables potentially expensive checks in the transform "
"interpreter, providing more speed at the expense of "
"potential memory problems and silent corruptions"),
cl::init(false)};
cl::opt<bool> dumpLibraryModule{
"dump-library-module",
cl::desc("Prints the combined library module before the output"),
cl::init(false)};
};
} // namespace
/// "Managed" static instance of the command-line options structure. This makes
/// them locally-scoped and explicitly initialized/deinitialized. While this is
/// not strictly necessary in the tool source file that is not being used as a
/// library (where the options would pollute the global list of options), it is
/// good practice to follow this.
static llvm::ManagedStatic<MlirTransformOptCLOptions> clOptions;
/// Explicitly registers command-line options.
static void registerCLOptions() { *clOptions; }
namespace {
/// A wrapper class for source managers diagnostic. This provides both unique
/// ownership and virtual function-like overload for a pair of
/// inheritance-related classes that do not use virtual functions.
class DiagnosticHandlerWrapper {
public:
/// Kind of the diagnostic handler to use.
enum class Kind { EmitDiagnostics, VerifyDiagnostics };
/// Constructs the diagnostic handler of the specified kind of the given
/// source manager and context.
DiagnosticHandlerWrapper(Kind kind, llvm::SourceMgr &mgr,
mlir::MLIRContext *context) {
if (kind == Kind::EmitDiagnostics)
handler = new mlir::SourceMgrDiagnosticHandler(mgr, context);
else
handler = new mlir::SourceMgrDiagnosticVerifierHandler(mgr, context);
}
/// This object is non-copyable but movable.
DiagnosticHandlerWrapper(const DiagnosticHandlerWrapper &) = delete;
DiagnosticHandlerWrapper(DiagnosticHandlerWrapper &&other) = default;
DiagnosticHandlerWrapper &
operator=(const DiagnosticHandlerWrapper &) = delete;
DiagnosticHandlerWrapper &operator=(DiagnosticHandlerWrapper &&) = default;
/// Verifies the captured "expected-*" diagnostics if required.
mlir::LogicalResult verify() const {
if (auto *ptr =
handler.dyn_cast<mlir::SourceMgrDiagnosticVerifierHandler *>()) {
return ptr->verify();
}
return mlir::success();
}
/// Destructs the object of the same type as allocated.
~DiagnosticHandlerWrapper() {
if (auto *ptr = handler.dyn_cast<mlir::SourceMgrDiagnosticHandler *>()) {
delete ptr;
} else {
delete handler.get<mlir::SourceMgrDiagnosticVerifierHandler *>();
}
}
private:
/// Internal storage is a type-safe union.
llvm::PointerUnion<mlir::SourceMgrDiagnosticHandler *,
mlir::SourceMgrDiagnosticVerifierHandler *>
handler;
};
/// MLIR has deeply rooted expectations that the LLVM source manager contains
/// exactly one buffer, until at least the lexer level. This class wraps
/// multiple LLVM source managers each managing a buffer to match MLIR's
/// expectations while still providing a centralized handling mechanism.
class TransformSourceMgr {
public:
/// Constructs the source manager indicating whether diagnostic messages will
/// be verified later on.
explicit TransformSourceMgr(bool verifyDiagnostics)
: verifyDiagnostics(verifyDiagnostics) {}
/// Deconstructs the source manager. Note that `checkResults` must have been
/// called on this instance before deconstructing it.
~TransformSourceMgr() {
assert(resultChecked && "must check the result of diagnostic handlers by "
"running TransformSourceMgr::checkResult");
}
/// Parses the given buffer and creates the top-level operation of the kind
/// specified as template argument in the given context. Additional parsing
/// options may be provided.
template <typename OpTy = mlir::Operation *>
mlir::OwningOpRef<OpTy> parseBuffer(std::unique_ptr<MemoryBuffer> buffer,
mlir::MLIRContext &context,
const mlir::ParserConfig &config) {
// Create a single-buffer LLVM source manager. Note that `unique_ptr` allows
// the code below to capture a reference to the source manager in such a way
// that it is not invalidated when the vector contents is eventually
// reallocated.
llvm::SourceMgr &mgr =
*sourceMgrs.emplace_back(std::make_unique<llvm::SourceMgr>());
mgr.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc());
// Choose the type of diagnostic handler depending on whether diagnostic
// verification needs to happen and store it.
if (verifyDiagnostics) {
diagHandlers.emplace_back(
DiagnosticHandlerWrapper::Kind::VerifyDiagnostics, mgr, &context);
} else {
diagHandlers.emplace_back(DiagnosticHandlerWrapper::Kind::EmitDiagnostics,
mgr, &context);
}
// Defer to MLIR's parser.
return mlir::parseSourceFile<OpTy>(mgr, config);
}
/// If diagnostic message verification has been requested upon construction of
/// this source manager, performs the verification, reports errors and returns
/// the result of the verification. Otherwise passes through the given value.
mlir::LogicalResult checkResult(mlir::LogicalResult result) {
resultChecked = true;
if (!verifyDiagnostics)
return result;
return mlir::failure(llvm::any_of(diagHandlers, [](const auto &handler) {
return mlir::failed(handler.verify());
}));
}
private:
/// Indicates whether diagnostic message verification is requested.
const bool verifyDiagnostics;
/// Indicates that diagnostic message verification has taken place, and the
/// deconstruction is therefore safe.
bool resultChecked = false;
/// Storage for per-buffer source managers and diagnostic handlers. These are
/// wrapped into unique pointers in order to make it safe to capture
/// references to these objects: if the vector is reallocated, the unique
/// pointer objects are moved by the pointer addresses won't change. Also, for
/// handlers, this allows to store the pointer to the base class.
SmallVector<std::unique_ptr<llvm::SourceMgr>> sourceMgrs;
SmallVector<DiagnosticHandlerWrapper> diagHandlers;
};
} // namespace
/// Trivial wrapper around `applyTransforms` that doesn't support extra mapping
/// and doesn't enforce the entry point transform ops being top-level.
static mlir::LogicalResult
applyTransforms(mlir::Operation *payloadRoot,
mlir::transform::TransformOpInterface transformRoot,
const mlir::transform::TransformOptions &options) {
return applyTransforms(payloadRoot, transformRoot, {}, options,
/*enforceToplevelTransformOp=*/false);
}
/// Applies transforms indicated in the transform dialect script to the input
/// buffer. The transform script may be embedded in the input buffer or as a
/// separate buffer. The transform script may have external symbols, the
/// definitions of which must be provided in transform library buffers. If the
/// application is successful, prints the transformed input buffer into the
/// given output stream. Additional configuration options are derived from
/// command-line options.
static mlir::LogicalResult processPayloadBuffer(
raw_ostream &os, std::unique_ptr<MemoryBuffer> inputBuffer,
std::unique_ptr<llvm::MemoryBuffer> transformBuffer,
MutableArrayRef<std::unique_ptr<MemoryBuffer>> transformLibraries,
mlir::DialectRegistry &registry) {
// Initialize the MLIR context, and various configurations.
mlir::MLIRContext context(registry, mlir::MLIRContext::Threading::DISABLED);
context.allowUnregisteredDialects(clOptions->allowUnregisteredDialects);
mlir::ParserConfig config(&context);
TransformSourceMgr sourceMgr(
/*verifyDiagnostics=*/clOptions->verifyDiagnostics);
// Parse the input buffer that will be used as transform payload.
mlir::OwningOpRef<mlir::Operation *> payloadRoot =
sourceMgr.parseBuffer(std::move(inputBuffer), context, config);
if (!payloadRoot)
return sourceMgr.checkResult(mlir::failure());
// Identify the module containing the transform script entry point. This may
// be the same module as the input or a separate module. In the former case,
// make a copy of the module so it can be modified freely. Modification may
// happen in the script itself (at which point it could be rewriting itself
// during interpretation, leading to tricky memory errors) or by embedding
// library modules in the script.
mlir::OwningOpRef<mlir::ModuleOp> transformRoot;
if (transformBuffer) {
transformRoot = sourceMgr.parseBuffer<mlir::ModuleOp>(
std::move(transformBuffer), context, config);
if (!transformRoot)
return sourceMgr.checkResult(mlir::failure());
} else {
transformRoot = cast<mlir::ModuleOp>(payloadRoot->clone());
}
// Parse and merge the libraries into the main transform module.
for (auto &&transformLibrary : transformLibraries) {
mlir::OwningOpRef<mlir::ModuleOp> libraryModule =
sourceMgr.parseBuffer<mlir::ModuleOp>(std::move(transformLibrary),
context, config);
if (!libraryModule ||
mlir::failed(mlir::transform::detail::mergeSymbolsInto(
*transformRoot, std::move(libraryModule))))
return sourceMgr.checkResult(mlir::failure());
}
// If requested, dump the combined transform module.
if (clOptions->dumpLibraryModule)
transformRoot->dump();
// Find the entry point symbol. Even if it had originally been in the payload
// module, it was cloned into the transform module so only look there.
mlir::transform::TransformOpInterface entryPoint =
mlir::transform::detail::findTransformEntryPoint(
*transformRoot, mlir::ModuleOp(), clOptions->transformEntryPoint);
if (!entryPoint)
return sourceMgr.checkResult(mlir::failure());
// Apply the requested transformations.
mlir::transform::TransformOptions transformOptions;
transformOptions.enableExpensiveChecks(!clOptions->disableExpensiveChecks);
if (mlir::failed(applyTransforms(*payloadRoot, entryPoint, transformOptions)))
return sourceMgr.checkResult(mlir::failure());
// Print the transformed result and check the captured diagnostics if
// requested.
payloadRoot->print(os);
return sourceMgr.checkResult(mlir::success());
}
/// Tool entry point.
static mlir::LogicalResult runMain(int argc, char **argv) {
// Register all upstream dialects and extensions. Specific uses are advised
// not to register all dialects indiscriminately but rather hand-pick what is
// necessary for their use case.
mlir::DialectRegistry registry;
mlir::registerAllDialects(registry);
mlir::registerAllExtensions(registry);
mlir::registerAllPasses();
// Explicitly register the transform dialect. This is not strictly necessary
// since it has been already registered as part of the upstream dialect list,
// but useful for example purposes for cases when dialects to register are
// hand-picked. The transform dialect must be registered.
registry.insert<mlir::transform::TransformDialect>();
// Register various command-line options. Note that the LLVM initializer
// object is a RAII that ensures correct deconstruction of command-line option
// objects inside ManagedStatic.
llvm::InitLLVM y(argc, argv);
mlir::registerAsmPrinterCLOptions();
mlir::registerMLIRContextCLOptions();
registerCLOptions();
llvm::cl::ParseCommandLineOptions(argc, argv,
"Minimal Transform dialect driver\n");
// Try opening the main input file.
std::string errorMessage;
std::unique_ptr<llvm::MemoryBuffer> payloadFile =
mlir::openInputFile(clOptions->payloadFilename, &errorMessage);
if (!payloadFile) {
llvm::errs() << errorMessage << "\n";
return mlir::failure();
}
// Try opening the output file.
std::unique_ptr<llvm::ToolOutputFile> outputFile =
mlir::openOutputFile(clOptions->outputFilename, &errorMessage);
if (!outputFile) {
llvm::errs() << errorMessage << "\n";
return mlir::failure();
}
// Try opening the main transform file if provided.
std::unique_ptr<llvm::MemoryBuffer> transformRootFile;
if (!clOptions->transformMainFilename.empty()) {
if (clOptions->transformMainFilename == clOptions->payloadFilename) {
llvm::errs() << "warning: " << clOptions->payloadFilename
<< " is provided as both payload and transform file\n";
} else {
transformRootFile =
mlir::openInputFile(clOptions->transformMainFilename, &errorMessage);
if (!transformRootFile) {
llvm::errs() << errorMessage << "\n";
return mlir::failure();
}
}
}
// Try opening transform library files if provided.
SmallVector<std::unique_ptr<llvm::MemoryBuffer>> transformLibraries;
transformLibraries.reserve(clOptions->transformLibraryFilenames.size());
for (llvm::StringRef filename : clOptions->transformLibraryFilenames) {
transformLibraries.emplace_back(
mlir::openInputFile(filename, &errorMessage));
if (!transformLibraries.back()) {
llvm::errs() << errorMessage << "\n";
return mlir::failure();
}
}
return processPayloadBuffer(outputFile->os(), std::move(payloadFile),
std::move(transformRootFile), transformLibraries,
registry);
}
int main(int argc, char **argv) {
return mlir::asMainReturnCode(runMain(argc, argv));
}

View File

@ -173,6 +173,7 @@ if(LLVM_BUILD_EXAMPLES)
transform-opt-ch3
transform-opt-ch4
mlir-minimal-opt
mlir-transform-opt
)
if(MLIR_ENABLE_EXECUTION_ENGINE)
list(APPEND MLIR_TEST_DEPENDS

View File

@ -0,0 +1,12 @@
// RUN: mlir-transform-opt %s --transform=%p/self-contained.mlir | FileCheck %s
// RUN: mlir-transform-opt %s --transform=%p/external-decl.mlir --verify-diagnostics
// RUN: mlir-transform-opt %s --transform=%p/external-def.mlir --transform-entry-point=external_def | FileCheck %s --check-prefix=EXTERNAL
// RUN: mlir-transform-opt %s --transform=%p/external-decl.mlir --transform-library=%p/external-def.mlir | FileCheck %s --check-prefix=EXTERNAL
// RUN: mlir-transform-opt %s --transform=%p/syntax-error.mlir --verify-diagnostics
// RUN: mlir-transform-opt %s --transform=%p/self-contained.mlir --transform-library=%p/syntax-error.mlir --verify-diagnostics
// RUN: mlir-transform-opt %s --transform=%p/self-contained.mlir --transform-library=%p/external-def.mlir --transform-library=%p/syntax-error.mlir --verify-diagnostics
// CHECK: IR printer: in self-contained
// EXTERNAL: IR printer: external_def
// CHECK-NOT: @__transform_main
module {}

View File

@ -0,0 +1,18 @@
// This test just needs to parse. Note that the diagnostic message below will
// be produced in *another* multi-file test, do *not* -verify-diagnostics here.
// RUN: mlir-opt %s
// RUN: mlir-transform-opt %s --transform-library=%p/external-def.mlir | FileCheck %s
module attributes {transform.with_named_sequence} {
// The definition should not be printed here.
// CHECK: @external_def
// CHECK-NOT: transform.print
transform.named_sequence private @external_def(%root: !transform.any_op {transform.readonly})
transform.named_sequence private @__transform_main(%root: !transform.any_op) {
// expected-error @below {{unresolved external named sequence}}
transform.include @external_def failures(propagate) (%root) : (!transform.any_op) -> ()
transform.yield
}
}

View File

@ -0,0 +1,8 @@
// RUN: mlir-opt %s
module attributes {transform.with_named_sequence} {
transform.named_sequence @external_def(%root: !transform.any_op {transform.readonly}) {
transform.print %root { name = "external_def" } : !transform.any_op
transform.yield
}
}

View File

@ -0,0 +1,19 @@
// RUN: mlir-transform-opt %s | FileCheck %s
module attributes {transform.with_named_sequence} {
// CHECK-LABEL: @return_42
// CHECK: %[[C42:.+]] = arith.constant 42
// CHECK: return %[[C42]]
func.func @return_42() -> i32 {
%0 = arith.constant 21 : i32
%1 = arith.constant 2 : i32
%2 = arith.muli %0, %1 : i32
return %2 : i32
}
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
%arg1 = transform.apply_registered_pass "canonicalize" to %arg0 : (!transform.any_op) -> !transform.any_op
transform.print %arg1 : !transform.any_op
transform.yield
}
}

View File

@ -0,0 +1,21 @@
// RUN: mlir-transform-opt %s | FileCheck %s
// RUN: mlir-transform-opt %s --transform=%s | FileCheck %s
// RUN: mlir-transform-opt %s --transform=%p/external-decl.mlir --verify-diagnostics
// RUN: mlir-transform-opt %s --transform=%p/external-def.mlir --transform-entry-point=external_def | FileCheck %s --check-prefix=EXTERNAL
// RUN: mlir-transform-opt %s --transform=%p/external-decl.mlir --transform-library=%p/external-def.mlir | FileCheck %s --check-prefix=EXTERNAL
// RUN: mlir-transform-opt %s --transform=%p/syntax-error.mlir --verify-diagnostics
// CHECK: IR printer: in self-contained
// EXTERNAL: IR printer: external_def
// The first occurrence comes from the print operation and the second is the
// roundtrip output. However, we shouldn't have the symbol duplicated because
// of library merging.
// CHECK-COUNT-2: @__transform_main
// CHECK-NOT: @__transform_main
module attributes {transform.with_named_sequence} {
transform.named_sequence private @__transform_main(%root: !transform.any_op) {
transform.print %root { name = "in self-contained" } : !transform.any_op
transform.yield
}
}

View File

@ -0,0 +1,5 @@
// RUN: mlir-opt %s --verify-diagnostics
// This file is used as additional input.
// expected-error @below {{expected operation name in quotes}}
module {

View File

@ -161,6 +161,7 @@ tools.extend(
ToolSubst("transform-opt-ch2", unresolved="ignore"),
ToolSubst("transform-opt-ch3", unresolved="ignore"),
ToolSubst("transform-opt-ch4", unresolved="ignore"),
ToolSubst("mlir-transform-opt", unresolved="ignore"),
ToolSubst("%mlir_lib_dir", config.mlir_lib_dir, unresolved="ignore"),
ToolSubst("%mlir_src_dir", config.mlir_src_root, unresolved="ignore"),
]