[mlir] Support DialectRegistry extension comparison (#101119)

`PassManager::run` loads the dependent dialects for each pass into the
current context prior to invoking the individual passes. If the
dependent dialect is already loaded into the context, this should be a
no-op. However, if there are extensions registered in the
`DialectRegistry`, the dependent dialects are unconditionally registered
into the context.

This poses a problem for dynamic pass pipelines, however, because they
will likely be executing while the context is in an immutable state
(because of the parent pass pipeline being run).

To solve this, we'll update the extension registration API on
`DialectRegistry` to require a type ID for each extension that is
registered. Then, instead of unconditionally registered dialects into a
context if extensions are present, we'll check against the extension
type IDs already present in the context's internal `DialectRegistry`.
The context will only be marked as dirty if there are net-new extension
types present in the `DialectRegistry` populated by
`PassManager::getDependentDialects`.

Note: this PR removes the `addExtension` overload that utilizes
`std::function` as the parameter. This is because `std::function` is
copyable and potentially allocates memory for the contained function so
we can't use the function pointer as the unique type ID for the
extension.

Downstream changes required:
- Existing `DialectExtension` subclasses will need a type ID to be
registered for each subclass. More details on how to register a type ID
can be found here:
8b68e06731/mlir/include/mlir/Support/TypeID.h (L30)
- Existing uses of the `std::function` overload of `addExtension` will
need to be refactored into dedicated `DialectExtension` classes with
associated type IDs. The attached `std::function` can either be inlined
into or called directly from `DialectExtension::apply`.

---------

Co-authored-by: Mehdi Amini <joker.eph@gmail.com>
This commit is contained in:
Nikhil Kalra 2024-08-05 16:32:36 -07:00 committed by GitHub
parent 2fd2fd2c46
commit 84cc1865ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 167 additions and 32 deletions

View File

@ -29,6 +29,9 @@
class MyExtension
: public ::mlir::transform::TransformDialectExtension<MyExtension> {
public:
// The TypeID of this extension.
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MyExtension)
// The extension must derive the base constructor.
using Base::Base;

View File

@ -35,6 +35,9 @@
class MyExtension
: public ::mlir::transform::TransformDialectExtension<MyExtension> {
public:
// The TypeID of this extension.
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MyExtension)
// The extension must derive the base constructor.
using Base::Base;

View File

@ -31,6 +31,9 @@
class MyExtension
: public ::mlir::transform::TransformDialectExtension<MyExtension> {
public:
// The TypeID of this extension.
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MyExtension)
// The extension must derive the base constructor.
using Base::Base;

View File

@ -14,9 +14,9 @@
#define MLIR_IR_DIALECTREGISTRY_H
#include "mlir/IR/MLIRContext.h"
#include "mlir/Support/TypeID.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/MapVector.h"
#include <map>
#include <tuple>
@ -187,7 +187,8 @@ public:
nameAndRegistrationIt.second.second);
// Merge the extensions.
for (const auto &extension : extensions)
destination.extensions.push_back(extension->clone());
destination.extensions.try_emplace(extension.first,
extension.second->clone());
}
/// Return the names of dialects known to this registry.
@ -206,39 +207,37 @@ public:
void applyExtensions(MLIRContext *ctx) const;
/// Add the given extension to the registry.
void addExtension(std::unique_ptr<DialectExtensionBase> extension) {
extensions.push_back(std::move(extension));
bool addExtension(TypeID extensionID,
std::unique_ptr<DialectExtensionBase> extension) {
return extensions.try_emplace(extensionID, std::move(extension)).second;
}
/// Add the given extensions to the registry.
template <typename... ExtensionsT>
void addExtensions() {
(addExtension(std::make_unique<ExtensionsT>()), ...);
(addExtension(TypeID::get<ExtensionsT>(), std::make_unique<ExtensionsT>()),
...);
}
/// Add an extension function that requires the given dialects.
/// Note: This bare functor overload is provided in addition to the
/// std::function variant to enable dialect type deduction, e.g.:
/// registry.addExtension(+[](MLIRContext *ctx, MyDialect *dialect) { ... })
/// registry.addExtension(+[](MLIRContext *ctx, MyDialect *dialect) {
/// ... })
///
/// is equivalent to:
/// registry.addExtension<MyDialect>(
/// [](MLIRContext *ctx, MyDialect *dialect){ ... }
/// )
template <typename... DialectsT>
void addExtension(void (*extensionFn)(MLIRContext *, DialectsT *...)) {
addExtension<DialectsT...>(
std::function<void(MLIRContext *, DialectsT * ...)>(extensionFn));
}
template <typename... DialectsT>
void
addExtension(std::function<void(MLIRContext *, DialectsT *...)> extensionFn) {
using ExtensionFnT = std::function<void(MLIRContext *, DialectsT * ...)>;
bool addExtension(void (*extensionFn)(MLIRContext *, DialectsT *...)) {
using ExtensionFnT = void (*)(MLIRContext *, DialectsT *...);
struct Extension : public DialectExtension<Extension, DialectsT...> {
Extension(const Extension &) = default;
Extension(ExtensionFnT extensionFn)
: extensionFn(std::move(extensionFn)) {}
: DialectExtension<Extension, DialectsT...>(),
extensionFn(extensionFn) {}
~Extension() override = default;
void apply(MLIRContext *context, DialectsT *...dialects) const final {
@ -246,7 +245,9 @@ public:
}
ExtensionFnT extensionFn;
};
addExtension(std::make_unique<Extension>(std::move(extensionFn)));
return addExtension(TypeID::getFromOpaquePointer(
reinterpret_cast<const void *>(extensionFn)),
std::make_unique<Extension>(extensionFn));
}
/// Returns true if the current registry is a subset of 'rhs', i.e. if 'rhs'
@ -255,7 +256,7 @@ public:
private:
MapTy registry;
std::vector<std::unique_ptr<DialectExtensionBase>> extensions;
llvm::MapVector<TypeID, std::unique_ptr<DialectExtensionBase>> extensions;
};
} // namespace mlir

