
Introduce a new extension for simple print-debugging of the transform dialect scripts. The initial version of this extension consists of two ops that are printing the payload objects associated with transform dialect values. Similar ops were already available in the test extenion and several downstream projects, and were extensively used for testing.
96 lines
3.6 KiB
C++
96 lines
3.6 KiB
C++
//===- Preload.cpp - Test MlirOptMain parameterization ------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Transform/DebugExtension/DebugExtension.h"
|
|
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
|
#include "mlir/Dialect/Transform/IR/Utils.h"
|
|
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
|
|
#include "mlir/IR/AsmState.h"
|
|
#include "mlir/IR/DialectRegistry.h"
|
|
#include "mlir/IR/Verifier.h"
|
|
#include "mlir/Parser/Parser.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Pass/PassManager.h"
|
|
#include "mlir/Support/FileUtilities.h"
|
|
#include "mlir/Support/TypeID.h"
|
|
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
|
|
#include "llvm/Support/MemoryBuffer.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
#include "gtest/gtest.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace mlir {
|
|
namespace test {
|
|
std::unique_ptr<Pass> createTestTransformDialectInterpreterPass();
|
|
} // namespace test
|
|
} // namespace mlir
|
|
|
|
const static llvm::StringLiteral library = R"MLIR(
|
|
module attributes {transform.with_named_sequence} {
|
|
transform.named_sequence private @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
|
transform.debug.emit_remark_at %arg0, "from external symbol" : !transform.any_op
|
|
transform.yield
|
|
}
|
|
})MLIR";
|
|
|
|
const static llvm::StringLiteral input = R"MLIR(
|
|
module attributes {transform.with_named_sequence} {
|
|
transform.named_sequence private @__transform_main(%arg0: !transform.any_op {transform.readonly})
|
|
|
|
transform.sequence failures(propagate) {
|
|
^bb0(%arg0: !transform.any_op):
|
|
include @__transform_main failures(propagate) (%arg0) : (!transform.any_op) -> ()
|
|
}
|
|
})MLIR";
|
|
|
|
TEST(Preload, ContextPreloadConstructedLibrary) {
|
|
registerPassManagerCLOptions();
|
|
|
|
MLIRContext context;
|
|
auto *dialect = context.getOrLoadDialect<transform::TransformDialect>();
|
|
DialectRegistry registry;
|
|
mlir::transform::registerDebugExtension(registry);
|
|
registry.applyExtensions(&context);
|
|
ParserConfig parserConfig(&context);
|
|
|
|
OwningOpRef<ModuleOp> inputModule =
|
|
parseSourceString<ModuleOp>(input, parserConfig, "<input>");
|
|
EXPECT_TRUE(inputModule) << "failed to parse input module";
|
|
|
|
OwningOpRef<ModuleOp> transformLibrary =
|
|
parseSourceString<ModuleOp>(library, parserConfig, "<transform-library>");
|
|
EXPECT_TRUE(transformLibrary) << "failed to parse transform module";
|
|
LogicalResult diag =
|
|
dialect->loadIntoLibraryModule(std::move(transformLibrary));
|
|
EXPECT_TRUE(succeeded(diag));
|
|
|
|
ModuleOp retrievedTransformLibrary =
|
|
transform::detail::getPreloadedTransformModule(&context);
|
|
EXPECT_TRUE(retrievedTransformLibrary)
|
|
<< "failed to retrieve transform module";
|
|
|
|
OwningOpRef<Operation *> clonedTransformModule(
|
|
retrievedTransformLibrary->clone());
|
|
|
|
LogicalResult res = transform::detail::mergeSymbolsInto(
|
|
inputModule->getOperation(), std::move(clonedTransformModule));
|
|
EXPECT_TRUE(succeeded(res)) << "failed to define declared symbols";
|
|
|
|
transform::TransformOpInterface entryPoint =
|
|
transform::detail::findTransformEntryPoint(inputModule->getOperation(),
|
|
retrievedTransformLibrary);
|
|
EXPECT_TRUE(entryPoint) << "failed to find entry point";
|
|
|
|
transform::TransformOptions options;
|
|
res = transform::applyTransformNamedSequence(
|
|
inputModule->getOperation(), entryPoint, retrievedTransformLibrary,
|
|
options);
|
|
EXPECT_TRUE(succeeded(res)) << "failed to apply named sequence";
|
|
}
|