From 84cc1865ef9202af39404ff4524a9b13df80cfc1 Mon Sep 17 00:00:00 2001 From: Nikhil Kalra Date: Mon, 5 Aug 2024 16:32:36 -0700 Subject: [PATCH] [mlir] Support DialectRegistry extension comparison (#101119) `PassManager::run` loads the dependent dialects for each pass into the current context prior to invoking the individual passes. If the dependent dialect is already loaded into the context, this should be a no-op. However, if there are extensions registered in the `DialectRegistry`, the dependent dialects are unconditionally registered into the context. This poses a problem for dynamic pass pipelines, however, because they will likely be executing while the context is in an immutable state (because of the parent pass pipeline being run). To solve this, we'll update the extension registration API on `DialectRegistry` to require a type ID for each extension that is registered. Then, instead of unconditionally registered dialects into a context if extensions are present, we'll check against the extension type IDs already present in the context's internal `DialectRegistry`. The context will only be marked as dirty if there are net-new extension types present in the `DialectRegistry` populated by `PassManager::getDependentDialects`. Note: this PR removes the `addExtension` overload that utilizes `std::function` as the parameter. This is because `std::function` is copyable and potentially allocates memory for the contained function so we can't use the function pointer as the unique type ID for the extension. Downstream changes required: - Existing `DialectExtension` subclasses will need a type ID to be registered for each subclass. More details on how to register a type ID can be found here: https://github.com/llvm/llvm-project/blob/8b68e06731e0033ed3f8d6fe6292ae671611cfa1/mlir/include/mlir/Support/TypeID.h#L30 - Existing uses of the `std::function` overload of `addExtension` will need to be refactored into dedicated `DialectExtension` classes with associated type IDs. The attached `std::function` can either be inlined into or called directly from `DialectExtension::apply`. --------- Co-authored-by: Mehdi Amini --- .../transform/Ch2/lib/MyExtension.cpp | 3 + .../transform/Ch3/lib/MyExtension.cpp | 3 + .../transform/Ch4/lib/MyExtension.cpp | 3 + mlir/include/mlir/IR/DialectRegistry.h | 37 ++++++------ .../ConvertToLLVM/ConvertToLLVMPass.cpp | 2 + .../TransformOps/AffineTransformOps.cpp | 2 + .../BufferizationTransformOps.cpp | 3 + .../Func/TransformOps/FuncTransformOps.cpp | 2 + .../GPU/TransformOps/GPUTransformOps.cpp | 2 + .../Linalg/TransformOps/DialectExtension.cpp | 2 + .../TransformOps/MemRefTransformOps.cpp | 2 + .../NVGPU/TransformOps/NVGPUTransformOps.cpp | 2 + .../SCF/TransformOps/SCFTransformOps.cpp | 2 + .../TransformOps/SparseTensorTransformOps.cpp | 3 + .../TransformOps/TensorTransformOps.cpp | 2 + .../DebugExtension/DebugExtension.cpp | 2 + .../Transform/IRDLExtension/IRDLExtension.cpp | 2 + .../Transform/LoopExtension/LoopExtension.cpp | 2 + .../Transform/PDLExtension/PDLExtension.cpp | 2 + .../TransformOps/VectorTransformOps.cpp | 2 + mlir/lib/IR/Dialect.cpp | 56 ++++++++++++++++--- .../TestTransformDialectExtension.cpp | 2 + .../TestTilingInterfaceTransformOps.cpp | 3 + .../Transform/BuildOnlyExtensionTest.cpp | 2 + mlir/unittests/IR/DialectTest.cpp | 56 +++++++++++++++++-- 25 files changed, 167 insertions(+), 32 deletions(-) diff --git a/mlir/examples/transform/Ch2/lib/MyExtension.cpp b/mlir/examples/transform/Ch2/lib/MyExtension.cpp index 68d538a09801..b4b27e97d266 100644 --- a/mlir/examples/transform/Ch2/lib/MyExtension.cpp +++ b/mlir/examples/transform/Ch2/lib/MyExtension.cpp @@ -29,6 +29,9 @@ class MyExtension : public ::mlir::transform::TransformDialectExtension { public: + // The TypeID of this extension. + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MyExtension) + // The extension must derive the base constructor. using Base::Base; diff --git a/mlir/examples/transform/Ch3/lib/MyExtension.cpp b/mlir/examples/transform/Ch3/lib/MyExtension.cpp index f7a99423a52e..4b2123fa71d3 100644 --- a/mlir/examples/transform/Ch3/lib/MyExtension.cpp +++ b/mlir/examples/transform/Ch3/lib/MyExtension.cpp @@ -35,6 +35,9 @@ class MyExtension : public ::mlir::transform::TransformDialectExtension { public: + // The TypeID of this extension. + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MyExtension) + // The extension must derive the base constructor. using Base::Base; diff --git a/mlir/examples/transform/Ch4/lib/MyExtension.cpp b/mlir/examples/transform/Ch4/lib/MyExtension.cpp index 38c8ca1125a2..fa0ffc9dc2e8 100644 --- a/mlir/examples/transform/Ch4/lib/MyExtension.cpp +++ b/mlir/examples/transform/Ch4/lib/MyExtension.cpp @@ -31,6 +31,9 @@ class MyExtension : public ::mlir::transform::TransformDialectExtension { public: + // The TypeID of this extension. + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MyExtension) + // The extension must derive the base constructor. using Base::Base; diff --git a/mlir/include/mlir/IR/DialectRegistry.h b/mlir/include/mlir/IR/DialectRegistry.h index 8e394988119d..2c1f6964998e 100644 --- a/mlir/include/mlir/IR/DialectRegistry.h +++ b/mlir/include/mlir/IR/DialectRegistry.h @@ -14,9 +14,9 @@ #define MLIR_IR_DIALECTREGISTRY_H #include "mlir/IR/MLIRContext.h" +#include "mlir/Support/TypeID.h" #include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/MapVector.h" #include #include @@ -187,7 +187,8 @@ public: nameAndRegistrationIt.second.second); // Merge the extensions. for (const auto &extension : extensions) - destination.extensions.push_back(extension->clone()); + destination.extensions.try_emplace(extension.first, + extension.second->clone()); } /// Return the names of dialects known to this registry. @@ -206,39 +207,37 @@ public: void applyExtensions(MLIRContext *ctx) const; /// Add the given extension to the registry. - void addExtension(std::unique_ptr extension) { - extensions.push_back(std::move(extension)); + bool addExtension(TypeID extensionID, + std::unique_ptr extension) { + return extensions.try_emplace(extensionID, std::move(extension)).second; } /// Add the given extensions to the registry. template void addExtensions() { - (addExtension(std::make_unique()), ...); + (addExtension(TypeID::get(), std::make_unique()), + ...); } /// Add an extension function that requires the given dialects. /// Note: This bare functor overload is provided in addition to the /// std::function variant to enable dialect type deduction, e.g.: - /// registry.addExtension(+[](MLIRContext *ctx, MyDialect *dialect) { ... }) + /// registry.addExtension(+[](MLIRContext *ctx, MyDialect *dialect) { + /// ... }) /// /// is equivalent to: /// registry.addExtension( /// [](MLIRContext *ctx, MyDialect *dialect){ ... } /// ) template - void addExtension(void (*extensionFn)(MLIRContext *, DialectsT *...)) { - addExtension( - std::function(extensionFn)); - } - template - void - addExtension(std::function extensionFn) { - using ExtensionFnT = std::function; + bool addExtension(void (*extensionFn)(MLIRContext *, DialectsT *...)) { + using ExtensionFnT = void (*)(MLIRContext *, DialectsT *...); struct Extension : public DialectExtension { Extension(const Extension &) = default; Extension(ExtensionFnT extensionFn) - : extensionFn(std::move(extensionFn)) {} + : DialectExtension(), + extensionFn(extensionFn) {} ~Extension() override = default; void apply(MLIRContext *context, DialectsT *...dialects) const final { @@ -246,7 +245,9 @@ public: } ExtensionFnT extensionFn; }; - addExtension(std::make_unique(std::move(extensionFn))); + return addExtension(TypeID::getFromOpaquePointer( + reinterpret_cast(extensionFn)), + std::make_unique(extensionFn)); } /// Returns true if the current registry is a subset of 'rhs', i.e. if 'rhs' @@ -255,7 +256,7 @@ public: private: MapTy registry; - std::vector> extensions; + llvm::MapVector> extensions; }; } // namespace mlir diff --git a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp index 6135117348a5..b2407a258c27 100644 --- a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp +++ b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp @@ -35,6 +35,8 @@ namespace { /// starting a pass pipeline that involves dialect conversion to LLVM. class LoadDependentDialectExtension : public DialectExtensionBase { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LoadDependentDialectExtension) + LoadDependentDialectExtension() : DialectExtensionBase(/*dialectNames=*/{}) {} void apply(MLIRContext *context, diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp index 6457655cfe41..eb5229794072 100644 --- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp +++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp @@ -157,6 +157,8 @@ class AffineTransformDialectExtension : public transform::TransformDialectExtension< AffineTransformDialectExtension> { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AffineTransformDialectExtension) + using Base::Base; void init() { diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp index e10c7bd914e3..a1d7bb995fc7 100644 --- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp +++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp @@ -150,6 +150,9 @@ class BufferizationTransformDialectExtension : public transform::TransformDialectExtension< BufferizationTransformDialectExtension> { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + BufferizationTransformDialectExtension) + using Base::Base; void init() { diff --git a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp index b632b25d0cc6..2728936bf33f 100644 --- a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp +++ b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp @@ -236,6 +236,8 @@ class FuncTransformDialectExtension : public transform::TransformDialectExtension< FuncTransformDialectExtension> { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FuncTransformDialectExtension) + using Base::Base; void init() { diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp index 3661c5dea452..1528da914d54 100644 --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -924,6 +924,8 @@ class GPUTransformDialectExtension : public transform::TransformDialectExtension< GPUTransformDialectExtension> { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GPUTransformDialectExtension) + GPUTransformDialectExtension() { declareGeneratedDialect(); declareGeneratedDialect(); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/DialectExtension.cpp b/mlir/lib/Dialect/Linalg/TransformOps/DialectExtension.cpp index f4244ca96223..4591802ce74a 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/DialectExtension.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/DialectExtension.cpp @@ -30,6 +30,8 @@ class LinalgTransformDialectExtension : public transform::TransformDialectExtension< LinalgTransformDialectExtension> { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LinalgTransformDialectExtension) + using Base::Base; void init() { diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp index 8469e84c668c..89640ac323b6 100644 --- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp +++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp @@ -309,6 +309,8 @@ class MemRefTransformDialectExtension : public transform::TransformDialectExtension< MemRefTransformDialectExtension> { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MemRefTransformDialectExtension) + using Base::Base; void init() { diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp index 733fde78e425..0c2275bbc4b2 100644 --- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp +++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp @@ -1135,6 +1135,8 @@ class NVGPUTransformDialectExtension : public transform::TransformDialectExtension< NVGPUTransformDialectExtension> { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NVGPUTransformDialectExtension) + NVGPUTransformDialectExtension() { declareGeneratedDialect(); declareGeneratedDialect(); diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp index c4a55c302d0a..551411bb1476 100644 --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -613,6 +613,8 @@ class SCFTransformDialectExtension : public transform::TransformDialectExtension< SCFTransformDialectExtension> { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SCFTransformDialectExtension) + using Base::Base; void init() { diff --git a/mlir/lib/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp b/mlir/lib/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp index ca19259ebffa..bdec43825ddc 100644 --- a/mlir/lib/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp +++ b/mlir/lib/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp @@ -38,6 +38,9 @@ class SparseTensorTransformDialectExtension : public transform::TransformDialectExtension< SparseTensorTransformDialectExtension> { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + SparseTensorTransformDialectExtension) + SparseTensorTransformDialectExtension() { declareGeneratedDialect(); registerTransformOps< diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp index 33016f84056e..f911619d7122 100644 --- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp +++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp @@ -236,6 +236,8 @@ class TensorTransformDialectExtension : public transform::TransformDialectExtension< TensorTransformDialectExtension> { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TensorTransformDialectExtension) + using Base::Base; void init() { diff --git a/mlir/lib/Dialect/Transform/DebugExtension/DebugExtension.cpp b/mlir/lib/Dialect/Transform/DebugExtension/DebugExtension.cpp index e369daddb00c..d69535169f95 100644 --- a/mlir/lib/Dialect/Transform/DebugExtension/DebugExtension.cpp +++ b/mlir/lib/Dialect/Transform/DebugExtension/DebugExtension.cpp @@ -20,6 +20,8 @@ namespace { class DebugExtension : public transform::TransformDialectExtension { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DebugExtension) + void init() { registerTransformOps< #define GET_OP_LIST diff --git a/mlir/lib/Dialect/Transform/IRDLExtension/IRDLExtension.cpp b/mlir/lib/Dialect/Transform/IRDLExtension/IRDLExtension.cpp index 94004365b8a1..9dc95490b14b 100644 --- a/mlir/lib/Dialect/Transform/IRDLExtension/IRDLExtension.cpp +++ b/mlir/lib/Dialect/Transform/IRDLExtension/IRDLExtension.cpp @@ -18,6 +18,8 @@ namespace { class IRDLExtension : public transform::TransformDialectExtension { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(IRDLExtension) + void init() { registerTransformOps< #define GET_OP_LIST diff --git a/mlir/lib/Dialect/Transform/LoopExtension/LoopExtension.cpp b/mlir/lib/Dialect/Transform/LoopExtension/LoopExtension.cpp index b33288fd7b99..0a099b5bc75a 100644 --- a/mlir/lib/Dialect/Transform/LoopExtension/LoopExtension.cpp +++ b/mlir/lib/Dialect/Transform/LoopExtension/LoopExtension.cpp @@ -20,6 +20,8 @@ namespace { class LoopExtension : public transform::TransformDialectExtension { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LoopExtension) + void init() { registerTransformOps< #define GET_OP_LIST diff --git a/mlir/lib/Dialect/Transform/PDLExtension/PDLExtension.cpp b/mlir/lib/Dialect/Transform/PDLExtension/PDLExtension.cpp index 2c770abd56d5..27c5dc332a42 100644 --- a/mlir/lib/Dialect/Transform/PDLExtension/PDLExtension.cpp +++ b/mlir/lib/Dialect/Transform/PDLExtension/PDLExtension.cpp @@ -38,6 +38,8 @@ namespace { /// with Transform dialect operations. class PDLExtension : public transform::TransformDialectExtension { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PDLExtension) + void init() { registerTransformOps< #define GET_OP_LIST diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp index 2e9aa8801182..bc423a3781bf 100644 --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -212,6 +212,8 @@ class VectorTransformDialectExtension : public transform::TransformDialectExtension< VectorTransformDialectExtension> { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VectorTransformDialectExtension) + VectorTransformDialectExtension() { declareGeneratedDialect(); declareGeneratedDialect(); diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp index 965386681f27..cc80677a4078 100644 --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -11,14 +11,20 @@ #include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/DialectInterface.h" +#include "mlir/IR/DialectRegistry.h" #include "mlir/IR/ExtensibleDialect.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" +#include "mlir/Support/TypeID.h" #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SetOperations.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/SmallVectorExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ManagedStatic.h" #include "llvm/Support/Regex.h" +#include #define DEBUG_TYPE "dialect" @@ -173,6 +179,40 @@ bool dialect_extension_detail::hasPromisedInterface(Dialect &dialect, // DialectRegistry //===----------------------------------------------------------------------===// +namespace { +template +void applyExtensionsFn( + Fn &&applyExtension, + const llvm::MapVector> + &extensions) { + // Note: Additional extensions may be added while applying an extension. + // The iterators will be invalidated if extensions are added so we'll keep + // a copy of the extensions for ourselves. + + const auto extractExtension = + [](const auto &entry) -> DialectExtensionBase * { + return entry.second.get(); + }; + + auto startIt = extensions.begin(), endIt = extensions.end(); + size_t count = 0; + while (startIt != endIt) { + count += endIt - startIt; + + // Grab the subset of extensions we'll apply in this iteration. + const auto subset = + llvm::map_to_vector(llvm::make_range(startIt, endIt), extractExtension); + + for (const auto *ext : subset) + applyExtension(*ext); + + // Book-keep for the next iteration. + startIt = extensions.begin() + count; + endIt = extensions.end(); + } +} +} // namespace + DialectRegistry::DialectRegistry() { insert(); } DialectAllocatorFunctionRef @@ -258,9 +298,7 @@ void DialectRegistry::applyExtensions(Dialect *dialect) const { extension.apply(ctx, requiredDialects); }; - // Note: Additional extensions may be added while applying an extension. - for (int i = 0; i < static_cast(extensions.size()); ++i) - applyExtension(*extensions[i]); + applyExtensionsFn(applyExtension, extensions); } void DialectRegistry::applyExtensions(MLIRContext *ctx) const { @@ -285,15 +323,17 @@ void DialectRegistry::applyExtensions(MLIRContext *ctx) const { extension.apply(ctx, requiredDialects); }; - // Note: Additional extensions may be added while applying an extension. - for (int i = 0; i < static_cast(extensions.size()); ++i) - applyExtension(*extensions[i]); + applyExtensionsFn(applyExtension, extensions); } bool DialectRegistry::isSubsetOf(const DialectRegistry &rhs) const { - // Treat any extensions conservatively. - if (!extensions.empty()) + // Check that all extension keys are present in 'rhs'. + const auto hasExtension = [&](const auto &key) { + return rhs.extensions.contains(key); + }; + if (!llvm::all_of(make_first_range(extensions), hasExtension)) return false; + // Check that the current dialects fully overlap with the dialects in 'rhs'. return llvm::all_of( registry, [&](const auto &it) { return rhs.registry.count(it.first); }); diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp index b8a4b9470d73..c023aad4a3ee 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -874,6 +874,8 @@ class TestTransformDialectExtension : public transform::TransformDialectExtension< TestTransformDialectExtension> { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformDialectExtension) + using Base::Base; void init() { diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp index a99441cd7147..7aa7b58433f3 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp @@ -382,6 +382,9 @@ class TestTilingInterfaceDialectExtension : public transform::TransformDialectExtension< TestTilingInterfaceDialectExtension> { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + TestTilingInterfaceDialectExtension) + using Base::Base; void init() { diff --git a/mlir/unittests/Dialect/Transform/BuildOnlyExtensionTest.cpp b/mlir/unittests/Dialect/Transform/BuildOnlyExtensionTest.cpp index 40fb752ffd6e..d2a4999594a9 100644 --- a/mlir/unittests/Dialect/Transform/BuildOnlyExtensionTest.cpp +++ b/mlir/unittests/Dialect/Transform/BuildOnlyExtensionTest.cpp @@ -18,6 +18,8 @@ using namespace mlir::transform; namespace { class Extension : public TransformDialectExtension { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(Extension) + using Base::Base; void init() { declareGeneratedDialect(); } }; diff --git a/mlir/unittests/IR/DialectTest.cpp b/mlir/unittests/IR/DialectTest.cpp index e99d46e6d264..7dd6a01c3389 100644 --- a/mlir/unittests/IR/DialectTest.cpp +++ b/mlir/unittests/IR/DialectTest.cpp @@ -8,6 +8,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectInterface.h" +#include "mlir/Support/TypeID.h" #include "gtest/gtest.h" using namespace mlir; @@ -140,15 +141,22 @@ namespace { /// A dummy extension that increases a counter when being applied and /// recursively adds additional extensions. struct DummyExtension : DialectExtension { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DummyExtension) + DummyExtension(int *counter, int numRecursive) : DialectExtension(), counter(counter), numRecursive(numRecursive) {} void apply(MLIRContext *ctx, TestDialect *dialect) const final { ++(*counter); DialectRegistry nestedRegistry; - for (int i = 0; i < numRecursive; ++i) - nestedRegistry.addExtension( - std::make_unique(counter, /*numRecursive=*/0)); + for (int i = 0; i < numRecursive; ++i) { + // Create unique TypeIDs for these recursive extensions so they don't get + // de-duplicated. + auto extension = + std::make_unique(counter, /*numRecursive=*/0); + auto typeID = TypeID::getFromOpaquePointer(extension.get()); + nestedRegistry.addExtension(typeID, std::move(extension)); + } // Adding additional extensions may trigger a reallocation of the // `extensions` vector in the dialect registry. ctx->appendDialectRegistry(nestedRegistry); @@ -166,20 +174,56 @@ TEST(Dialect, NestedDialectExtension) { // Add an extension that adds 100 more extensions. int counter1 = 0; - registry.addExtension(std::make_unique(&counter1, 100)); + registry.addExtension(TypeID::get(), + std::make_unique(&counter1, 100)); // Add one more extension. This should not crash. int counter2 = 0; - registry.addExtension(std::make_unique(&counter2, 0)); + registry.addExtension(TypeID::getFromOpaquePointer(&counter2), + std::make_unique(&counter2, 0)); // Load dialect and apply extensions. MLIRContext context(registry); Dialect *testDialect = context.getOrLoadDialect(); ASSERT_TRUE(testDialect != nullptr); - // Extensions may be applied multiple times. Make sure that each expected + // Extensions are de-duplicated by typeID. Make sure that each expected // extension was applied at least once. EXPECT_GE(counter1, 101); EXPECT_GE(counter2, 1); } +TEST(Dialect, SubsetWithExtensions) { + DialectRegistry registry1, registry2; + registry1.insert(); + registry2.insert(); + + // Validate that the registries are equivalent. + ASSERT_TRUE(registry1.isSubsetOf(registry2)); + ASSERT_TRUE(registry2.isSubsetOf(registry1)); + + // Add extensions to registry2. + int counter = 0; + registry2.addExtension(TypeID::get(), + std::make_unique(&counter, 0)); + + // Expect that (1) is a subset of (2) but not the other way around. + ASSERT_TRUE(registry1.isSubsetOf(registry2)); + ASSERT_FALSE(registry2.isSubsetOf(registry1)); + + // Add extensions to registry1. + registry1.addExtension(TypeID::get(), + std::make_unique(&counter, 0)); + + // Expect that (1) and (2) are equivalent. + ASSERT_TRUE(registry1.isSubsetOf(registry2)); + ASSERT_TRUE(registry2.isSubsetOf(registry1)); + + // Load dialect and apply extensions. + MLIRContext context(registry1); + context.getOrLoadDialect(); + context.appendDialectRegistry(registry2); + // Expect that the extension as only invoked once. + ASSERT_EQ(counter, 1); +} + } // namespace