llvm-project/mlir/unittests/IR/DialectTest.cpp
Nikhil Kalra 84cc1865ef
[mlir] Support DialectRegistry extension comparison (#101119)
`PassManager::run` loads the dependent dialects for each pass into the
current context prior to invoking the individual passes. If the
dependent dialect is already loaded into the context, this should be a
no-op. However, if there are extensions registered in the
`DialectRegistry`, the dependent dialects are unconditionally registered
into the context.

This poses a problem for dynamic pass pipelines, however, because they
will likely be executing while the context is in an immutable state
(because of the parent pass pipeline being run).

To solve this, we'll update the extension registration API on
`DialectRegistry` to require a type ID for each extension that is
registered. Then, instead of unconditionally registered dialects into a
context if extensions are present, we'll check against the extension
type IDs already present in the context's internal `DialectRegistry`.
The context will only be marked as dirty if there are net-new extension
types present in the `DialectRegistry` populated by
`PassManager::getDependentDialects`.

Note: this PR removes the `addExtension` overload that utilizes
`std::function` as the parameter. This is because `std::function` is
copyable and potentially allocates memory for the contained function so
we can't use the function pointer as the unique type ID for the
extension.

Downstream changes required:
- Existing `DialectExtension` subclasses will need a type ID to be
registered for each subclass. More details on how to register a type ID
can be found here:
8b68e06731/mlir/include/mlir/Support/TypeID.h (L30)
- Existing uses of the `std::function` overload of `addExtension` will
need to be refactored into dedicated `DialectExtension` classes with
associated type IDs. The attached `std::function` can either be inlined
into or called directly from `DialectExtension::apply`.

---------

Co-authored-by: Mehdi Amini <joker.eph@gmail.com>
2024-08-06 01:32:36 +02:00

230 lines
8.3 KiB
C++

//===- DialectTest.cpp - Dialect unit tests -------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectInterface.h"
#include "mlir/Support/TypeID.h"
#include "gtest/gtest.h"
using namespace mlir;
using namespace mlir::detail;
namespace {
struct TestDialect : public Dialect {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialect)
static StringRef getDialectNamespace() { return "test"; };
TestDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context, TypeID::get<TestDialect>()) {}
};
struct AnotherTestDialect : public Dialect {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AnotherTestDialect)
static StringRef getDialectNamespace() { return "test"; };
AnotherTestDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context,
TypeID::get<AnotherTestDialect>()) {}
};
TEST(DialectDeathTest, MultipleDialectsWithSameNamespace) {
MLIRContext context;
// Registering a dialect with the same namespace twice should result in a
// failure.
context.loadDialect<TestDialect>();
ASSERT_DEATH(context.loadDialect<AnotherTestDialect>(), "");
}
struct SecondTestDialect : public Dialect {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SecondTestDialect)
static StringRef getDialectNamespace() { return "test2"; }
SecondTestDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context,
TypeID::get<SecondTestDialect>()) {}
};
struct TestDialectInterfaceBase
: public DialectInterface::Base<TestDialectInterfaceBase> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialectInterfaceBase)
TestDialectInterfaceBase(Dialect *dialect) : Base(dialect) {}
virtual int function() const { return 42; }
};
struct TestDialectInterface : public TestDialectInterfaceBase {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialectInterface)
using TestDialectInterfaceBase::TestDialectInterfaceBase;
int function() const final { return 56; }
};
struct SecondTestDialectInterface : public TestDialectInterfaceBase {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SecondTestDialectInterface)
using TestDialectInterfaceBase::TestDialectInterfaceBase;
int function() const final { return 78; }
};
TEST(Dialect, DelayedInterfaceRegistration) {
DialectRegistry registry;
registry.insert<TestDialect, SecondTestDialect>();
// Delayed registration of an interface for TestDialect.
registry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) {
dialect->addInterfaces<TestDialectInterface>();
});
MLIRContext context(registry);
// Load the TestDialect and check that the interface got registered for it.
Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
ASSERT_TRUE(testDialect != nullptr);
auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
EXPECT_TRUE(testDialectInterface != nullptr);
// Load the SecondTestDialect and check that the interface is not registered
// for it.
Dialect *secondTestDialect = context.getOrLoadDialect<SecondTestDialect>();
ASSERT_TRUE(secondTestDialect != nullptr);
auto *secondTestDialectInterface =
dyn_cast<SecondTestDialectInterface>(secondTestDialect);
EXPECT_TRUE(secondTestDialectInterface == nullptr);
// Use the same mechanism as for delayed registration but for an already
// loaded dialect and check that the interface is now registered.
DialectRegistry secondRegistry;
secondRegistry.insert<SecondTestDialect>();
secondRegistry.addExtension(
+[](MLIRContext *ctx, SecondTestDialect *dialect) {
dialect->addInterfaces<SecondTestDialectInterface>();
});
context.appendDialectRegistry(secondRegistry);
secondTestDialectInterface =
dyn_cast<SecondTestDialectInterface>(secondTestDialect);
EXPECT_TRUE(secondTestDialectInterface != nullptr);
}
TEST(Dialect, RepeatedDelayedRegistration) {
// Set up the delayed registration.
DialectRegistry registry;
registry.insert<TestDialect>();
registry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) {
dialect->addInterfaces<TestDialectInterface>();
});
MLIRContext context(registry);
// Load the TestDialect and check that the interface got registered for it.
Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
ASSERT_TRUE(testDialect != nullptr);
auto *testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
EXPECT_TRUE(testDialectInterface != nullptr);
// Try adding the same dialect interface again and check that we don't crash
// on repeated interface registration.
DialectRegistry secondRegistry;
secondRegistry.insert<TestDialect>();
secondRegistry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) {
dialect->addInterfaces<TestDialectInterface>();
});
context.appendDialectRegistry(secondRegistry);
testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
EXPECT_TRUE(testDialectInterface != nullptr);
}
namespace {
/// A dummy extension that increases a counter when being applied and
/// recursively adds additional extensions.
struct DummyExtension : DialectExtension<DummyExtension, TestDialect> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DummyExtension)
DummyExtension(int *counter, int numRecursive)
: DialectExtension(), counter(counter), numRecursive(numRecursive) {}
void apply(MLIRContext *ctx, TestDialect *dialect) const final {
++(*counter);
DialectRegistry nestedRegistry;
for (int i = 0; i < numRecursive; ++i) {
// Create unique TypeIDs for these recursive extensions so they don't get
// de-duplicated.
auto extension =
std::make_unique<DummyExtension>(counter, /*numRecursive=*/0);
auto typeID = TypeID::getFromOpaquePointer(extension.get());
nestedRegistry.addExtension(typeID, std::move(extension));
}
// Adding additional extensions may trigger a reallocation of the
// `extensions` vector in the dialect registry.
ctx->appendDialectRegistry(nestedRegistry);
}
private:
int *counter;
int numRecursive;
};
} // namespace
TEST(Dialect, NestedDialectExtension) {
DialectRegistry registry;
registry.insert<TestDialect>();
// Add an extension that adds 100 more extensions.
int counter1 = 0;
registry.addExtension(TypeID::get<DummyExtension>(),
std::make_unique<DummyExtension>(&counter1, 100));
// Add one more extension. This should not crash.
int counter2 = 0;
registry.addExtension(TypeID::getFromOpaquePointer(&counter2),
std::make_unique<DummyExtension>(&counter2, 0));
// Load dialect and apply extensions.
MLIRContext context(registry);
Dialect *testDialect = context.getOrLoadDialect<TestDialect>();
ASSERT_TRUE(testDialect != nullptr);
// Extensions are de-duplicated by typeID. Make sure that each expected
// extension was applied at least once.
EXPECT_GE(counter1, 101);
EXPECT_GE(counter2, 1);
}
TEST(Dialect, SubsetWithExtensions) {
DialectRegistry registry1, registry2;
registry1.insert<TestDialect>();
registry2.insert<TestDialect>();
// Validate that the registries are equivalent.
ASSERT_TRUE(registry1.isSubsetOf(registry2));
ASSERT_TRUE(registry2.isSubsetOf(registry1));
// Add extensions to registry2.
int counter = 0;
registry2.addExtension(TypeID::get<DummyExtension>(),
std::make_unique<DummyExtension>(&counter, 0));
// Expect that (1) is a subset of (2) but not the other way around.
ASSERT_TRUE(registry1.isSubsetOf(registry2));
ASSERT_FALSE(registry2.isSubsetOf(registry1));
// Add extensions to registry1.
registry1.addExtension(TypeID::get<DummyExtension>(),
std::make_unique<DummyExtension>(&counter, 0));
// Expect that (1) and (2) are equivalent.
ASSERT_TRUE(registry1.isSubsetOf(registry2));
ASSERT_TRUE(registry2.isSubsetOf(registry1));
// Load dialect and apply extensions.
MLIRContext context(registry1);
context.getOrLoadDialect<TestDialect>();
context.appendDialectRegistry(registry2);
// Expect that the extension as only invoked once.
ASSERT_EQ(counter, 1);
}
} // namespace