llvm-project/mlir/lib/Target/SPIRV/TranslateRegistration.cpp
rkayaith ed90f8026e [mlir-translate] Support parsing operations other than 'builtin.module' as top-level
This adds a '--no-implicit-module' option, which disables the insertion
of a top-level 'builtin.module' during parsing.

The translation APIs are also updated to take/return 'Operation*'
instead of 'ModuleOp', to allow other operation types to be used. To
simplify translations which are restricted to specific operation types,
'TranslateFromMLIRRegistration' has an overload which performs the
necessary cast and error checking.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D134237
2022-10-21 15:54:06 -04:00

185 lines
6.5 KiB
C++

//===- TranslateRegistration.cpp - hooks to mlir-translate ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a translation from SPIR-V binary module to MLIR SPIR-V
// ModuleOp.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Target/SPIRV/Deserialization.h"
#include "mlir/Target/SPIRV/Serialization.h"
#include "mlir/Tools/mlir-translate/Translation.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SMLoc.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Deserialization registration
//===----------------------------------------------------------------------===//
// Deserializes the SPIR-V binary module stored in the file named as
// `inputFilename` and returns a module containing the SPIR-V module.
static OwningOpRef<Operation *>
deserializeModule(const llvm::MemoryBuffer *input, MLIRContext *context) {
context->loadDialect<spirv::SPIRVDialect>();
// Make sure the input stream can be treated as a stream of SPIR-V words
auto *start = input->getBufferStart();
auto size = input->getBufferSize();
if (size % sizeof(uint32_t) != 0) {
emitError(UnknownLoc::get(context))
<< "SPIR-V binary module must contain integral number of 32-bit words";
return {};
}
auto binary = llvm::makeArrayRef(reinterpret_cast<const uint32_t *>(start),
size / sizeof(uint32_t));
OwningOpRef<spirv::ModuleOp> spirvModule =
spirv::deserialize(binary, context);
if (!spirvModule)
return {};
OwningOpRef<ModuleOp> module(ModuleOp::create(FileLineColLoc::get(
context, input->getBufferIdentifier(), /*line=*/0, /*column=*/0)));
module->getBody()->push_front(spirvModule.release());
return std::move(module);
}
namespace mlir {
void registerFromSPIRVTranslation() {
TranslateToMLIRRegistration fromBinary(
"deserialize-spirv", "deserializes the SPIR-V module",
[](llvm::SourceMgr &sourceMgr, MLIRContext *context) {
assert(sourceMgr.getNumBuffers() == 1 && "expected one buffer");
return deserializeModule(
sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), context);
});
}
} // namespace mlir
//===----------------------------------------------------------------------===//
// Serialization registration
//===----------------------------------------------------------------------===//
static LogicalResult serializeModule(ModuleOp module, raw_ostream &output) {
if (!module)
return failure();
SmallVector<uint32_t, 0> binary;
SmallVector<spirv::ModuleOp, 1> spirvModules;
module.walk([&](spirv::ModuleOp op) { spirvModules.push_back(op); });
if (spirvModules.empty())
return module.emitError("found no 'spirv.module' op");
if (spirvModules.size() != 1)
return module.emitError("found more than one 'spirv.module' op");
if (failed(spirv::serialize(spirvModules[0], binary)))
return failure();
output.write(reinterpret_cast<char *>(binary.data()),
binary.size() * sizeof(uint32_t));
return mlir::success();
}
namespace mlir {
void registerToSPIRVTranslation() {
TranslateFromMLIRRegistration toBinary(
"serialize-spirv", "serialize SPIR-V dialect",
[](ModuleOp module, raw_ostream &output) {
return serializeModule(module, output);
},
[](DialectRegistry &registry) {
registry.insert<spirv::SPIRVDialect>();
});
}
} // namespace mlir
//===----------------------------------------------------------------------===//
// Round-trip registration
//===----------------------------------------------------------------------===//
static LogicalResult roundTripModule(ModuleOp srcModule, bool emitDebugInfo,
raw_ostream &output) {
SmallVector<uint32_t, 0> binary;
MLIRContext *context = srcModule.getContext();
auto spirvModules = srcModule.getOps<spirv::ModuleOp>();
if (spirvModules.begin() == spirvModules.end())
return srcModule.emitError("found no 'spirv.module' op");
if (std::next(spirvModules.begin()) != spirvModules.end())
return srcModule.emitError("found more than one 'spirv.module' op");
spirv::SerializationOptions options;
options.emitDebugInfo = emitDebugInfo;
if (failed(spirv::serialize(*spirvModules.begin(), binary, options)))
return failure();
MLIRContext deserializationContext(context->getDialectRegistry());
// TODO: we should only load the required dialects instead of all dialects.
deserializationContext.loadAllAvailableDialects();
// Then deserialize to get back a SPIR-V module.
OwningOpRef<spirv::ModuleOp> spirvModule =
spirv::deserialize(binary, &deserializationContext);
if (!spirvModule)
return failure();
// Wrap around in a new MLIR module.
OwningOpRef<ModuleOp> dstModule(ModuleOp::create(
FileLineColLoc::get(&deserializationContext,
/*filename=*/"", /*line=*/0, /*column=*/0)));
dstModule->getBody()->push_front(spirvModule.release());
if (failed(verify(*dstModule)))
return failure();
dstModule->print(output);
return mlir::success();
}
namespace mlir {
void registerTestRoundtripSPIRV() {
TranslateFromMLIRRegistration roundtrip(
"test-spirv-roundtrip", "test roundtrip in SPIR-V dialect",
[](ModuleOp module, raw_ostream &output) {
return roundTripModule(module, /*emitDebugInfo=*/false, output);
},
[](DialectRegistry &registry) {
registry.insert<spirv::SPIRVDialect>();
});
}
void registerTestRoundtripDebugSPIRV() {
TranslateFromMLIRRegistration roundtrip(
"test-spirv-roundtrip-debug", "test roundtrip debug in SPIR-V",
[](ModuleOp module, raw_ostream &output) {
return roundTripModule(module, /*emitDebugInfo=*/true, output);
},
[](DialectRegistry &registry) {
registry.insert<spirv::SPIRVDialect>();
});
}
} // namespace mlir