
This is a model runner for ML researchers using environments like CompilerGym. In such environments, researchers host the compiler and want to be able to observe the problem space (features) at each decision step of some optimization pass, at which point the compiler is stopped, waiting for the host makes a decision and provide an advice back to the compiler, which then continues its normal operation, and so on. The InteractiveModelRunner supports this scenario for the feature set exposed by the compiler at a given time. It uses 2 files - ideally FIFO pipes - one to pass data to the host, the other to get advices back from the host. This means this scenario is supported with no special dependencies. The file creation and deletion is the responsibility of the host. Hooking up this model evaluator to a MLGO-ed pass is the responsibilty of the pass author, and subsequent patches will do so for the current set of mlgo passes, and offer an API to easily "just opt in" by default when mlgo-ing a new pass. The data protocol is that of the training logger: the host sees a training log doled out observation by observation by reading from one of the files, and passes back its advice as a serialized tensor (i.e. tensor value memory dump) via the other file. There are some differences wrt the log seen during training: the interactive model doesn't currently include the outcome (because it should be identical to the decision, and it's also not present in the "release" mode); and partial rewards aren't currently communicated back. The assumption - just like with the training logger - is that the host is co-located, thus avoiding any endianness concerns. In a distributed environment, it is up to the hosting infrastructure to intermediate that. Differential Revision: https://reviews.llvm.org/D142642
125 lines
4.5 KiB
C++
125 lines
4.5 KiB
C++
//===- TensorSpec.cpp - tensor type abstraction ---------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// Implementation file for the abstraction of a tensor type, and JSON loading
|
|
// utils.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/Config/config.h"
|
|
|
|
#include "llvm/ADT/StringExtras.h"
|
|
#include "llvm/ADT/Twine.h"
|
|
#include "llvm/Analysis/TensorSpec.h"
|
|
#include "llvm/Support/CommandLine.h"
|
|
#include "llvm/Support/Debug.h"
|
|
#include "llvm/Support/JSON.h"
|
|
#include "llvm/Support/ManagedStatic.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
#include <array>
|
|
#include <cassert>
|
|
#include <numeric>
|
|
|
|
using namespace llvm;
|
|
|
|
namespace llvm {
|
|
|
|
#define TFUTILS_GETDATATYPE_IMPL(T, E) \
|
|
template <> TensorType TensorSpec::getDataType<T>() { return TensorType::E; }
|
|
|
|
SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_IMPL)
|
|
|
|
#undef TFUTILS_GETDATATYPE_IMPL
|
|
|
|
static std::array<std::string, static_cast<size_t>(TensorType::Total)>
|
|
TensorTypeNames{"INVALID",
|
|
#define TFUTILS_GETNAME_IMPL(T, _) #T,
|
|
SUPPORTED_TENSOR_TYPES(TFUTILS_GETNAME_IMPL)
|
|
#undef TFUTILS_GETNAME_IMPL
|
|
};
|
|
|
|
StringRef toString(TensorType TT) {
|
|
return TensorTypeNames[static_cast<size_t>(TT)];
|
|
}
|
|
|
|
void TensorSpec::toJSON(json::OStream &OS) const {
|
|
OS.object([&]() {
|
|
OS.attribute("name", name());
|
|
OS.attribute("type", toString(type()));
|
|
OS.attribute("port", port());
|
|
OS.attributeArray("shape", [&]() {
|
|
for (size_t D : shape())
|
|
OS.value(static_cast<int64_t>(D));
|
|
});
|
|
});
|
|
}
|
|
|
|
TensorSpec::TensorSpec(const std::string &Name, int Port, TensorType Type,
|
|
size_t ElementSize, const std::vector<int64_t> &Shape)
|
|
: Name(Name), Port(Port), Type(Type), Shape(Shape),
|
|
ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1,
|
|
std::multiplies<int64_t>())),
|
|
ElementSize(ElementSize) {}
|
|
|
|
std::optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
|
|
const json::Value &Value) {
|
|
auto EmitError =
|
|
[&](const llvm::Twine &Message) -> std::optional<TensorSpec> {
|
|
std::string S;
|
|
llvm::raw_string_ostream OS(S);
|
|
OS << Value;
|
|
Ctx.emitError("Unable to parse JSON Value as spec (" + Message + "): " + S);
|
|
return std::nullopt;
|
|
};
|
|
// FIXME: accept a Path as a parameter, and use it for error reporting.
|
|
json::Path::Root Root("tensor_spec");
|
|
json::ObjectMapper Mapper(Value, Root);
|
|
if (!Mapper)
|
|
return EmitError("Value is not a dict");
|
|
|
|
std::string TensorName;
|
|
int TensorPort = -1;
|
|
std::string TensorType;
|
|
std::vector<int64_t> TensorShape;
|
|
|
|
if (!Mapper.map<std::string>("name", TensorName))
|
|
return EmitError("'name' property not present or not a string");
|
|
if (!Mapper.map<std::string>("type", TensorType))
|
|
return EmitError("'type' property not present or not a string");
|
|
if (!Mapper.map<int>("port", TensorPort))
|
|
return EmitError("'port' property not present or not an int");
|
|
if (!Mapper.map<std::vector<int64_t>>("shape", TensorShape))
|
|
return EmitError("'shape' property not present or not an int array");
|
|
|
|
#define PARSE_TYPE(T, E) \
|
|
if (TensorType == #T) \
|
|
return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort);
|
|
SUPPORTED_TENSOR_TYPES(PARSE_TYPE)
|
|
#undef PARSE_TYPE
|
|
return std::nullopt;
|
|
}
|
|
|
|
std::string tensorValueToString(const char *Buffer, const TensorSpec &Spec) {
|
|
switch (Spec.type()) {
|
|
#define _IMR_DBG_PRINTER(T, N) \
|
|
case TensorType::N: { \
|
|
const T *TypedBuff = reinterpret_cast<const T *>(Buffer); \
|
|
auto R = llvm::make_range(TypedBuff, TypedBuff + Spec.getElementCount()); \
|
|
return llvm::join( \
|
|
llvm::map_range(R, [](T V) { return std::to_string(V); }), ","); \
|
|
}
|
|
SUPPORTED_TENSOR_TYPES(_IMR_DBG_PRINTER)
|
|
#undef _IMR_DBG_PRINTER
|
|
case TensorType::Total:
|
|
case TensorType::Invalid:
|
|
llvm_unreachable("invalid tensor type");
|
|
}
|
|
}
|
|
|
|
} // namespace llvm
|