[mlir][python] automatic location inference (#151246)
This PR implements "automatic" location inference in the bindings. The
way it works is it walks the frame stack collecting source locations
(Python captures these in the frame itself). It is inspired by JAX's
[implementation](523ddcfbca/jax/_src/interpreters/mlir.py (L462)
)
but moves the frame stack traversal into the bindings for better
performance.
The system supports registering "included" and "excluded" filenames;
frames originating from functions in included filenames **will not** be
filtered and frames originating from functions in excluded filenames
**will** be filtered (in that order). This allows excluding all the
generated `*_ops_gen.py` files.
The system is also "toggleable" and off by default to save people who
have their own systems (such as JAX) from the added cost.
Note, the system stores the entire stacktrace (subject to
`locTracebackFramesLimit`) in the `Location` using specifically a
`CallSiteLoc`. This can be useful for profiling tools (flamegraphs
etc.).
Shoutout to the folks at JAX for coming up with a good system.
---------
Co-authored-by: Jacques Pienaar <jpienaar@google.com>
This commit is contained in:
parent
da3182a288
commit
a40f47c972
@ -10,15 +10,19 @@
|
|||||||
#define MLIR_BINDINGS_PYTHON_GLOBALS_H
|
#define MLIR_BINDINGS_PYTHON_GLOBALS_H
|
||||||
|
|
||||||
#include <optional>
|
#include <optional>
|
||||||
|
#include <regex>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "NanobindUtils.h"
|
#include "NanobindUtils.h"
|
||||||
#include "mlir-c/IR.h"
|
#include "mlir-c/IR.h"
|
||||||
#include "mlir/CAPI/Support.h"
|
#include "mlir/CAPI/Support.h"
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/ADT/StringExtras.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
#include "llvm/ADT/StringSet.h"
|
#include "llvm/ADT/StringSet.h"
|
||||||
|
#include "llvm/Support/Regex.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace python {
|
namespace python {
|
||||||
@ -114,6 +118,39 @@ public:
|
|||||||
std::optional<nanobind::object>
|
std::optional<nanobind::object>
|
||||||
lookupOperationClass(llvm::StringRef operationName);
|
lookupOperationClass(llvm::StringRef operationName);
|
||||||
|
|
||||||
|
class TracebackLoc {
|
||||||
|
public:
|
||||||
|
bool locTracebacksEnabled();
|
||||||
|
|
||||||
|
void setLocTracebacksEnabled(bool value);
|
||||||
|
|
||||||
|
size_t locTracebackFramesLimit();
|
||||||
|
|
||||||
|
void setLocTracebackFramesLimit(size_t value);
|
||||||
|
|
||||||
|
void registerTracebackFileInclusion(const std::string &file);
|
||||||
|
|
||||||
|
void registerTracebackFileExclusion(const std::string &file);
|
||||||
|
|
||||||
|
bool isUserTracebackFilename(llvm::StringRef file);
|
||||||
|
|
||||||
|
static constexpr size_t kMaxFrames = 512;
|
||||||
|
|
||||||
|
private:
|
||||||
|
nanobind::ft_mutex mutex;
|
||||||
|
bool locTracebackEnabled_ = false;
|
||||||
|
size_t locTracebackFramesLimit_ = 10;
|
||||||
|
std::unordered_set<std::string> userTracebackIncludeFiles;
|
||||||
|
std::unordered_set<std::string> userTracebackExcludeFiles;
|
||||||
|
std::regex userTracebackIncludeRegex;
|
||||||
|
bool rebuildUserTracebackIncludeRegex = false;
|
||||||
|
std::regex userTracebackExcludeRegex;
|
||||||
|
bool rebuildUserTracebackExcludeRegex = false;
|
||||||
|
llvm::StringMap<bool> isUserTracebackFilenameCache;
|
||||||
|
};
|
||||||
|
|
||||||
|
TracebackLoc &getTracebackLoc() { return tracebackLoc; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static PyGlobals *instance;
|
static PyGlobals *instance;
|
||||||
|
|
||||||
@ -134,6 +171,8 @@ private:
|
|||||||
/// Set of dialect namespaces that we have attempted to import implementation
|
/// Set of dialect namespaces that we have attempted to import implementation
|
||||||
/// modules for.
|
/// modules for.
|
||||||
llvm::StringSet<> loadedDialectModules;
|
llvm::StringSet<> loadedDialectModules;
|
||||||
|
|
||||||
|
TracebackLoc tracebackLoc;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace python
|
} // namespace python
|
||||||
|
@ -20,11 +20,8 @@
|
|||||||
#include "nanobind/nanobind.h"
|
#include "nanobind/nanobind.h"
|
||||||
#include "llvm/ADT/ArrayRef.h"
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
|
||||||
|
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <system_error>
|
|
||||||
#include <utility>
|
|
||||||
|
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
using namespace nb::literals;
|
using namespace nb::literals;
|
||||||
@ -1523,7 +1520,7 @@ nb::object PyOperation::create(std::string_view name,
|
|||||||
llvm::ArrayRef<MlirValue> operands,
|
llvm::ArrayRef<MlirValue> operands,
|
||||||
std::optional<nb::dict> attributes,
|
std::optional<nb::dict> attributes,
|
||||||
std::optional<std::vector<PyBlock *>> successors,
|
std::optional<std::vector<PyBlock *>> successors,
|
||||||
int regions, DefaultingPyLocation location,
|
int regions, PyLocation &location,
|
||||||
const nb::object &maybeIp, bool inferType) {
|
const nb::object &maybeIp, bool inferType) {
|
||||||
llvm::SmallVector<MlirType, 4> mlirResults;
|
llvm::SmallVector<MlirType, 4> mlirResults;
|
||||||
llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
|
llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
|
||||||
@ -1627,7 +1624,7 @@ nb::object PyOperation::create(std::string_view name,
|
|||||||
if (!operation.ptr)
|
if (!operation.ptr)
|
||||||
throw nb::value_error("Operation creation failed");
|
throw nb::value_error("Operation creation failed");
|
||||||
PyOperationRef created =
|
PyOperationRef created =
|
||||||
PyOperation::createDetached(location->getContext(), operation);
|
PyOperation::createDetached(location.getContext(), operation);
|
||||||
maybeInsertOperation(created, maybeIp);
|
maybeInsertOperation(created, maybeIp);
|
||||||
|
|
||||||
return created.getObject();
|
return created.getObject();
|
||||||
@ -1937,9 +1934,9 @@ nb::object PyOpView::buildGeneric(
|
|||||||
std::optional<nb::list> resultTypeList, nb::list operandList,
|
std::optional<nb::list> resultTypeList, nb::list operandList,
|
||||||
std::optional<nb::dict> attributes,
|
std::optional<nb::dict> attributes,
|
||||||
std::optional<std::vector<PyBlock *>> successors,
|
std::optional<std::vector<PyBlock *>> successors,
|
||||||
std::optional<int> regions, DefaultingPyLocation location,
|
std::optional<int> regions, PyLocation &location,
|
||||||
const nb::object &maybeIp) {
|
const nb::object &maybeIp) {
|
||||||
PyMlirContextRef context = location->getContext();
|
PyMlirContextRef context = location.getContext();
|
||||||
|
|
||||||
// Class level operation construction metadata.
|
// Class level operation construction metadata.
|
||||||
// Operand and result segment specs are either none, which does no
|
// Operand and result segment specs are either none, which does no
|
||||||
@ -2789,6 +2786,90 @@ private:
|
|||||||
PyOperationRef operation;
|
PyOperationRef operation;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
MlirLocation tracebackToLocation(MlirContext ctx) {
|
||||||
|
size_t framesLimit =
|
||||||
|
PyGlobals::get().getTracebackLoc().locTracebackFramesLimit();
|
||||||
|
// Use a thread_local here to avoid requiring a large amount of space.
|
||||||
|
thread_local std::array<MlirLocation, PyGlobals::TracebackLoc::kMaxFrames>
|
||||||
|
frames;
|
||||||
|
size_t count = 0;
|
||||||
|
|
||||||
|
nb::gil_scoped_acquire acquire;
|
||||||
|
PyThreadState *tstate = PyThreadState_GET();
|
||||||
|
PyFrameObject *next;
|
||||||
|
PyFrameObject *pyFrame = PyThreadState_GetFrame(tstate);
|
||||||
|
// In the increment expression:
|
||||||
|
// 1. get the next prev frame;
|
||||||
|
// 2. decrement the ref count on the current frame (in order that it can get
|
||||||
|
// gc'd, along with any objects in its closure and etc);
|
||||||
|
// 3. set current = next.
|
||||||
|
for (; pyFrame != nullptr && count < framesLimit;
|
||||||
|
next = PyFrame_GetBack(pyFrame), Py_XDECREF(pyFrame), pyFrame = next) {
|
||||||
|
PyCodeObject *code = PyFrame_GetCode(pyFrame);
|
||||||
|
auto fileNameStr =
|
||||||
|
nb::cast<std::string>(nb::borrow<nb::str>(code->co_filename));
|
||||||
|
llvm::StringRef fileName(fileNameStr);
|
||||||
|
if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
#if PY_VERSION_HEX < 0x030b00f0
|
||||||
|
std::string name =
|
||||||
|
nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));
|
||||||
|
llvm::StringRef funcName(name);
|
||||||
|
int startLine = PyFrame_GetLineNumber(pyFrame);
|
||||||
|
MlirLocation loc =
|
||||||
|
mlirLocationFileLineColGet(ctx, wrap(fileName), startLine, 0);
|
||||||
|
#else
|
||||||
|
// co_qualname and PyCode_Addr2Location added in py3.11
|
||||||
|
std::string name =
|
||||||
|
nb::cast<std::string>(nb::borrow<nb::str>(code->co_qualname));
|
||||||
|
llvm::StringRef funcName(name);
|
||||||
|
int startLine, startCol, endLine, endCol;
|
||||||
|
int lasti = PyFrame_GetLasti(pyFrame);
|
||||||
|
if (!PyCode_Addr2Location(code, lasti, &startLine, &startCol, &endLine,
|
||||||
|
&endCol)) {
|
||||||
|
throw nb::python_error();
|
||||||
|
}
|
||||||
|
MlirLocation loc = mlirLocationFileLineColRangeGet(
|
||||||
|
ctx, wrap(fileName), startLine, startCol, endLine, endCol);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
frames[count] = mlirLocationNameGet(ctx, wrap(funcName), loc);
|
||||||
|
++count;
|
||||||
|
}
|
||||||
|
// When the loop breaks (after the last iter), current frame (if non-null)
|
||||||
|
// is leaked without this.
|
||||||
|
Py_XDECREF(pyFrame);
|
||||||
|
|
||||||
|
if (count == 0)
|
||||||
|
return mlirLocationUnknownGet(ctx);
|
||||||
|
|
||||||
|
MlirLocation callee = frames[0];
|
||||||
|
assert(!mlirLocationIsNull(callee) && "expected non-null callee location");
|
||||||
|
if (count == 1)
|
||||||
|
return callee;
|
||||||
|
|
||||||
|
MlirLocation caller = frames[count - 1];
|
||||||
|
assert(!mlirLocationIsNull(caller) && "expected non-null caller location");
|
||||||
|
for (int i = count - 2; i >= 1; i--)
|
||||||
|
caller = mlirLocationCallSiteGet(frames[i], caller);
|
||||||
|
|
||||||
|
return mlirLocationCallSiteGet(callee, caller);
|
||||||
|
}
|
||||||
|
|
||||||
|
PyLocation
|
||||||
|
maybeGetTracebackLocation(const std::optional<PyLocation> &location) {
|
||||||
|
if (location.has_value())
|
||||||
|
return location.value();
|
||||||
|
if (!PyGlobals::get().getTracebackLoc().locTracebacksEnabled())
|
||||||
|
return DefaultingPyLocation::resolve();
|
||||||
|
|
||||||
|
PyMlirContext &ctx = DefaultingPyMlirContext::resolve();
|
||||||
|
MlirLocation mlirLoc = tracebackToLocation(ctx.get());
|
||||||
|
PyMlirContextRef ref = PyMlirContext::forContext(ctx.get());
|
||||||
|
return {ref, mlirLoc};
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
//------------------------------------------------------------------------------
|
//------------------------------------------------------------------------------
|
||||||
@ -3052,10 +3133,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
|
|||||||
.def("__eq__", [](PyLocation &self, nb::object other) { return false; })
|
.def("__eq__", [](PyLocation &self, nb::object other) { return false; })
|
||||||
.def_prop_ro_static(
|
.def_prop_ro_static(
|
||||||
"current",
|
"current",
|
||||||
[](nb::object & /*class*/) {
|
[](nb::object & /*class*/) -> std::optional<PyLocation *> {
|
||||||
auto *loc = PyThreadContextEntry::getDefaultLocation();
|
auto *loc = PyThreadContextEntry::getDefaultLocation();
|
||||||
if (!loc)
|
if (!loc)
|
||||||
throw nb::value_error("No current Location");
|
return std::nullopt;
|
||||||
return loc;
|
return loc;
|
||||||
},
|
},
|
||||||
"Gets the Location bound to the current thread or raises ValueError")
|
"Gets the Location bound to the current thread or raises ValueError")
|
||||||
@ -3240,8 +3321,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
|
|||||||
kModuleParseDocstring)
|
kModuleParseDocstring)
|
||||||
.def_static(
|
.def_static(
|
||||||
"create",
|
"create",
|
||||||
[](DefaultingPyLocation loc) {
|
[](const std::optional<PyLocation> &loc) {
|
||||||
MlirModule module = mlirModuleCreateEmpty(loc);
|
PyLocation pyLoc = maybeGetTracebackLocation(loc);
|
||||||
|
MlirModule module = mlirModuleCreateEmpty(pyLoc.get());
|
||||||
return PyModule::forModule(module).releaseObject();
|
return PyModule::forModule(module).releaseObject();
|
||||||
},
|
},
|
||||||
nb::arg("loc").none() = nb::none(), "Creates an empty module")
|
nb::arg("loc").none() = nb::none(), "Creates an empty module")
|
||||||
@ -3462,8 +3544,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
|
|||||||
std::optional<std::vector<PyValue *>> operands,
|
std::optional<std::vector<PyValue *>> operands,
|
||||||
std::optional<nb::dict> attributes,
|
std::optional<nb::dict> attributes,
|
||||||
std::optional<std::vector<PyBlock *>> successors, int regions,
|
std::optional<std::vector<PyBlock *>> successors, int regions,
|
||||||
DefaultingPyLocation location, const nb::object &maybeIp,
|
const std::optional<PyLocation> &location,
|
||||||
bool inferType) {
|
const nb::object &maybeIp, bool inferType) {
|
||||||
// Unpack/validate operands.
|
// Unpack/validate operands.
|
||||||
llvm::SmallVector<MlirValue, 4> mlirOperands;
|
llvm::SmallVector<MlirValue, 4> mlirOperands;
|
||||||
if (operands) {
|
if (operands) {
|
||||||
@ -3475,8 +3557,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PyLocation pyLoc = maybeGetTracebackLocation(location);
|
||||||
return PyOperation::create(name, results, mlirOperands, attributes,
|
return PyOperation::create(name, results, mlirOperands, attributes,
|
||||||
successors, regions, location, maybeIp,
|
successors, regions, pyLoc, maybeIp,
|
||||||
inferType);
|
inferType);
|
||||||
},
|
},
|
||||||
nb::arg("name"), nb::arg("results").none() = nb::none(),
|
nb::arg("name"), nb::arg("results").none() = nb::none(),
|
||||||
@ -3520,12 +3603,14 @@ void mlir::python::populateIRCore(nb::module_ &m) {
|
|||||||
std::optional<nb::list> resultTypeList, nb::list operandList,
|
std::optional<nb::list> resultTypeList, nb::list operandList,
|
||||||
std::optional<nb::dict> attributes,
|
std::optional<nb::dict> attributes,
|
||||||
std::optional<std::vector<PyBlock *>> successors,
|
std::optional<std::vector<PyBlock *>> successors,
|
||||||
std::optional<int> regions, DefaultingPyLocation location,
|
std::optional<int> regions,
|
||||||
|
const std::optional<PyLocation> &location,
|
||||||
const nb::object &maybeIp) {
|
const nb::object &maybeIp) {
|
||||||
|
PyLocation pyLoc = maybeGetTracebackLocation(location);
|
||||||
new (self) PyOpView(PyOpView::buildGeneric(
|
new (self) PyOpView(PyOpView::buildGeneric(
|
||||||
name, opRegionSpec, operandSegmentSpecObj,
|
name, opRegionSpec, operandSegmentSpecObj,
|
||||||
resultSegmentSpecObj, resultTypeList, operandList,
|
resultSegmentSpecObj, resultTypeList, operandList,
|
||||||
attributes, successors, regions, location, maybeIp));
|
attributes, successors, regions, pyLoc, maybeIp));
|
||||||
},
|
},
|
||||||
nb::arg("name"), nb::arg("opRegionSpec"),
|
nb::arg("name"), nb::arg("opRegionSpec"),
|
||||||
nb::arg("operandSegmentSpecObj").none() = nb::none(),
|
nb::arg("operandSegmentSpecObj").none() = nb::none(),
|
||||||
@ -3559,17 +3644,18 @@ void mlir::python::populateIRCore(nb::module_ &m) {
|
|||||||
[](nb::handle cls, std::optional<nb::list> resultTypeList,
|
[](nb::handle cls, std::optional<nb::list> resultTypeList,
|
||||||
nb::list operandList, std::optional<nb::dict> attributes,
|
nb::list operandList, std::optional<nb::dict> attributes,
|
||||||
std::optional<std::vector<PyBlock *>> successors,
|
std::optional<std::vector<PyBlock *>> successors,
|
||||||
std::optional<int> regions, DefaultingPyLocation location,
|
std::optional<int> regions, std::optional<PyLocation> location,
|
||||||
const nb::object &maybeIp) {
|
const nb::object &maybeIp) {
|
||||||
std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
|
std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
|
||||||
std::tuple<int, bool> opRegionSpec =
|
std::tuple<int, bool> opRegionSpec =
|
||||||
nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
|
nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
|
||||||
nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
|
nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
|
||||||
nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
|
nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
|
||||||
|
PyLocation pyLoc = maybeGetTracebackLocation(location);
|
||||||
return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
|
return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
|
||||||
resultSegmentSpec, resultTypeList,
|
resultSegmentSpec, resultTypeList,
|
||||||
operandList, attributes, successors,
|
operandList, attributes, successors,
|
||||||
regions, location, maybeIp);
|
regions, pyLoc, maybeIp);
|
||||||
},
|
},
|
||||||
nb::arg("cls"), nb::arg("results").none() = nb::none(),
|
nb::arg("cls"), nb::arg("results").none() = nb::none(),
|
||||||
nb::arg("operands").none() = nb::none(),
|
nb::arg("operands").none() = nb::none(),
|
||||||
|
@ -13,9 +13,9 @@
|
|||||||
|
|
||||||
#include "Globals.h"
|
#include "Globals.h"
|
||||||
#include "NanobindUtils.h"
|
#include "NanobindUtils.h"
|
||||||
|
#include "mlir-c/Bindings/Python/Interop.h"
|
||||||
#include "mlir-c/Support.h"
|
#include "mlir-c/Support.h"
|
||||||
#include "mlir/Bindings/Python/Nanobind.h"
|
#include "mlir/Bindings/Python/Nanobind.h"
|
||||||
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
|
|
||||||
|
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
@ -197,3 +197,71 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
|
|||||||
// Not found and loading did not yield a registration.
|
// Not found and loading did not yield a registration.
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool PyGlobals::TracebackLoc::locTracebacksEnabled() {
|
||||||
|
nanobind::ft_lock_guard lock(mutex);
|
||||||
|
return locTracebackEnabled_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void PyGlobals::TracebackLoc::setLocTracebacksEnabled(bool value) {
|
||||||
|
nanobind::ft_lock_guard lock(mutex);
|
||||||
|
locTracebackEnabled_ = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t PyGlobals::TracebackLoc::locTracebackFramesLimit() {
|
||||||
|
nanobind::ft_lock_guard lock(mutex);
|
||||||
|
return locTracebackFramesLimit_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void PyGlobals::TracebackLoc::setLocTracebackFramesLimit(size_t value) {
|
||||||
|
nanobind::ft_lock_guard lock(mutex);
|
||||||
|
locTracebackFramesLimit_ = std::min(value, kMaxFrames);
|
||||||
|
}
|
||||||
|
|
||||||
|
void PyGlobals::TracebackLoc::registerTracebackFileInclusion(
|
||||||
|
const std::string &file) {
|
||||||
|
nanobind::ft_lock_guard lock(mutex);
|
||||||
|
auto reg = "^" + llvm::Regex::escape(file);
|
||||||
|
if (userTracebackIncludeFiles.insert(reg).second)
|
||||||
|
rebuildUserTracebackIncludeRegex = true;
|
||||||
|
if (userTracebackExcludeFiles.count(reg)) {
|
||||||
|
if (userTracebackExcludeFiles.erase(reg))
|
||||||
|
rebuildUserTracebackExcludeRegex = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void PyGlobals::TracebackLoc::registerTracebackFileExclusion(
|
||||||
|
const std::string &file) {
|
||||||
|
nanobind::ft_lock_guard lock(mutex);
|
||||||
|
auto reg = "^" + llvm::Regex::escape(file);
|
||||||
|
if (userTracebackExcludeFiles.insert(reg).second)
|
||||||
|
rebuildUserTracebackExcludeRegex = true;
|
||||||
|
if (userTracebackIncludeFiles.count(reg)) {
|
||||||
|
if (userTracebackIncludeFiles.erase(reg))
|
||||||
|
rebuildUserTracebackIncludeRegex = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool PyGlobals::TracebackLoc::isUserTracebackFilename(
|
||||||
|
const llvm::StringRef file) {
|
||||||
|
nanobind::ft_lock_guard lock(mutex);
|
||||||
|
if (rebuildUserTracebackIncludeRegex) {
|
||||||
|
userTracebackIncludeRegex.assign(
|
||||||
|
llvm::join(userTracebackIncludeFiles, "|"));
|
||||||
|
rebuildUserTracebackIncludeRegex = false;
|
||||||
|
isUserTracebackFilenameCache.clear();
|
||||||
|
}
|
||||||
|
if (rebuildUserTracebackExcludeRegex) {
|
||||||
|
userTracebackExcludeRegex.assign(
|
||||||
|
llvm::join(userTracebackExcludeFiles, "|"));
|
||||||
|
rebuildUserTracebackExcludeRegex = false;
|
||||||
|
isUserTracebackFilenameCache.clear();
|
||||||
|
}
|
||||||
|
if (!isUserTracebackFilenameCache.contains(file)) {
|
||||||
|
std::string fileStr = file.str();
|
||||||
|
bool include = std::regex_search(fileStr, userTracebackIncludeRegex);
|
||||||
|
bool exclude = std::regex_search(fileStr, userTracebackExcludeRegex);
|
||||||
|
isUserTracebackFilenameCache[file] = include || !exclude;
|
||||||
|
}
|
||||||
|
return isUserTracebackFilenameCache[file];
|
||||||
|
}
|
||||||
|
@ -192,16 +192,6 @@ public:
|
|||||||
PyMlirContext(const PyMlirContext &) = delete;
|
PyMlirContext(const PyMlirContext &) = delete;
|
||||||
PyMlirContext(PyMlirContext &&) = delete;
|
PyMlirContext(PyMlirContext &&) = delete;
|
||||||
|
|
||||||
/// For the case of a python __init__ (nanobind::init) method, pybind11 is
|
|
||||||
/// quite strict about needing to return a pointer that is not yet associated
|
|
||||||
/// to an nanobind::object. Since the forContext() method acts like a pool,
|
|
||||||
/// possibly returning a recycled context, it does not satisfy this need. The
|
|
||||||
/// usual way in python to accomplish such a thing is to override __new__, but
|
|
||||||
/// that is also not supported by pybind11. Instead, we use this entry
|
|
||||||
/// point which always constructs a fresh context (which cannot alias an
|
|
||||||
/// existing one because it is fresh).
|
|
||||||
static PyMlirContext *createNewContextForInit();
|
|
||||||
|
|
||||||
/// Returns a context reference for the singleton PyMlirContext wrapper for
|
/// Returns a context reference for the singleton PyMlirContext wrapper for
|
||||||
/// the given context.
|
/// the given context.
|
||||||
static PyMlirContextRef forContext(MlirContext context);
|
static PyMlirContextRef forContext(MlirContext context);
|
||||||
@ -722,8 +712,7 @@ public:
|
|||||||
llvm::ArrayRef<MlirValue> operands,
|
llvm::ArrayRef<MlirValue> operands,
|
||||||
std::optional<nanobind::dict> attributes,
|
std::optional<nanobind::dict> attributes,
|
||||||
std::optional<std::vector<PyBlock *>> successors, int regions,
|
std::optional<std::vector<PyBlock *>> successors, int regions,
|
||||||
DefaultingPyLocation location, const nanobind::object &ip,
|
PyLocation &location, const nanobind::object &ip, bool inferType);
|
||||||
bool inferType);
|
|
||||||
|
|
||||||
/// Creates an OpView suitable for this operation.
|
/// Creates an OpView suitable for this operation.
|
||||||
nanobind::object createOpView();
|
nanobind::object createOpView();
|
||||||
@ -781,7 +770,7 @@ public:
|
|||||||
nanobind::list operandList,
|
nanobind::list operandList,
|
||||||
std::optional<nanobind::dict> attributes,
|
std::optional<nanobind::dict> attributes,
|
||||||
std::optional<std::vector<PyBlock *>> successors,
|
std::optional<std::vector<PyBlock *>> successors,
|
||||||
std::optional<int> regions, DefaultingPyLocation location,
|
std::optional<int> regions, PyLocation &location,
|
||||||
const nanobind::object &maybeIp);
|
const nanobind::object &maybeIp);
|
||||||
|
|
||||||
/// Construct an instance of a class deriving from OpView, bypassing its
|
/// Construct an instance of a class deriving from OpView, bypassing its
|
||||||
|
@ -6,7 +6,6 @@
|
|||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
|
||||||
#include "Globals.h"
|
#include "Globals.h"
|
||||||
#include "IRModule.h"
|
#include "IRModule.h"
|
||||||
#include "NanobindUtils.h"
|
#include "NanobindUtils.h"
|
||||||
@ -44,7 +43,27 @@ NB_MODULE(_mlir, m) {
|
|||||||
.def("_register_operation_impl", &PyGlobals::registerOperationImpl,
|
.def("_register_operation_impl", &PyGlobals::registerOperationImpl,
|
||||||
"operation_name"_a, "operation_class"_a, nb::kw_only(),
|
"operation_name"_a, "operation_class"_a, nb::kw_only(),
|
||||||
"replace"_a = false,
|
"replace"_a = false,
|
||||||
"Testing hook for directly registering an operation");
|
"Testing hook for directly registering an operation")
|
||||||
|
.def("loc_tracebacks_enabled",
|
||||||
|
[](PyGlobals &self) {
|
||||||
|
return self.getTracebackLoc().locTracebacksEnabled();
|
||||||
|
})
|
||||||
|
.def("set_loc_tracebacks_enabled",
|
||||||
|
[](PyGlobals &self, bool enabled) {
|
||||||
|
self.getTracebackLoc().setLocTracebacksEnabled(enabled);
|
||||||
|
})
|
||||||
|
.def("set_loc_tracebacks_frame_limit",
|
||||||
|
[](PyGlobals &self, int n) {
|
||||||
|
self.getTracebackLoc().setLocTracebackFramesLimit(n);
|
||||||
|
})
|
||||||
|
.def("register_traceback_file_inclusion",
|
||||||
|
[](PyGlobals &self, const std::string &filename) {
|
||||||
|
self.getTracebackLoc().registerTracebackFileInclusion(filename);
|
||||||
|
})
|
||||||
|
.def("register_traceback_file_exclusion",
|
||||||
|
[](PyGlobals &self, const std::string &filename) {
|
||||||
|
self.getTracebackLoc().registerTracebackFileExclusion(filename);
|
||||||
|
});
|
||||||
|
|
||||||
// Aside from making the globals accessible to python, having python manage
|
// Aside from making the globals accessible to python, having python manage
|
||||||
// it is necessary to make sure it is destroyed (and releases its python
|
// it is necessary to make sure it is destroyed (and releases its python
|
||||||
|
@ -78,12 +78,12 @@ def equally_sized_accessor(
|
|||||||
def get_default_loc_context(location=None):
|
def get_default_loc_context(location=None):
|
||||||
"""
|
"""
|
||||||
Returns a context in which the defaulted location is created. If the location
|
Returns a context in which the defaulted location is created. If the location
|
||||||
is None, takes the current location from the stack, raises ValueError if there
|
is None, takes the current location from the stack.
|
||||||
is no location on the stack.
|
|
||||||
"""
|
"""
|
||||||
if location is None:
|
if location is None:
|
||||||
# Location.current raises ValueError if there is no current location.
|
if _cext.ir.Location.current:
|
||||||
return _cext.ir.Location.current.context
|
return _cext.ir.Location.current.context
|
||||||
|
return None
|
||||||
return location.context
|
return location.context
|
||||||
|
|
||||||
|
|
||||||
|
@ -378,3 +378,6 @@ if config.run_rocm_tests:
|
|||||||
|
|
||||||
if config.arm_emulator_executable:
|
if config.arm_emulator_executable:
|
||||||
config.available_features.add("arm-emulator")
|
config.available_features.add("arm-emulator")
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 11):
|
||||||
|
config.available_features.add("python-ge-311")
|
||||||
|
101
mlir/test/python/ir/auto_location.py
Normal file
101
mlir/test/python/ir/auto_location.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
# RUN: %PYTHON %s | FileCheck %s
|
||||||
|
# REQUIRES: python-ge-311
|
||||||
|
import gc
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
from mlir.ir import *
|
||||||
|
from mlir.dialects._ods_common import _cext
|
||||||
|
from mlir.dialects import arith, _arith_ops_gen
|
||||||
|
|
||||||
|
|
||||||
|
def run(f):
|
||||||
|
print("\nTEST:", f.__name__)
|
||||||
|
f()
|
||||||
|
gc.collect()
|
||||||
|
assert Context._get_live_count() == 0
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def with_infer_location():
|
||||||
|
_cext.globals.set_loc_tracebacks_enabled(True)
|
||||||
|
yield
|
||||||
|
_cext.globals.set_loc_tracebacks_enabled(False)
|
||||||
|
|
||||||
|
|
||||||
|
# CHECK-LABEL: TEST: testInferLocations
|
||||||
|
@run
|
||||||
|
def testInferLocations():
|
||||||
|
with Context() as ctx, with_infer_location():
|
||||||
|
ctx.allow_unregistered_dialects = True
|
||||||
|
|
||||||
|
op = Operation.create("custom.op1")
|
||||||
|
one = arith.constant(IndexType.get(), 1)
|
||||||
|
_cext.globals.register_traceback_file_exclusion(arith.__file__)
|
||||||
|
two = arith.constant(IndexType.get(), 2)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
# CHECK: loc(callsite("testInferLocations"("{{.*}}[[SEP:[/\\]]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":31:13 to :43) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4))))
|
||||||
|
# fmt: on
|
||||||
|
print(op.location)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
# CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]arith.py":65:12 to :76) at callsite("constant"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]arith.py":110:40 to :81) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":32:14 to :48) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4))))))
|
||||||
|
# fmt: on
|
||||||
|
print(one.location)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
# CHECK: loc(callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":34:14 to :48) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4))))
|
||||||
|
# fmt: on
|
||||||
|
print(two.location)
|
||||||
|
|
||||||
|
_cext.globals.register_traceback_file_inclusion(_arith_ops_gen.__file__)
|
||||||
|
three = arith.constant(IndexType.get(), 3)
|
||||||
|
# fmt: off
|
||||||
|
# CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":405:4 to :218) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":52:16 to :50) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4)))))
|
||||||
|
# fmt: on
|
||||||
|
print(three.location)
|
||||||
|
|
||||||
|
def foo():
|
||||||
|
four = arith.constant(IndexType.get(), 4)
|
||||||
|
print(four.location)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
# CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":405:4 to :218) at callsite("testInferLocations.<locals>.foo"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":59:19 to :53) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":65:8 to :13) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4))))))
|
||||||
|
# fmt: on
|
||||||
|
foo()
|
||||||
|
|
||||||
|
_cext.globals.register_traceback_file_exclusion(__file__)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
# CHECK: loc("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":405:4 to :218))
|
||||||
|
# fmt: on
|
||||||
|
foo()
|
||||||
|
|
||||||
|
def bar1():
|
||||||
|
def bar2():
|
||||||
|
def bar3():
|
||||||
|
five = arith.constant(IndexType.get(), 5)
|
||||||
|
print(five.location)
|
||||||
|
|
||||||
|
bar3()
|
||||||
|
|
||||||
|
bar2()
|
||||||
|
|
||||||
|
_cext.globals.register_traceback_file_inclusion(__file__)
|
||||||
|
_cext.globals.register_traceback_file_exclusion(_arith_ops_gen.__file__)
|
||||||
|
|
||||||
|
_cext.globals.set_loc_tracebacks_frame_limit(2)
|
||||||
|
# fmt: off
|
||||||
|
# CHECK: loc(callsite("testInferLocations.<locals>.bar1.<locals>.bar2.<locals>.bar3"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":77:27 to :61) at "testInferLocations.<locals>.bar1.<locals>.bar2"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":80:16 to :22)))
|
||||||
|
# fmt: on
|
||||||
|
bar1()
|
||||||
|
|
||||||
|
_cext.globals.set_loc_tracebacks_frame_limit(1)
|
||||||
|
# fmt: off
|
||||||
|
# CHECK: loc("testInferLocations.<locals>.bar1.<locals>.bar2.<locals>.bar3"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":77:27 to :61))
|
||||||
|
# fmt: on
|
||||||
|
bar1()
|
||||||
|
|
||||||
|
_cext.globals.set_loc_tracebacks_frame_limit(0)
|
||||||
|
# CHECK: loc(unknown)
|
||||||
|
bar1()
|
@ -35,25 +35,14 @@ def testLocationEnterExit():
|
|||||||
# Asserting a different context should clear it.
|
# Asserting a different context should clear it.
|
||||||
with Context() as ctx2:
|
with Context() as ctx2:
|
||||||
assert Context.current is ctx2
|
assert Context.current is ctx2
|
||||||
try:
|
assert Location.current is None
|
||||||
_ = Location.current
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
assert False, "Expected exception"
|
|
||||||
|
|
||||||
# And should restore.
|
# And should restore.
|
||||||
assert Context.current is ctx1
|
assert Context.current is ctx1
|
||||||
assert Location.current is loc1
|
assert Location.current is loc1
|
||||||
|
|
||||||
# All should clear.
|
# All should clear.
|
||||||
try:
|
assert Location.current is None
|
||||||
_ = Location.current
|
|
||||||
except ValueError as e:
|
|
||||||
# CHECK: No current Location
|
|
||||||
print(e)
|
|
||||||
else:
|
|
||||||
assert False, "Expected exception"
|
|
||||||
|
|
||||||
|
|
||||||
run(testLocationEnterExit)
|
run(testLocationEnterExit)
|
||||||
@ -72,12 +61,7 @@ def testInsertionPointEnterExit():
|
|||||||
assert InsertionPoint.current is ip
|
assert InsertionPoint.current is ip
|
||||||
assert Location.current is loc1
|
assert Location.current is loc1
|
||||||
# Location should clear.
|
# Location should clear.
|
||||||
try:
|
assert Location.current is None
|
||||||
_ = Location.current
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
assert False, "Expected exception"
|
|
||||||
|
|
||||||
# Asserting the same Context should preserve.
|
# Asserting the same Context should preserve.
|
||||||
with ctx1:
|
with ctx1:
|
||||||
|
@ -41,6 +41,7 @@ from ._ods_common import (
|
|||||||
segmented_accessor as _ods_segmented_accessor,
|
segmented_accessor as _ods_segmented_accessor,
|
||||||
)
|
)
|
||||||
_ods_ir = _ods_cext.ir
|
_ods_ir = _ods_cext.ir
|
||||||
|
_ods_cext.globals.register_traceback_file_exclusion(__file__)
|
||||||
|
|
||||||
import builtins
|
import builtins
|
||||||
from typing import Sequence as _Sequence, Union as _Union
|
from typing import Sequence as _Sequence, Union as _Union
|
||||||
|
Loading…
x
Reference in New Issue
Block a user