Maksim Levental b0e00ca6a6
[mlir][python] fix replace=True for register_operation and register_type_caster (#70264)
<img
src="https://github.com/llvm/llvm-project/assets/5657668/443852b6-ac25-45bb-a38b-5dfbda09d5a7"
height="400" />
<p></p>


So turns out that none of the `replace=True` things actually work
because of the map caches (except for
`register_attribute_builder(replace=True)`, which doesn't use such a
cache). This was hidden by a series of unfortunate events:

1. `register_type_caster` failure was hidden because it was the same
`TestIntegerRankedTensorType` being replaced with itself (d'oh).
2. `register_operation` failure was hidden behind the "order of events"
in the lifecycle of typical extension import/use. Since extensions are
loaded/registered almost immediately after generated builders are
registered, there is no opportunity for the `operationClassMapCache` to
be populated (through e.g., `module.body.operations[2]` or
`module.body.operations[2].opview` or something). Of course as soon as
you as actually do "late-bind/late-register" the extension, you see it's
not successfully replacing the stale one in `operationClassMapCache`.

I'll take this opportunity to propose we ditch the caches all together.
I've been cargo-culting them but I really don't understand how they
work. There's this comment above `operationClassMapCache`

```cpp
  /// Cache of operation name to external operation class object. This is
  /// maintained on lookup as a shadow of operationClassMap in order for repeat
  /// lookups of the classes to only incur the cost of one hashtable lookup.
  llvm::StringMap<pybind11::object> operationClassMapCache;
```

But I don't understand how that's true given that the canonical thing
`operationClassMap` is already a map:

```cpp
  /// Map of full operation name to external operation class object.
  llvm::StringMap<pybind11::object> operationClassMap;
```

Maybe it wasn't always the case? Anyway things work now but it seems
like an unnecessary layer of complexity for not much gain? But maybe I'm
wrong.
2023-10-30 20:22:27 -05:00

222 lines
7.6 KiB
C++

//===- IRModule.cpp - IR pybind module ------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "IRModule.h"
#include "Globals.h"
#include "PybindUtils.h"
#include <optional>
#include <vector>
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/Support.h"
namespace py = pybind11;
using namespace mlir;
using namespace mlir::python;
// -----------------------------------------------------------------------------
// PyGlobals
// -----------------------------------------------------------------------------
PyGlobals *PyGlobals::instance = nullptr;
PyGlobals::PyGlobals() {
assert(!instance && "PyGlobals already constructed");
instance = this;
// The default search path include {mlir.}dialects, where {mlir.} is the
// package prefix configured at compile time.
dialectSearchPrefixes.emplace_back(MAKE_MLIR_PYTHON_QUALNAME("dialects"));
}
PyGlobals::~PyGlobals() { instance = nullptr; }
void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
if (loadedDialectModulesCache.contains(dialectNamespace))
return;
// Since re-entrancy is possible, make a copy of the search prefixes.
std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
py::object loaded;
for (std::string moduleName : localSearchPrefixes) {
moduleName.push_back('.');
moduleName.append(dialectNamespace.data(), dialectNamespace.size());
try {
loaded = py::module::import(moduleName.c_str());
} catch (py::error_already_set &e) {
if (e.matches(PyExc_ModuleNotFoundError)) {
continue;
}
throw;
}
break;
}
// Note: Iterator cannot be shared from prior to loading, since re-entrancy
// may have occurred, which may do anything.
loadedDialectModulesCache.insert(dialectNamespace);
}
void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
py::function pyFunc, bool replace) {
py::object &found = attributeBuilderMap[attributeKind];
if (found && !found.is_none() && !replace) {
throw std::runtime_error((llvm::Twine("Attribute builder for '") +
attributeKind +
"' is already registered with func: " +
py::str(found).operator std::string())
.str());
}
found = std::move(pyFunc);
}
void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
pybind11::function typeCaster,
bool replace) {
pybind11::object &found = typeCasterMap[mlirTypeID];
if (found && !found.is_none() && !replace)
throw std::runtime_error("Type caster is already registered");
found = std::move(typeCaster);
const auto foundIt = typeCasterMapCache.find(mlirTypeID);
if (foundIt != typeCasterMapCache.end() && !foundIt->second.is_none()) {
typeCasterMapCache[mlirTypeID] = found;
}
}
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
py::object pyClass) {
py::object &found = dialectClassMap[dialectNamespace];
if (found) {
throw std::runtime_error((llvm::Twine("Dialect namespace '") +
dialectNamespace + "' is already registered.")
.str());
}
found = std::move(pyClass);
}
void PyGlobals::registerOperationImpl(const std::string &operationName,
py::object pyClass, bool replace) {
py::object &found = operationClassMap[operationName];
if (found && !replace) {
throw std::runtime_error((llvm::Twine("Operation '") + operationName +
"' is already registered.")
.str());
}
found = std::move(pyClass);
auto foundIt = operationClassMapCache.find(operationName);
if (foundIt != operationClassMapCache.end() && !foundIt->second.is_none()) {
operationClassMapCache[operationName] = found;
}
}
std::optional<py::function>
PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
// Fast match against the class map first (common case).
const auto foundIt = attributeBuilderMap.find(attributeKind);
if (foundIt != attributeBuilderMap.end()) {
if (foundIt->second.is_none())
return std::nullopt;
assert(foundIt->second && "py::function is defined");
return foundIt->second;
}
// Not found and loading did not yield a registration. Negative cache.
attributeBuilderMap[attributeKind] = py::none();
return std::nullopt;
}
std::optional<py::function> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
MlirDialect dialect) {
{
// Fast match against the class map first (common case).
const auto foundIt = typeCasterMapCache.find(mlirTypeID);
if (foundIt != typeCasterMapCache.end()) {
if (foundIt->second.is_none())
return std::nullopt;
assert(foundIt->second && "py::function is defined");
return foundIt->second;
}
}
// Not found. Load the dialect namespace.
loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
// Attempt to find from the canonical map and cache.
{
const auto foundIt = typeCasterMap.find(mlirTypeID);
if (foundIt != typeCasterMap.end()) {
if (foundIt->second.is_none())
return std::nullopt;
assert(foundIt->second && "py::object is defined");
// Positive cache.
typeCasterMapCache[mlirTypeID] = foundIt->second;
return foundIt->second;
}
// Negative cache.
typeCasterMap[mlirTypeID] = py::none();
return std::nullopt;
}
}
std::optional<py::object>
PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
loadDialectModule(dialectNamespace);
// Fast match against the class map first (common case).
const auto foundIt = dialectClassMap.find(dialectNamespace);
if (foundIt != dialectClassMap.end()) {
if (foundIt->second.is_none())
return std::nullopt;
assert(foundIt->second && "py::object is defined");
return foundIt->second;
}
// Not found and loading did not yield a registration. Negative cache.
dialectClassMap[dialectNamespace] = py::none();
return std::nullopt;
}
std::optional<pybind11::object>
PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
{
auto foundIt = operationClassMapCache.find(operationName);
if (foundIt != operationClassMapCache.end()) {
if (foundIt->second.is_none())
return std::nullopt;
assert(foundIt->second && "py::object is defined");
return foundIt->second;
}
}
// Not found. Load the dialect namespace.
auto split = operationName.split('.');
llvm::StringRef dialectNamespace = split.first;
loadDialectModule(dialectNamespace);
// Attempt to find from the canonical map and cache.
{
auto foundIt = operationClassMap.find(operationName);
if (foundIt != operationClassMap.end()) {
if (foundIt->second.is_none())
return std::nullopt;
assert(foundIt->second && "py::object is defined");
// Positive cache.
operationClassMapCache[operationName] = foundIt->second;
return foundIt->second;
}
// Negative cache.
operationClassMap[operationName] = py::none();
return std::nullopt;
}
}
void PyGlobals::clearImportCache() {
loadedDialectModulesCache.clear();
operationClassMapCache.clear();
typeCasterMapCache.clear();
}