[mlir:Bytecode] Add initial support for dialect defined attribute/type encodings
Dialects can opt-in to providing custom encodings by implementing the `BytecodeDialectInterface`. This interface provides hooks, namely `readAttribute`/`readType` and `writeAttribute`/`writeType`, that will be used by the bytecode reader and writer. These hooks are provided a reader and writer implementation that can be used to encode various constructs in the underlying bytecode format. A unique feature of this interface is that dialects may choose to only encode a subset of their attributes and types in a custom bytecode format, which can simplify adding new or experimental components that aren't fully baked. Differential Revision: https://reviews.llvm.org/D132498
This commit is contained in:
parent
b3449392f5
commit
02c2ecb9c6
@ -207,7 +207,26 @@ reference to the parent dialect instead.
|
||||
|
||||
##### Dialect Defined Encoding
|
||||
|
||||
TODO: This is not yet supported.
|
||||
In addition to the assembly format fallback, dialects may also provide a custom
|
||||
encoding for their attributes and types. Custom encodings are very beneficial in
|
||||
that they are significantly smaller and faster to read and write.
|
||||
|
||||
Dialects can opt-in to providing custom encodings by implementing the
|
||||
`BytecodeDialectInterface`. This interface provides hooks, namely
|
||||
`readAttribute`/`readType` and `writeAttribute`/`writeType`, that will be used
|
||||
by the bytecode reader and writer. These hooks are provided a reader and writer
|
||||
implementation that can be used to encode various constructs in the underlying
|
||||
bytecode format. A unique feature of this interface is that dialects may choose
|
||||
to only encode a subset of their attributes and types in a custom bytecode
|
||||
format, which can simplify adding new or experimental components that aren't
|
||||
fully baked.
|
||||
|
||||
When implementing the bytecode interface, dialects are responsible for all
|
||||
aspects of the encoding. This includes the indicator for which kind of attribute
|
||||
or type is being encoded; the bytecode reader will only know that it has
|
||||
encountered an attribute or type of a given dialect, it doesn't encode any
|
||||
further information. As such, a common encoding idiom is to use a leading
|
||||
`varint` code to indicate how the attribute or type was encoded.
|
||||
|
||||
### IR Section
|
||||
|
||||
|
220
mlir/include/mlir/Bytecode/BytecodeImplementation.h
Normal file
220
mlir/include/mlir/Bytecode/BytecodeImplementation.h
Normal file
@ -0,0 +1,220 @@
|
||||
//===- BytecodeImplementation.h - MLIR Bytecode Implementation --*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This header defines various interfaces and utilities necessary for dialects
|
||||
// to hook into bytecode serialization.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H
|
||||
#define MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H
|
||||
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/DialectInterface.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
|
||||
namespace mlir {
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DialectBytecodeReader
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This class defines a virtual interface for reading a bytecode stream,
|
||||
/// providing hooks into the bytecode reader. As such, this class should only be
|
||||
/// derived and defined by the main bytecode reader, users (i.e. dialects)
|
||||
/// should generally only interact with this class via the
|
||||
/// BytecodeDialectInterface below.
|
||||
class DialectBytecodeReader {
|
||||
public:
|
||||
virtual ~DialectBytecodeReader() = default;
|
||||
|
||||
/// Emit an error to the reader.
|
||||
virtual InFlightDiagnostic emitError(const Twine &msg = {}) = 0;
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// IR
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Read out a list of elements, invoking the provided callback for each
|
||||
/// element. The callback function may be in any of the following forms:
|
||||
/// * LogicalResult(T &)
|
||||
/// * FailureOr<T>()
|
||||
template <typename T, typename CallbackFn>
|
||||
LogicalResult readList(SmallVectorImpl<T> &result, CallbackFn &&callback) {
|
||||
uint64_t size;
|
||||
if (failed(readVarInt(size)))
|
||||
return failure();
|
||||
result.reserve(size);
|
||||
|
||||
for (uint64_t i = 0; i < size; ++i) {
|
||||
// Check if the callback uses FailureOr, or populates the result by
|
||||
// reference.
|
||||
if constexpr (llvm::function_traits<std::decay_t<CallbackFn>>::num_args) {
|
||||
T element = {};
|
||||
if (failed(callback(element)))
|
||||
return failure();
|
||||
result.emplace_back(std::move(element));
|
||||
} else {
|
||||
FailureOr<T> element = callback();
|
||||
if (failed(element))
|
||||
return failure();
|
||||
result.emplace_back(std::move(*element));
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Read a reference to the given attribute.
|
||||
virtual LogicalResult readAttribute(Attribute &result) = 0;
|
||||
template <typename T>
|
||||
LogicalResult readAttributes(SmallVectorImpl<T> &attrs) {
|
||||
return readList(attrs, [this](T &attr) { return readAttribute(attr); });
|
||||
}
|
||||
template <typename T>
|
||||
LogicalResult parseAttribute(T &result) {
|
||||
Attribute baseResult;
|
||||
if (failed(parseAttribute(baseResult)))
|
||||
return failure();
|
||||
if ((result = baseResult.dyn_cast<T>()))
|
||||
return success();
|
||||
return emitError() << "expected attribute of type: "
|
||||
<< llvm::getTypeName<T>() << ", but got: " << baseResult;
|
||||
}
|
||||
|
||||
/// Read a reference to the given type.
|
||||
virtual LogicalResult readType(Type &result) = 0;
|
||||
template <typename T>
|
||||
LogicalResult readTypes(SmallVectorImpl<T> &types) {
|
||||
return readList(types, [this](T &type) { return readType(type); });
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Primitives
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Read a variable width integer.
|
||||
// TODO: Add a signed variant when necessary.
|
||||
virtual LogicalResult readVarInt(uint64_t &result) = 0;
|
||||
|
||||
/// Read a string from the bytecode.
|
||||
virtual LogicalResult readString(StringRef &result) = 0;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DialectBytecodeWriter
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This class defines a virtual interface for writing to a bytecode stream,
|
||||
/// providing hooks into the bytecode writer. As such, this class should only be
|
||||
/// derived and defined by the main bytecode writer, users (i.e. dialects)
|
||||
/// should generally only interact with this class via the
|
||||
/// BytecodeDialectInterface below.
|
||||
class DialectBytecodeWriter {
|
||||
public:
|
||||
virtual ~DialectBytecodeWriter() = default;
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// IR
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Write out a list of elements, invoking the provided callback for each
|
||||
/// element.
|
||||
template <typename RangeT, typename CallbackFn>
|
||||
void writeList(RangeT &&range, CallbackFn &&callback) {
|
||||
writeVarInt(llvm::size(range));
|
||||
for (auto &element : range)
|
||||
callback(element);
|
||||
}
|
||||
|
||||
/// Write a reference to the given attribute.
|
||||
virtual void writeAttribute(Attribute attr) = 0;
|
||||
template <typename T>
|
||||
void writeAttributes(ArrayRef<T> attrs) {
|
||||
writeList(attrs, [this](T attr) { writeAttribute(attr); });
|
||||
}
|
||||
|
||||
/// Write a reference to the given type.
|
||||
virtual void writeType(Type type) = 0;
|
||||
template <typename T>
|
||||
void writeTypes(ArrayRef<T> types) {
|
||||
writeList(types, [this](T type) { writeType(type); });
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Primitives
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Write a variable width integer to the output stream. This should be the
|
||||
/// preferred method for emitting integers whenever possible.
|
||||
// TODO: Add a signed variant when necessary.
|
||||
virtual void writeVarInt(uint64_t value) = 0;
|
||||
|
||||
/// Write a string to the bytecode, which is owned by the caller and is
|
||||
/// guaranteed to not die before the end of the bytecode process. This should
|
||||
/// only be called if such a guarantee can be made, such as when the string is
|
||||
/// owned by an attribute or type.
|
||||
virtual void writeOwnedString(StringRef str) = 0;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BytecodeDialectInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class BytecodeDialectInterface
|
||||
: public DialectInterface::Base<BytecodeDialectInterface> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Reading
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Read an attribute belonging to this dialect from the given reader. This
|
||||
/// method should return null in the case of failure.
|
||||
virtual Attribute readAttribute(DialectBytecodeReader &reader) const {
|
||||
reader.emitError() << "dialect " << getDialect()->getNamespace()
|
||||
<< " does not support reading attributes from bytecode";
|
||||
return Attribute();
|
||||
}
|
||||
|
||||
/// Read a type belonging to this dialect from the given reader. This method
|
||||
/// should return null in the case of failure.
|
||||
virtual Type readType(DialectBytecodeReader &reader) const {
|
||||
reader.emitError() << "dialect " << getDialect()->getNamespace()
|
||||
<< " does not support reading types from bytecode";
|
||||
return Type();
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Writing
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Write the given attribute, which belongs to this dialect, to the given
|
||||
/// writer. This method may return failure to indicate that the given
|
||||
/// attribute could not be encoded, in which case the textual format will be
|
||||
/// used to encode this attribute instead.
|
||||
virtual LogicalResult writeAttribute(Attribute attr,
|
||||
DialectBytecodeWriter &writer) const {
|
||||
return failure();
|
||||
}
|
||||
|
||||
/// Write the given type, which belongs to this dialect, to the given writer.
|
||||
/// This method may return failure to indicate that the given type could not
|
||||
/// be encoded, in which case the textual format will be used to encode this
|
||||
/// type instead.
|
||||
virtual LogicalResult writeType(Type type,
|
||||
DialectBytecodeWriter &writer) const {
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H
|
@ -50,6 +50,9 @@ public:
|
||||
/// Return the dialect that this interface represents.
|
||||
Dialect *getDialect() const { return dialect; }
|
||||
|
||||
/// Return the context that holds the parent dialect of this interface.
|
||||
MLIRContext *getContext() const;
|
||||
|
||||
/// Return the derived interface id.
|
||||
TypeID getID() const { return interfaceID; }
|
||||
|
||||
|
@ -12,6 +12,7 @@
|
||||
#include "mlir/Bytecode/BytecodeReader.h"
|
||||
#include "../Encoding.h"
|
||||
#include "mlir/AsmParser/AsmParser.h"
|
||||
#include "mlir/Bytecode/BytecodeImplementation.h"
|
||||
#include "mlir/IR/BuiltinDialect.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
@ -66,7 +67,7 @@ public:
|
||||
|
||||
/// Emit an error using the given arguments.
|
||||
template <typename... Args>
|
||||
LogicalResult emitError(Args &&...args) const {
|
||||
InFlightDiagnostic emitError(Args &&...args) const {
|
||||
return ::emitError(fileLoc).append(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
@ -326,6 +327,11 @@ struct BytecodeDialect {
|
||||
"-allow-unregistered-dialect with the MLIR tool used.");
|
||||
}
|
||||
dialect = loadedDialect;
|
||||
|
||||
// If the dialect was actually loaded, check to see if it has a bytecode
|
||||
// interface.
|
||||
if (loadedDialect)
|
||||
interface = dyn_cast<BytecodeDialectInterface>(loadedDialect);
|
||||
return success();
|
||||
}
|
||||
|
||||
@ -333,6 +339,11 @@ struct BytecodeDialect {
|
||||
/// load, nullptr if we failed to load, otherwise the loaded dialect.
|
||||
Optional<Dialect *> dialect;
|
||||
|
||||
/// The bytecode interface of the dialect, or nullptr if the dialect does not
|
||||
/// implement the bytecode interface. This field should only be checked if the
|
||||
/// `dialect` field is non-None.
|
||||
const BytecodeDialectInterface *interface = nullptr;
|
||||
|
||||
/// The name of the dialect.
|
||||
StringRef name;
|
||||
};
|
||||
@ -397,7 +408,8 @@ class AttrTypeReader {
|
||||
using TypeEntry = Entry<Type>;
|
||||
|
||||
public:
|
||||
AttrTypeReader(Location fileLoc) : fileLoc(fileLoc) {}
|
||||
AttrTypeReader(StringSectionReader &stringReader, Location fileLoc)
|
||||
: stringReader(stringReader), fileLoc(fileLoc) {}
|
||||
|
||||
/// Initialize the attribute and type information within the reader.
|
||||
LogicalResult initialize(MutableArrayRef<BytecodeDialect> dialects,
|
||||
@ -456,6 +468,10 @@ private:
|
||||
LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader,
|
||||
StringRef entryType);
|
||||
|
||||
/// The string section reader used to resolve string references when parsing
|
||||
/// custom encoded attribute/type entries.
|
||||
StringSectionReader &stringReader;
|
||||
|
||||
/// The set of attribute and type entries.
|
||||
SmallVector<AttrEntry> attributes;
|
||||
SmallVector<TypeEntry> types;
|
||||
@ -463,6 +479,47 @@ private:
|
||||
/// A location used for error emission.
|
||||
Location fileLoc;
|
||||
};
|
||||
|
||||
class DialectReader : public DialectBytecodeReader {
|
||||
public:
|
||||
DialectReader(AttrTypeReader &attrTypeReader,
|
||||
StringSectionReader &stringReader, EncodingReader &reader)
|
||||
: attrTypeReader(attrTypeReader), stringReader(stringReader),
|
||||
reader(reader) {}
|
||||
|
||||
InFlightDiagnostic emitError(const Twine &msg) override {
|
||||
return reader.emitError(msg);
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// IR
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult readAttribute(Attribute &result) override {
|
||||
return attrTypeReader.parseAttribute(reader, result);
|
||||
}
|
||||
|
||||
LogicalResult readType(Type &result) override {
|
||||
return attrTypeReader.parseType(reader, result);
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Primitives
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult readVarInt(uint64_t &result) override {
|
||||
return reader.parseVarInt(result);
|
||||
}
|
||||
|
||||
LogicalResult readString(StringRef &result) override {
|
||||
return stringReader.parseString(reader, result);
|
||||
}
|
||||
|
||||
private:
|
||||
AttrTypeReader &attrTypeReader;
|
||||
StringSectionReader &stringReader;
|
||||
EncodingReader &reader;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
LogicalResult
|
||||
@ -486,7 +543,7 @@ AttrTypeReader::initialize(MutableArrayRef<BytecodeDialect> dialects,
|
||||
size_t currentIndex = 0, endIndex = range.size();
|
||||
|
||||
// Parse an individual entry.
|
||||
auto parseEntryFn = [&](BytecodeDialect *dialect) {
|
||||
auto parseEntryFn = [&](BytecodeDialect *dialect) -> LogicalResult {
|
||||
auto &entry = range[currentIndex++];
|
||||
|
||||
uint64_t entrySize;
|
||||
@ -548,8 +605,7 @@ T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
|
||||
}
|
||||
|
||||
if (!reader.empty()) {
|
||||
(void)reader.emitError("unexpected trailing bytes after " + entryType +
|
||||
" entry");
|
||||
reader.emitError("unexpected trailing bytes after " + entryType + " entry");
|
||||
return T();
|
||||
}
|
||||
return entry.entry;
|
||||
@ -584,8 +640,22 @@ template <typename T>
|
||||
LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
|
||||
EncodingReader &reader,
|
||||
StringRef entryType) {
|
||||
// FIXME: Add support for reading custom attribute/type encodings.
|
||||
return reader.emitError("unexpected Attribute encoding");
|
||||
if (failed(entry.dialect->load(reader, fileLoc.getContext())))
|
||||
return failure();
|
||||
|
||||
// Ensure that the dialect implements the bytecode interface.
|
||||
if (!entry.dialect->interface) {
|
||||
return reader.emitError("dialect '", entry.dialect->name,
|
||||
"' does not implement the bytecode interface");
|
||||
}
|
||||
|
||||
// Ask the dialect to parse the entry.
|
||||
DialectReader dialectReader(*this, stringReader, reader);
|
||||
if constexpr (std::is_same_v<T, Type>)
|
||||
entry.entry = entry.dialect->interface->readType(dialectReader);
|
||||
else
|
||||
entry.entry = entry.dialect->interface->readAttribute(dialectReader);
|
||||
return success(!!entry.entry);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -597,7 +667,7 @@ namespace {
|
||||
class BytecodeReader {
|
||||
public:
|
||||
BytecodeReader(Location fileLoc, const ParserConfig &config)
|
||||
: config(config), fileLoc(fileLoc), attrTypeReader(fileLoc),
|
||||
: config(config), fileLoc(fileLoc), attrTypeReader(stringReader, fileLoc),
|
||||
// Use the builtin unrealized conversion cast operation to represent
|
||||
// forward references to values that aren't yet defined.
|
||||
forwardRefOpState(UnknownLoc::get(config.getContext()),
|
||||
|
@ -9,6 +9,7 @@
|
||||
#include "mlir/Bytecode/BytecodeWriter.h"
|
||||
#include "../Encoding.h"
|
||||
#include "IRNumbering.h"
|
||||
#include "mlir/Bytecode/BytecodeImplementation.h"
|
||||
#include "mlir/IR/BuiltinDialect.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "llvm/ADT/CachedHashString.h"
|
||||
@ -358,22 +359,78 @@ void BytecodeWriter::writeDialectSection(EncodingEmitter &emitter) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Attributes and Types
|
||||
|
||||
namespace {
|
||||
class DialectWriter : public DialectBytecodeWriter {
|
||||
public:
|
||||
DialectWriter(EncodingEmitter &emitter, IRNumberingState &numberingState,
|
||||
StringSectionBuilder &stringSection)
|
||||
: emitter(emitter), numberingState(numberingState),
|
||||
stringSection(stringSection) {}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// IR
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
void writeAttribute(Attribute attr) override {
|
||||
emitter.emitVarInt(numberingState.getNumber(attr));
|
||||
}
|
||||
void writeType(Type type) override {
|
||||
emitter.emitVarInt(numberingState.getNumber(type));
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Primitives
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
void writeVarInt(uint64_t value) override { emitter.emitVarInt(value); }
|
||||
|
||||
void writeOwnedString(StringRef str) override {
|
||||
emitter.emitVarInt(stringSection.insert(str));
|
||||
}
|
||||
|
||||
private:
|
||||
EncodingEmitter &emitter;
|
||||
IRNumberingState &numberingState;
|
||||
StringSectionBuilder &stringSection;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void BytecodeWriter::writeAttrTypeSection(EncodingEmitter &emitter) {
|
||||
EncodingEmitter attrTypeEmitter;
|
||||
EncodingEmitter offsetEmitter;
|
||||
offsetEmitter.emitVarInt(llvm::size(numberingState.getAttributes()));
|
||||
offsetEmitter.emitVarInt(llvm::size(numberingState.getTypes()));
|
||||
|
||||
// The writer used when emitting using a custom bytecode encoding.
|
||||
DialectWriter dialectWriter(attrTypeEmitter, numberingState, stringSection);
|
||||
|
||||
// A functor used to emit an attribute or type entry.
|
||||
uint64_t prevOffset = 0;
|
||||
auto emitAttrOrType = [&](auto &entry) {
|
||||
// TODO: Allow dialects to provide more optimal implementations of attribute
|
||||
// and type encodings.
|
||||
bool hasCustomEncoding = false;
|
||||
auto entryValue = entry.getValue();
|
||||
|
||||
// Emit the entry using the textual format.
|
||||
raw_emitter_ostream(attrTypeEmitter) << entry.getValue();
|
||||
attrTypeEmitter.emitByte(0);
|
||||
// First, try to emit this entry using the dialect bytecode interface.
|
||||
bool hasCustomEncoding = false;
|
||||
if (const BytecodeDialectInterface *interface = entry.dialect->interface) {
|
||||
if constexpr (std::is_same_v<std::decay_t<decltype(entryValue)>, Type>) {
|
||||
// TODO: We don't currently support custom encoded mutable types.
|
||||
hasCustomEncoding =
|
||||
!entryValue.template hasTrait<TypeTrait::IsMutable>() &&
|
||||
succeeded(interface->writeType(entryValue, dialectWriter));
|
||||
} else {
|
||||
// TODO: We don't currently support custom encoded mutable attributes.
|
||||
hasCustomEncoding =
|
||||
!entryValue.template hasTrait<AttributeTrait::IsMutable>() &&
|
||||
succeeded(interface->writeAttribute(entryValue, dialectWriter));
|
||||
}
|
||||
}
|
||||
|
||||
// If the entry was not emitted using the dialect interface, emit it using
|
||||
// the textual format.
|
||||
if (!hasCustomEncoding) {
|
||||
raw_emitter_ostream(attrTypeEmitter) << entryValue;
|
||||
attrTypeEmitter.emitByte(0);
|
||||
}
|
||||
|
||||
// Record the offset of this entry.
|
||||
uint64_t curOffset = attrTypeEmitter.size();
|
||||
|
@ -7,6 +7,7 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "IRNumbering.h"
|
||||
#include "mlir/Bytecode/BytecodeImplementation.h"
|
||||
#include "mlir/Bytecode/BytecodeWriter.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
@ -14,6 +15,28 @@
|
||||
using namespace mlir;
|
||||
using namespace mlir::bytecode::detail;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NumberingDialectWriter
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct IRNumberingState::NumberingDialectWriter : public DialectBytecodeWriter {
|
||||
NumberingDialectWriter(IRNumberingState &state) : state(state) {}
|
||||
|
||||
void writeAttribute(Attribute attr) override { state.number(attr); }
|
||||
void writeType(Type type) override { state.number(type); }
|
||||
|
||||
/// Stubbed out methods that are not used for numbering.
|
||||
void writeVarInt(uint64_t) override {}
|
||||
void writeOwnedString(StringRef) override {
|
||||
// TODO: It might be nice to prenumber strings and sort by the number of
|
||||
// references. This could potentially be useful for optimizing things like
|
||||
// file locations.
|
||||
}
|
||||
|
||||
/// The parent numbering state that is populated by this writer.
|
||||
IRNumberingState &state;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// IR Numbering
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -138,10 +161,22 @@ void IRNumberingState::number(Attribute attr) {
|
||||
// have a registered dialect when it got created. We don't want to encode this
|
||||
// as the builtin OpaqueAttr, we want to encode it as if the dialect was
|
||||
// actually loaded.
|
||||
if (OpaqueAttr opaqueAttr = attr.dyn_cast<OpaqueAttr>())
|
||||
if (OpaqueAttr opaqueAttr = attr.dyn_cast<OpaqueAttr>()) {
|
||||
numbering->dialect = &numberDialect(opaqueAttr.getDialectNamespace());
|
||||
else
|
||||
numbering->dialect = &numberDialect(&attr.getDialect());
|
||||
return;
|
||||
}
|
||||
numbering->dialect = &numberDialect(&attr.getDialect());
|
||||
|
||||
// If this attribute will be emitted using the bytecode format, perform a
|
||||
// dummy writing to number any nested components.
|
||||
if (const auto *interface = numbering->dialect->interface) {
|
||||
// TODO: We don't allow custom encodings for mutable attributes right now.
|
||||
if (attr.hasTrait<AttributeTrait::IsMutable>())
|
||||
return;
|
||||
|
||||
NumberingDialectWriter writer(*this);
|
||||
(void)interface->writeAttribute(attr, writer);
|
||||
}
|
||||
}
|
||||
|
||||
void IRNumberingState::number(Block &block) {
|
||||
@ -164,7 +199,7 @@ auto IRNumberingState::numberDialect(Dialect *dialect) -> DialectNumbering & {
|
||||
DialectNumbering *&numbering = registeredDialects[dialect];
|
||||
if (!numbering) {
|
||||
numbering = &numberDialect(dialect->getNamespace());
|
||||
numbering->dialect = dialect;
|
||||
numbering->interface = dyn_cast<BytecodeDialectInterface>(dialect);
|
||||
}
|
||||
return *numbering;
|
||||
}
|
||||
@ -244,8 +279,20 @@ void IRNumberingState::number(Type type) {
|
||||
// registered dialect when it got created. We don't want to encode this as the
|
||||
// builtin OpaqueType, we want to encode it as if the dialect was actually
|
||||
// loaded.
|
||||
if (OpaqueType opaqueType = type.dyn_cast<OpaqueType>())
|
||||
if (OpaqueType opaqueType = type.dyn_cast<OpaqueType>()) {
|
||||
numbering->dialect = &numberDialect(opaqueType.getDialectNamespace());
|
||||
else
|
||||
numbering->dialect = &numberDialect(&type.getDialect());
|
||||
return;
|
||||
}
|
||||
numbering->dialect = &numberDialect(&type.getDialect());
|
||||
|
||||
// If this type will be emitted using the bytecode format, perform a dummy
|
||||
// writing to number any nested components.
|
||||
if (const auto *interface = numbering->dialect->interface) {
|
||||
// TODO: We don't allow custom encodings for mutable types right now.
|
||||
if (type.hasTrait<TypeTrait::IsMutable>())
|
||||
return;
|
||||
|
||||
NumberingDialectWriter writer(*this);
|
||||
(void)interface->writeType(type, writer);
|
||||
}
|
||||
}
|
||||
|
@ -18,6 +18,7 @@
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
|
||||
namespace mlir {
|
||||
class BytecodeDialectInterface;
|
||||
class BytecodeWriterConfig;
|
||||
|
||||
namespace bytecode {
|
||||
@ -90,8 +91,8 @@ struct DialectNumbering {
|
||||
/// The number assigned to the dialect.
|
||||
unsigned number;
|
||||
|
||||
/// The loaded dialect, or nullptr if the dialect isn't loaded.
|
||||
Dialect *dialect = nullptr;
|
||||
/// The bytecode dialect interface of the dialect if defined.
|
||||
const BytecodeDialectInterface *interface = nullptr;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -147,6 +148,10 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
/// This class is used to provide a fake dialect writer for numbering nested
|
||||
/// attributes and types.
|
||||
struct NumberingDialectWriter;
|
||||
|
||||
/// Number the given IR unit for bytecode emission.
|
||||
void number(Attribute attr);
|
||||
void number(Block &block);
|
||||
|
@ -12,6 +12,7 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/IR/BuiltinDialect.h"
|
||||
#include "BuiltinDialectBytecode.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
@ -117,6 +118,7 @@ void BuiltinDialect::initialize() {
|
||||
|
||||
auto &blobInterface = addInterface<BuiltinBlobManagerInterface>();
|
||||
addInterface<BuiltinOpAsmDialectInterface>(blobInterface);
|
||||
builtin_dialect_detail::addBytecodeInterface(this);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
269
mlir/lib/IR/BuiltinDialectBytecode.cpp
Normal file
269
mlir/lib/IR/BuiltinDialectBytecode.cpp
Normal file
@ -0,0 +1,269 @@
|
||||
//===- BuiltinDialectBytecode.cpp - Builtin Bytecode Implementation -------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "BuiltinDialectBytecode.h"
|
||||
#include "mlir/Bytecode/BytecodeImplementation.h"
|
||||
#include "mlir/IR/BuiltinDialect.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
namespace builtin_encoding {
|
||||
/// This enum contains marker codes used to indicate which attribute is
|
||||
/// currently being decoded, and how it should be decoded. The order of these
|
||||
/// codes should generally be unchanged, as any changes will inevitably break
|
||||
/// compatibility with older bytecode.
|
||||
enum AttributeCode {
|
||||
/// ArrayAttr {
|
||||
/// elements: Attribute[]
|
||||
/// }
|
||||
///
|
||||
kArrayAttr = 0,
|
||||
|
||||
/// DictionaryAttr {
|
||||
/// attrs: <StringAttr, Attribute>[]
|
||||
/// }
|
||||
kDictionaryAttr = 1,
|
||||
|
||||
/// StringAttr {
|
||||
/// string
|
||||
/// }
|
||||
kStringAttr = 2,
|
||||
};
|
||||
|
||||
/// This enum contains marker codes used to indicate which type is currently
|
||||
/// being decoded, and how it should be decoded. The order of these codes should
|
||||
/// generally be unchanged, as any changes will inevitably break compatibility
|
||||
/// with older bytecode.
|
||||
enum TypeCode {
|
||||
/// IntegerType {
|
||||
/// widthAndSignedness: varint // (width << 2) | (signedness)
|
||||
/// }
|
||||
///
|
||||
kIntegerType = 0,
|
||||
|
||||
/// IndexType {
|
||||
/// }
|
||||
///
|
||||
kIndexType = 1,
|
||||
|
||||
/// FunctionType {
|
||||
/// inputs: Type[],
|
||||
/// results: Type[]
|
||||
/// }
|
||||
///
|
||||
kFunctionType = 2,
|
||||
};
|
||||
|
||||
} // namespace builtin_encoding
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BuiltinDialectBytecodeInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// This class implements the bytecode interface for the builtin dialect.
|
||||
struct BuiltinDialectBytecodeInterface : public BytecodeDialectInterface {
|
||||
BuiltinDialectBytecodeInterface(Dialect *dialect)
|
||||
: BytecodeDialectInterface(dialect) {}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Attributes
|
||||
|
||||
Attribute readAttribute(DialectBytecodeReader &reader) const override;
|
||||
ArrayAttr readArrayAttr(DialectBytecodeReader &reader) const;
|
||||
DictionaryAttr readDictionaryAttr(DialectBytecodeReader &reader) const;
|
||||
StringAttr readStringAttr(DialectBytecodeReader &reader) const;
|
||||
|
||||
LogicalResult writeAttribute(Attribute attr,
|
||||
DialectBytecodeWriter &writer) const override;
|
||||
void write(ArrayAttr attr, DialectBytecodeWriter &writer) const;
|
||||
void write(DictionaryAttr attr, DialectBytecodeWriter &writer) const;
|
||||
void write(StringAttr attr, DialectBytecodeWriter &writer) const;
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Types
|
||||
|
||||
Type readType(DialectBytecodeReader &reader) const override;
|
||||
IntegerType readIntegerType(DialectBytecodeReader &reader) const;
|
||||
FunctionType readFunctionType(DialectBytecodeReader &reader) const;
|
||||
|
||||
LogicalResult writeType(Type type,
|
||||
DialectBytecodeWriter &writer) const override;
|
||||
void write(IntegerType type, DialectBytecodeWriter &writer) const;
|
||||
void write(FunctionType type, DialectBytecodeWriter &writer) const;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void builtin_dialect_detail::addBytecodeInterface(BuiltinDialect *dialect) {
|
||||
dialect->addInterfaces<BuiltinDialectBytecodeInterface>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Attributes: Reader
|
||||
|
||||
Attribute BuiltinDialectBytecodeInterface::readAttribute(
|
||||
DialectBytecodeReader &reader) const {
|
||||
uint64_t code;
|
||||
if (failed(reader.readVarInt(code)))
|
||||
return Attribute();
|
||||
switch (code) {
|
||||
case builtin_encoding::kArrayAttr:
|
||||
return readArrayAttr(reader);
|
||||
case builtin_encoding::kDictionaryAttr:
|
||||
return readDictionaryAttr(reader);
|
||||
case builtin_encoding::kStringAttr:
|
||||
return readStringAttr(reader);
|
||||
default:
|
||||
reader.emitError() << "unknown builtin attribute code: " << code;
|
||||
return Attribute();
|
||||
}
|
||||
}
|
||||
|
||||
ArrayAttr BuiltinDialectBytecodeInterface::readArrayAttr(
|
||||
DialectBytecodeReader &reader) const {
|
||||
SmallVector<Attribute> elements;
|
||||
if (failed(reader.readAttributes(elements)))
|
||||
return ArrayAttr();
|
||||
return ArrayAttr::get(getContext(), elements);
|
||||
}
|
||||
|
||||
DictionaryAttr BuiltinDialectBytecodeInterface::readDictionaryAttr(
|
||||
DialectBytecodeReader &reader) const {
|
||||
auto readNamedAttr = [&]() -> FailureOr<NamedAttribute> {
|
||||
StringAttr name;
|
||||
Attribute value;
|
||||
if (failed(reader.readAttribute(name)) ||
|
||||
failed(reader.readAttribute(value)))
|
||||
return failure();
|
||||
return NamedAttribute(name, value);
|
||||
};
|
||||
SmallVector<NamedAttribute> attrs;
|
||||
if (failed(reader.readList(attrs, readNamedAttr)))
|
||||
return DictionaryAttr();
|
||||
return DictionaryAttr::get(getContext(), attrs);
|
||||
}
|
||||
|
||||
StringAttr BuiltinDialectBytecodeInterface::readStringAttr(
|
||||
DialectBytecodeReader &reader) const {
|
||||
StringRef string;
|
||||
if (failed(reader.readString(string)))
|
||||
return StringAttr();
|
||||
return StringAttr::get(getContext(), string);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Attributes: Writer
|
||||
|
||||
LogicalResult BuiltinDialectBytecodeInterface::writeAttribute(
|
||||
Attribute attr, DialectBytecodeWriter &writer) const {
|
||||
return TypeSwitch<Attribute, LogicalResult>(attr)
|
||||
.Case<ArrayAttr, DictionaryAttr, StringAttr>([&](auto attr) {
|
||||
write(attr, writer);
|
||||
return success();
|
||||
})
|
||||
.Default([&](Attribute) { return failure(); });
|
||||
}
|
||||
|
||||
void BuiltinDialectBytecodeInterface::write(
|
||||
ArrayAttr attr, DialectBytecodeWriter &writer) const {
|
||||
writer.writeVarInt(builtin_encoding::kArrayAttr);
|
||||
writer.writeAttributes(attr.getValue());
|
||||
}
|
||||
|
||||
void BuiltinDialectBytecodeInterface::write(
|
||||
DictionaryAttr attr, DialectBytecodeWriter &writer) const {
|
||||
writer.writeVarInt(builtin_encoding::kDictionaryAttr);
|
||||
writer.writeList(attr.getValue(), [&](NamedAttribute attr) {
|
||||
writer.writeAttribute(attr.getName());
|
||||
writer.writeAttribute(attr.getValue());
|
||||
});
|
||||
}
|
||||
|
||||
void BuiltinDialectBytecodeInterface::write(
|
||||
StringAttr attr, DialectBytecodeWriter &writer) const {
|
||||
writer.writeVarInt(builtin_encoding::kStringAttr);
|
||||
writer.writeOwnedString(attr.getValue());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Types: Reader
|
||||
|
||||
Type BuiltinDialectBytecodeInterface::readType(
|
||||
DialectBytecodeReader &reader) const {
|
||||
uint64_t code;
|
||||
if (failed(reader.readVarInt(code)))
|
||||
return Type();
|
||||
switch (code) {
|
||||
case builtin_encoding::kIntegerType:
|
||||
return readIntegerType(reader);
|
||||
case builtin_encoding::kIndexType:
|
||||
return IndexType::get(getContext());
|
||||
|
||||
case builtin_encoding::kFunctionType:
|
||||
return readFunctionType(reader);
|
||||
default:
|
||||
reader.emitError() << "unknown builtin type code: " << code;
|
||||
return Type();
|
||||
}
|
||||
}
|
||||
|
||||
IntegerType BuiltinDialectBytecodeInterface::readIntegerType(
|
||||
DialectBytecodeReader &reader) const {
|
||||
uint64_t encoding;
|
||||
if (failed(reader.readVarInt(encoding)))
|
||||
return IntegerType();
|
||||
return IntegerType::get(
|
||||
getContext(), encoding >> 2,
|
||||
static_cast<IntegerType::SignednessSemantics>(encoding & 0x3));
|
||||
}
|
||||
|
||||
FunctionType BuiltinDialectBytecodeInterface::readFunctionType(
|
||||
DialectBytecodeReader &reader) const {
|
||||
SmallVector<Type> inputs, results;
|
||||
if (failed(reader.readTypes(inputs)) || failed(reader.readTypes(results)))
|
||||
return FunctionType();
|
||||
return FunctionType::get(getContext(), inputs, results);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Types: Writer
|
||||
|
||||
LogicalResult BuiltinDialectBytecodeInterface::writeType(
|
||||
Type type, DialectBytecodeWriter &writer) const {
|
||||
return TypeSwitch<Type, LogicalResult>(type)
|
||||
.Case<IntegerType, FunctionType>([&](auto type) {
|
||||
write(type, writer);
|
||||
return success();
|
||||
})
|
||||
.Case([&](IndexType) {
|
||||
return writer.writeVarInt(builtin_encoding::kIndexType), success();
|
||||
})
|
||||
.Default([&](Type) { return failure(); });
|
||||
}
|
||||
|
||||
void BuiltinDialectBytecodeInterface::write(
|
||||
IntegerType type, DialectBytecodeWriter &writer) const {
|
||||
writer.writeVarInt(builtin_encoding::kIntegerType);
|
||||
writer.writeVarInt((type.getWidth() << 2) | type.getSignedness());
|
||||
}
|
||||
|
||||
void BuiltinDialectBytecodeInterface::write(
|
||||
FunctionType type, DialectBytecodeWriter &writer) const {
|
||||
writer.writeVarInt(builtin_encoding::kFunctionType);
|
||||
writer.writeTypes(type.getInputs());
|
||||
writer.writeTypes(type.getResults());
|
||||
}
|
26
mlir/lib/IR/BuiltinDialectBytecode.h
Normal file
26
mlir/lib/IR/BuiltinDialectBytecode.h
Normal file
@ -0,0 +1,26 @@
|
||||
//===- BuiltinDialectBytecode.h - MLIR Bytecode Implementation --*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This header defines hooks into the builtin dialect bytecode implementation.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef LIB_MLIR_IR_BUILTINDIALECTBYTECODE_H
|
||||
#define LIB_MLIR_IR_BUILTINDIALECTBYTECODE_H
|
||||
|
||||
namespace mlir {
|
||||
class BuiltinDialect;
|
||||
|
||||
namespace builtin_dialect_detail {
|
||||
/// Add the interfaces necessary for encoding the builtin dialect components in
|
||||
/// bytecode.
|
||||
void addBytecodeInterface(BuiltinDialect *dialect);
|
||||
} // namespace builtin_dialect_detail
|
||||
} // namespace mlir
|
||||
|
||||
#endif // LIB_MLIR_IR_BUILTINDIALECTBYTECODE_H
|
@ -8,6 +8,7 @@ add_mlir_library(MLIRIR
|
||||
BuiltinAttributeInterfaces.cpp
|
||||
BuiltinAttributes.cpp
|
||||
BuiltinDialect.cpp
|
||||
BuiltinDialectBytecode.cpp
|
||||
BuiltinTypes.cpp
|
||||
BuiltinTypeInterfaces.cpp
|
||||
Diagnostics.cpp
|
||||
|
@ -113,6 +113,10 @@ void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
|
||||
|
||||
DialectInterface::~DialectInterface() = default;
|
||||
|
||||
MLIRContext *DialectInterface::getContext() const {
|
||||
return dialect->getContext();
|
||||
}
|
||||
|
||||
DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
|
||||
MLIRContext *ctx, TypeID interfaceKind) {
|
||||
for (auto *dialect : ctx->getLoadedDialects()) {
|
||||
|
16
mlir/test/Dialect/Builtin/Bytecode/attrs.mlir
Normal file
16
mlir/test/Dialect/Builtin/Bytecode/attrs.mlir
Normal file
@ -0,0 +1,16 @@
|
||||
// RUN: mlir-opt -emit-bytecode %s | mlir-opt | FileCheck %s
|
||||
|
||||
// Bytecode currently does not support big-endian platforms
|
||||
// UNSUPPORTED: s390x-
|
||||
|
||||
// CHECK-LABEL: @TestArray
|
||||
module @TestArray attributes {
|
||||
// CHECK: bytecode.array = [unit]
|
||||
bytecode.array = [unit]
|
||||
} {}
|
||||
|
||||
// CHECK-LABEL: @TestString
|
||||
module @TestString attributes {
|
||||
// CHECK: bytecode.string = "hello"
|
||||
bytecode.string = "hello"
|
||||
} {}
|
28
mlir/test/Dialect/Builtin/Bytecode/types.mlir
Normal file
28
mlir/test/Dialect/Builtin/Bytecode/types.mlir
Normal file
@ -0,0 +1,28 @@
|
||||
// RUN: mlir-opt -emit-bytecode %s | mlir-opt | FileCheck %s
|
||||
|
||||
// Bytecode currently does not support big-endian platforms
|
||||
// UNSUPPORTED: s390x-
|
||||
|
||||
// CHECK-LABEL: @TestInteger
|
||||
module @TestInteger attributes {
|
||||
// CHECK: bytecode.int = i1024,
|
||||
// CHECK: bytecode.int1 = si32,
|
||||
// CHECK: bytecode.int2 = ui512
|
||||
bytecode.int = i1024,
|
||||
bytecode.int1 = si32,
|
||||
bytecode.int2 = ui512
|
||||
} {}
|
||||
|
||||
// CHECK-LABEL: @TestIndex
|
||||
module @TestIndex attributes {
|
||||
// CHECK: bytecode.index = index
|
||||
bytecode.index = index
|
||||
} {}
|
||||
|
||||
// CHECK-LABEL: @TestFunc
|
||||
module @TestFunc attributes {
|
||||
// CHECK: bytecode.func = () -> (),
|
||||
// CHECK: bytecode.func1 = (i1) -> i32
|
||||
bytecode.func = () -> (),
|
||||
bytecode.func1 = (i1) -> (i32)
|
||||
} {}
|
Loading…
x
Reference in New Issue
Block a user