This is a new attempt at #69320. The transform dialect stores a "library module" that the preload pass can populate. Until now, each pass registered an additional module by simply pushing it to a vector; however, the interpreter only used the first of them. This commit turns the registration into "loading", i.e., each newly added module gets merged into the existing one. This allows the loading to be split into several passes, and using the library in the interpreter now takes all of them into account. While this design avoids repeated merging every time the library is accessed, it requires that the implementation of merging modules lives in the TransformDialect target (since it at the dialect depend on each other). This resolves https://github.com/llvm/llvm-project/issues/69111.
98 lines
3.7 KiB
C++
98 lines
3.7 KiB
C++
//===- Preload.cpp - Test MlirOptMain parameterization ------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
|
#include "mlir/Dialect/Transform/IR/Utils.h"
|
|
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
|
|
#include "mlir/IR/AsmState.h"
|
|
#include "mlir/IR/DialectRegistry.h"
|
|
#include "mlir/IR/Verifier.h"
|
|
#include "mlir/Parser/Parser.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Pass/PassManager.h"
|
|
#include "mlir/Support/FileUtilities.h"
|
|
#include "mlir/Support/TypeID.h"
|
|
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
|
|
#include "llvm/Support/MemoryBuffer.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
#include "gtest/gtest.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace mlir {
|
|
namespace test {
|
|
std::unique_ptr<Pass> createTestTransformDialectInterpreterPass();
|
|
} // namespace test
|
|
} // namespace mlir
|
|
namespace test {
|
|
void registerTestTransformDialectExtension(DialectRegistry ®istry);
|
|
} // namespace test
|
|
|
|
const static llvm::StringLiteral library = R"MLIR(
|
|
module attributes {transform.with_named_sequence} {
|
|
transform.named_sequence private @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
|
transform.test_print_remark_at_operand %arg0, "from external symbol" : !transform.any_op
|
|
transform.yield
|
|
}
|
|
})MLIR";
|
|
|
|
const static llvm::StringLiteral input = R"MLIR(
|
|
module attributes {transform.with_named_sequence} {
|
|
transform.named_sequence private @__transform_main(%arg0: !transform.any_op {transform.readonly})
|
|
|
|
transform.sequence failures(propagate) {
|
|
^bb0(%arg0: !transform.any_op):
|
|
include @__transform_main failures(propagate) (%arg0) : (!transform.any_op) -> ()
|
|
}
|
|
})MLIR";
|
|
|
|
TEST(Preload, ContextPreloadConstructedLibrary) {
|
|
registerPassManagerCLOptions();
|
|
|
|
MLIRContext context;
|
|
auto *dialect = context.getOrLoadDialect<transform::TransformDialect>();
|
|
DialectRegistry registry;
|
|
::test::registerTestTransformDialectExtension(registry);
|
|
registry.applyExtensions(&context);
|
|
ParserConfig parserConfig(&context);
|
|
|
|
OwningOpRef<ModuleOp> inputModule =
|
|
parseSourceString<ModuleOp>(input, parserConfig, "<input>");
|
|
EXPECT_TRUE(inputModule) << "failed to parse input module";
|
|
|
|
OwningOpRef<ModuleOp> transformLibrary =
|
|
parseSourceString<ModuleOp>(library, parserConfig, "<transform-library>");
|
|
EXPECT_TRUE(transformLibrary) << "failed to parse transform module";
|
|
LogicalResult diag =
|
|
dialect->loadIntoLibraryModule(std::move(transformLibrary));
|
|
EXPECT_TRUE(succeeded(diag));
|
|
|
|
ModuleOp retrievedTransformLibrary =
|
|
transform::detail::getPreloadedTransformModule(&context);
|
|
EXPECT_TRUE(retrievedTransformLibrary)
|
|
<< "failed to retrieve transform module";
|
|
|
|
OwningOpRef<Operation *> clonedTransformModule(
|
|
retrievedTransformLibrary->clone());
|
|
|
|
LogicalResult res = transform::detail::mergeSymbolsInto(
|
|
inputModule->getOperation(), std::move(clonedTransformModule));
|
|
EXPECT_TRUE(succeeded(res)) << "failed to define declared symbols";
|
|
|
|
transform::TransformOpInterface entryPoint =
|
|
transform::detail::findTransformEntryPoint(inputModule->getOperation(),
|
|
retrievedTransformLibrary);
|
|
EXPECT_TRUE(entryPoint) << "failed to find entry point";
|
|
|
|
transform::TransformOptions options;
|
|
res = transform::applyTransformNamedSequence(
|
|
inputModule->getOperation(), entryPoint, retrievedTransformLibrary,
|
|
options);
|
|
EXPECT_TRUE(succeeded(res)) << "failed to apply named sequence";
|
|
}
|