[mlir] Add a method on MLIRContext to retrieve the operations for a given dialect (#112344)

Currently we have `MLIRContext::getRegisteredOperations` which returns
all operations for the given context, with the addition of
`MLIRContext::getRegisteredOperationsByDialect` we can now retrieve the
same for a given dialect class.

Closes #111591
This commit is contained in:
Rajveer Singh Bharadwaj 2024-10-17 15:32:24 +05:30 committed by GitHub
parent 4091bc61e3
commit b091701d01
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 29 additions and 0 deletions

View File

@ -197,6 +197,11 @@ public:
/// operations.
ArrayRef<RegisteredOperationName> getRegisteredOperations();
/// Return a sorted array containing the information for registered operations
/// filtered by dialect name.
ArrayRef<RegisteredOperationName>
getRegisteredOperationsByDialect(StringRef dialectName);
/// Return true if this operation name is registered in this context.
bool isOperationRegistered(StringRef name);

View File

@ -711,6 +711,30 @@ ArrayRef<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
return impl->sortedRegisteredOperations;
}
/// Return information for registered operations by dialect.
ArrayRef<RegisteredOperationName>
MLIRContext::getRegisteredOperationsByDialect(StringRef dialectName) {
auto lowerBound =
std::lower_bound(impl->sortedRegisteredOperations.begin(),
impl->sortedRegisteredOperations.end(), dialectName,
[](auto &lhs, auto &rhs) {
return lhs.getDialect().getNamespace().compare(rhs);
});
if (lowerBound == impl->sortedRegisteredOperations.end() ||
lowerBound->getDialect().getNamespace() != dialectName)
return ArrayRef<RegisteredOperationName>();
auto upperBound =
std::upper_bound(lowerBound, impl->sortedRegisteredOperations.end(),
dialectName, [](auto &lhs, auto &rhs) {
return lhs.compare(rhs.getDialect().getNamespace());
});
size_t count = std::distance(lowerBound, upperBound);
return ArrayRef(&*lowerBound, count);
}
bool MLIRContext::isOperationRegistered(StringRef name) {
return RegisteredOperationName::lookup(name, this).has_value();
}