View File

@ -35,6 +35,8 @@ namespace {
/// starting a pass pipeline that involves dialect conversion to LLVM.
class LoadDependentDialectExtension : public DialectExtensionBase {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LoadDependentDialectExtension)
LoadDependentDialectExtension() : DialectExtensionBase(/*dialectNames=*/{}) {}
void apply(MLIRContext *context,

View File

@ -157,6 +157,8 @@ class AffineTransformDialectExtension
: public transform::TransformDialectExtension<
AffineTransformDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AffineTransformDialectExtension)
using Base::Base;
void init() {

View File

@ -150,6 +150,9 @@ class BufferizationTransformDialectExtension
: public transform::TransformDialectExtension<
BufferizationTransformDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
BufferizationTransformDialectExtension)
using Base::Base;
void init() {

View File

@ -236,6 +236,8 @@ class FuncTransformDialectExtension
: public transform::TransformDialectExtension<
FuncTransformDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FuncTransformDialectExtension)
using Base::Base;
void init() {

View File

@ -924,6 +924,8 @@ class GPUTransformDialectExtension
: public transform::TransformDialectExtension<
GPUTransformDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GPUTransformDialectExtension)
GPUTransformDialectExtension() {
declareGeneratedDialect<scf::SCFDialect>();
declareGeneratedDialect<arith::ArithDialect>();

View File

@ -30,6 +30,8 @@ class LinalgTransformDialectExtension
: public transform::TransformDialectExtension<
LinalgTransformDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LinalgTransformDialectExtension)
using Base::Base;
void init() {

View File

@ -309,6 +309,8 @@ class MemRefTransformDialectExtension
: public transform::TransformDialectExtension<
MemRefTransformDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MemRefTransformDialectExtension)
using Base::Base;
void init() {

View File

@ -1135,6 +1135,8 @@ class NVGPUTransformDialectExtension
: public transform::TransformDialectExtension<
NVGPUTransformDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NVGPUTransformDialectExtension)
NVGPUTransformDialectExtension() {
declareGeneratedDialect<arith::ArithDialect>();
declareGeneratedDialect<affine::AffineDialect>();

View File

@ -613,6 +613,8 @@ class SCFTransformDialectExtension
: public transform::TransformDialectExtension<
SCFTransformDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SCFTransformDialectExtension)
using Base::Base;
void init() {

View File

@ -38,6 +38,9 @@ class SparseTensorTransformDialectExtension
: public transform::TransformDialectExtension<
SparseTensorTransformDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
SparseTensorTransformDialectExtension)
SparseTensorTransformDialectExtension() {
declareGeneratedDialect<sparse_tensor::SparseTensorDialect>();
registerTransformOps<

View File

@ -236,6 +236,8 @@ class TensorTransformDialectExtension
: public transform::TransformDialectExtension<
TensorTransformDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TensorTransformDialectExtension)
using Base::Base;
void init() {

View File

@ -20,6 +20,8 @@ namespace {
class DebugExtension
: public transform::TransformDialectExtension<DebugExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DebugExtension)
void init() {
registerTransformOps<
#define GET_OP_LIST

View File

@ -18,6 +18,8 @@ namespace {
class IRDLExtension
: public transform::TransformDialectExtension<IRDLExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(IRDLExtension)
void init() {
registerTransformOps<
#define GET_OP_LIST

View File

@ -20,6 +20,8 @@ namespace {
class LoopExtension
: public transform::TransformDialectExtension<LoopExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LoopExtension)
void init() {
registerTransformOps<
#define GET_OP_LIST

View File

@ -38,6 +38,8 @@ namespace {
/// with Transform dialect operations.
class PDLExtension : public transform::TransformDialectExtension<PDLExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PDLExtension)
void init() {
registerTransformOps<
#define GET_OP_LIST

View File

@ -212,6 +212,8 @@ class VectorTransformDialectExtension
: public transform::TransformDialectExtension<
VectorTransformDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VectorTransformDialectExtension)
VectorTransformDialectExtension() {
declareGeneratedDialect<vector::VectorDialect>();
declareGeneratedDialect<LLVM::LLVMDialect>();

View File

@ -11,14 +11,20 @@
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/DialectInterface.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/ExtensibleDialect.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/Support/TypeID.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/Regex.h"
#include <memory>
#define DEBUG_TYPE "dialect"
@ -173,6 +179,40 @@ bool dialect_extension_detail::hasPromisedInterface(Dialect &dialect,
// DialectRegistry
//===----------------------------------------------------------------------===//
namespace {
template <typename Fn>
void applyExtensionsFn(
Fn &&applyExtension,
const llvm::MapVector<TypeID, std::unique_ptr<DialectExtensionBase>>
&extensions) {
// Note: Additional extensions may be added while applying an extension.
// The iterators will be invalidated if extensions are added so we'll keep
// a copy of the extensions for ourselves.
const auto extractExtension =
[](const auto &entry) -> DialectExtensionBase * {
return entry.second.get();
};
auto startIt = extensions.begin(), endIt = extensions.end();
size_t count = 0;
while (startIt != endIt) {
count += endIt - startIt;
// Grab the subset of extensions we'll apply in this iteration.
const auto subset =
llvm::map_to_vector(llvm::make_range(startIt, endIt), extractExtension);
for (const auto *ext : subset)
applyExtension(*ext);
// Book-keep for the next iteration.
startIt = extensions.begin() + count;
endIt = extensions.end();
}
}
} // namespace
DialectRegistry::DialectRegistry() { insert<BuiltinDialect>(); }
DialectAllocatorFunctionRef
@ -258,9 +298,7 @@ void DialectRegistry::applyExtensions(Dialect *dialect) const {
extension.apply(ctx, requiredDialects);
};
// Note: Additional extensions may be added while applying an extension.
for (int i = 0; i < static_cast<int>(extensions.size()); ++i)
applyExtension(*extensions[i]);
applyExtensionsFn(applyExtension, extensions);
}
void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
@ -285,15 +323,17 @@ void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
extension.apply(ctx, requiredDialects);
};
// Note: Additional extensions may be added while applying an extension.
for (int i = 0; i < static_cast<int>(extensions.size()); ++i)
applyExtension(*extensions[i]);
applyExtensionsFn(applyExtension, extensions);
}
bool DialectRegistry::isSubsetOf(const DialectRegistry &rhs) const {
// Treat any extensions conservatively.
if (!extensions.empty())
// Check that all extension keys are present in 'rhs'.
const auto hasExtension = [&](const auto &key) {
return rhs.extensions.contains(key);
};
if (!llvm::all_of(make_first_range(extensions), hasExtension))
return false;
// Check that the current dialects fully overlap with the dialects in 'rhs'.
return llvm::all_of(
registry, [&](const auto &it) { return rhs.registry.count(it.first); });

View File

@ -874,6 +874,8 @@ class TestTransformDialectExtension
: public transform::TransformDialectExtension<
TestTransformDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformDialectExtension)
using Base::Base;
void init() {

View File

@ -382,6 +382,9 @@ class TestTilingInterfaceDialectExtension
: public transform::TransformDialectExtension<
TestTilingInterfaceDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestTilingInterfaceDialectExtension)
using Base::Base;
void init() {

View File

@ -18,6 +18,8 @@ using namespace mlir::transform;
namespace {
class Extension : public TransformDialectExtension<Extension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(Extension)
using Base::Base;
void init() { declareGeneratedDialect<func::FuncDialect>(); }
};

View File

@ -8,6 +8,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectInterface.h"
#include "mlir/Support/TypeID.h"
#include "gtest/gtest.h"
using namespace mlir;
@ -140,15 +141,22 @@ namespace {
/// A dummy extension that increases a counter when being applied and
/// recursively adds additional extensions.
struct DummyExtension : DialectExtension<DummyExtension, TestDialect> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DummyExtension)
DummyExtension(int *counter, int numRecursive)
: DialectExtension(), counter(counter), numRecursive(numRecursive) {}
void apply(MLIRContext *ctx, TestDialect *dialect) const final {
++(*counter);
DialectRegistry nestedRegistry;
for (int i = 0; i < numRecursive; ++i)
nestedRegistry.addExtension(
std::make_unique<DummyExtension>(counter, /*numRecursive=*/0));
for (int i = 0; i < numRecursive; ++i) {
// Create unique TypeIDs for these recursive extensions so they don't get
// de-duplicated.
auto extension =
std::make_unique<DummyExtension>(counter, /*numRecursive=*/0);
auto typeID = TypeID::getFromOpaquePointer(extension.get());
nestedRegistry.addExtension(typeID, std::move(extension));
}
// Adding additional extensions may trigger a reallocation of the
// `extensions` vector in the dialect registry.
ctx->appendDialectRegistry(nestedRegistry);
@ -166,20 +174,56 @@ TEST(Dialect, NestedDialectExtension) {
// Add an extension that adds 100 more extensions.
int counter1 = 0;
registry.addExtension(std::make_unique<DummyExtension>(&counter1, 100));
registry.addExtension(TypeID::get<DummyExtension>(),
std::make_unique<DummyExtension>(&counter1, 100));
// Add one more extension. This should not crash.
int counter2 = 0;
registry.addExtension(std::make_unique<DummyExtension>(&counter2, 0));
registry.addExtension(TypeID::getFromOpaquePointer(&counter2),
std::make_unique<DummyExtension>(&counter2, 0));
// Load dialect and apply extensions.
MLIRContext context(registry);
Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
ASSERT_TRUE(testDialect != nullptr);
// Extensions may be applied multiple times. Make sure that each expected
// Extensions are de-duplicated by typeID. Make sure that each expected
// extension was applied at least once.
EXPECT_GE(counter1, 101);
EXPECT_GE(counter2, 1);
}
TEST(Dialect, SubsetWithExtensions) {
DialectRegistry registry1, registry2;
registry1.insert<TestDialect>();
registry2.insert<TestDialect>();
// Validate that the registries are equivalent.
ASSERT_TRUE(registry1.isSubsetOf(registry2));
ASSERT_TRUE(registry2.isSubsetOf(registry1));
// Add extensions to registry2.
int counter = 0;
registry2.addExtension(TypeID::get<DummyExtension>(),
std::make_unique<DummyExtension>(&counter, 0));
// Expect that (1) is a subset of (2) but not the other way around.
ASSERT_TRUE(registry1.isSubsetOf(registry2));
ASSERT_FALSE(registry2.isSubsetOf(registry1));
// Add extensions to registry1.
registry1.addExtension(TypeID::get<DummyExtension>(),
std::make_unique<DummyExtension>(&counter, 0));
// Expect that (1) and (2) are equivalent.
ASSERT_TRUE(registry1.isSubsetOf(registry2));
ASSERT_TRUE(registry2.isSubsetOf(registry1));
// Load dialect and apply extensions.
MLIRContext context(registry1);
context.getOrLoadDialect<TestDialect>();
context.appendDialectRegistry(registry2);
// Expect that the extension as only invoked once.
ASSERT_EQ(counter, 1);
}
} // namespace