[mlir] Python: write bytecode to a file path (#127118)
The current `write_bytecode` implementation necessarily requires the serialized module to be duplicated in memory when the python `bytes` object is created and sent over the binding. For modules with large resources, we may want to avoid this in-memory copy by serializing directly to a file instead of sending bytes across the boundary.
This commit is contained in:
parent
62ec7b8de9
commit
a60e8a2c25
@ -6,12 +6,10 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
|
||||
#include "Globals.h"
|
||||
#include "IRModule.h"
|
||||
#include "NanobindUtils.h"
|
||||
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
|
||||
#include "mlir-c/BuiltinAttributes.h"
|
||||
#include "mlir-c/Debug.h"
|
||||
#include "mlir-c/Diagnostics.h"
|
||||
@ -19,9 +17,14 @@
|
||||
#include "mlir-c/Support.h"
|
||||
#include "mlir/Bindings/Python/Nanobind.h"
|
||||
#include "mlir/Bindings/Python/NanobindAdaptors.h"
|
||||
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
|
||||
#include "nanobind/nanobind.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <optional>
|
||||
#include <system_error>
|
||||
#include <utility>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
@ -1329,11 +1332,11 @@ void PyOperationBase::print(PyAsmState &state, nb::object fileObject,
|
||||
accum.getUserData());
|
||||
}
|
||||
|
||||
void PyOperationBase::writeBytecode(const nb::object &fileObject,
|
||||
void PyOperationBase::writeBytecode(const nb::object &fileOrStringObject,
|
||||
std::optional<int64_t> bytecodeVersion) {
|
||||
PyOperation &operation = getOperation();
|
||||
operation.checkValid();
|
||||
PyFileAccumulator accum(fileObject, /*binary=*/true);
|
||||
PyFileAccumulator accum(fileOrStringObject, /*binary=*/true);
|
||||
|
||||
if (!bytecodeVersion.has_value())
|
||||
return mlirOperationWriteBytecode(operation, accum.getCallback(),
|
||||
|
@ -13,8 +13,13 @@
|
||||
#include "mlir-c/Support.h"
|
||||
#include "mlir/Bindings/Python/Nanobind.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/Support/DataTypes.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <string>
|
||||
#include <variant>
|
||||
|
||||
template <>
|
||||
struct std::iterator_traits<nanobind::detail::fast_iterator> {
|
||||
@ -128,33 +133,59 @@ struct PyPrintAccumulator {
|
||||
}
|
||||
};
|
||||
|
||||
/// Accumulates int a python file-like object, either writing text (default)
|
||||
/// or binary.
|
||||
/// Accumulates into a file, either writing text (default)
|
||||
/// or binary. The file may be a Python file-like object or a path to a file.
|
||||
class PyFileAccumulator {
|
||||
public:
|
||||
PyFileAccumulator(const nanobind::object &fileObject, bool binary)
|
||||
: pyWriteFunction(fileObject.attr("write")), binary(binary) {}
|
||||
PyFileAccumulator(const nanobind::object &fileOrStringObject, bool binary)
|
||||
: binary(binary) {
|
||||
std::string filePath;
|
||||
if (nanobind::try_cast<std::string>(fileOrStringObject, filePath)) {
|
||||
std::error_code ec;
|
||||
writeTarget.emplace<llvm::raw_fd_ostream>(filePath, ec);
|
||||
if (ec) {
|
||||
throw nanobind::value_error(
|
||||
(std::string("Unable to open file for writing: ") + ec.message())
|
||||
.c_str());
|
||||
}
|
||||
} else {
|
||||
writeTarget.emplace<nanobind::object>(fileOrStringObject.attr("write"));
|
||||
}
|
||||
}
|
||||
|
||||
MlirStringCallback getCallback() {
|
||||
return writeTarget.index() == 0 ? getPyWriteCallback()
|
||||
: getOstreamCallback();
|
||||
}
|
||||
|
||||
void *getUserData() { return this; }
|
||||
|
||||
MlirStringCallback getCallback() {
|
||||
private:
|
||||
MlirStringCallback getPyWriteCallback() {
|
||||
return [](MlirStringRef part, void *userData) {
|
||||
nanobind::gil_scoped_acquire acquire;
|
||||
PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
|
||||
if (accum->binary) {
|
||||
// Note: Still has to copy and not avoidable with this API.
|
||||
nanobind::bytes pyBytes(part.data, part.length);
|
||||
accum->pyWriteFunction(pyBytes);
|
||||
std::get<nanobind::object>(accum->writeTarget)(pyBytes);
|
||||
} else {
|
||||
nanobind::str pyStr(part.data,
|
||||
part.length); // Decodes as UTF-8 by default.
|
||||
accum->pyWriteFunction(pyStr);
|
||||
std::get<nanobind::object>(accum->writeTarget)(pyStr);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
private:
|
||||
nanobind::object pyWriteFunction;
|
||||
MlirStringCallback getOstreamCallback() {
|
||||
return [](MlirStringRef part, void *userData) {
|
||||
PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
|
||||
std::get<llvm::raw_fd_ostream>(accum->writeTarget)
|
||||
.write(part.data, part.length);
|
||||
};
|
||||
}
|
||||
|
||||
std::variant<nanobind::object, llvm::raw_fd_ostream> writeTarget;
|
||||
bool binary;
|
||||
};
|
||||
|
||||
|
@ -47,7 +47,7 @@ import collections
|
||||
from collections.abc import Callable, Sequence
|
||||
import io
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar, TypeVar, overload
|
||||
from typing import Any, BinaryIO, ClassVar, TypeVar, overload
|
||||
|
||||
__all__ = [
|
||||
"AffineAddExpr",
|
||||
@ -285,12 +285,12 @@ class _OperationBase:
|
||||
"""
|
||||
Verify the operation. Raises MLIRError if verification fails, and returns true otherwise.
|
||||
"""
|
||||
def write_bytecode(self, file: Any, desired_version: int | None = None) -> None:
|
||||
def write_bytecode(self, file: BinaryIO | str, desired_version: int | None = None) -> None:
|
||||
"""
|
||||
Write the bytecode form of the operation to a file like object.
|
||||
|
||||
Args:
|
||||
file: The file like object to write to.
|
||||
file: The file like object or path to write to.
|
||||
desired_version: The version of bytecode to emit.
|
||||
Returns:
|
||||
The bytecode writer status.
|
||||
|
@ -3,6 +3,7 @@
|
||||
import gc
|
||||
import io
|
||||
import itertools
|
||||
from tempfile import NamedTemporaryFile
|
||||
from mlir.ir import *
|
||||
from mlir.dialects.builtin import ModuleOp
|
||||
from mlir.dialects import arith
|
||||
@ -617,6 +618,12 @@ def testOperationPrint():
|
||||
module.operation.write_bytecode(bytecode_stream, desired_version=1)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
assert bytecode.startswith(b"ML\xefR"), "Expected bytecode to start with MLïR"
|
||||
with NamedTemporaryFile() as tmpfile:
|
||||
module.operation.write_bytecode(str(tmpfile.name), desired_version=1)
|
||||
tmpfile.seek(0)
|
||||
assert tmpfile.read().startswith(
|
||||
b"ML\xefR"
|
||||
), "Expected bytecode to start with MLïR"
|
||||
ctx2 = Context()
|
||||
module_roundtrip = Module.parse(bytecode, ctx2)
|
||||
f = io.StringIO()
|
||||
|
Loading…
x
Reference in New Issue
Block a user