
All entry points (except olInit) now check that offload has been initialized. If not, a new `OL_ERRC_UNINITIALIZED` error is returned.
164 lines
5.7 KiB
C++
164 lines
5.7 KiB
C++
//===- offload-tblgen/EntryPointGen.cpp - Tablegen backend for Offload ----===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This is a Tablegen backend that produces the actual entry points for the
|
|
// Offload API. It serves as a place to integrate functionality like tracing
|
|
// and validation before dispatching to the actual implementations.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "llvm/Support/FormatVariadic.h"
|
|
#include "llvm/TableGen/Record.h"
|
|
|
|
#include "GenCommon.hpp"
|
|
#include "RecordTypes.hpp"
|
|
|
|
using namespace llvm;
|
|
using namespace offload::tblgen;
|
|
|
|
static void EmitValidationFunc(const FunctionRec &F, raw_ostream &OS) {
|
|
OS << CommentsHeader;
|
|
// Emit preamble
|
|
OS << formatv("llvm::Error {0}_val(\n ", F.getName());
|
|
// Emit arguments
|
|
std::string ParamNameList = "";
|
|
for (auto &Param : F.getParams()) {
|
|
OS << Param.getType() << " " << Param.getName();
|
|
if (Param != F.getParams().back()) {
|
|
OS << ", ";
|
|
}
|
|
ParamNameList += Param.getName().str() + ", ";
|
|
}
|
|
OS << ") {\n";
|
|
|
|
bool HasValidation = llvm::any_of(F.getReturns(), [](auto &R) {
|
|
return llvm::any_of(R.getConditions(), [](auto &C) {
|
|
return C.starts_with("`") && C.ends_with("`");
|
|
});
|
|
});
|
|
|
|
if (HasValidation) {
|
|
OS << TAB_1 "if (llvm::offload::isValidationEnabled()) {\n";
|
|
// Emit validation checks
|
|
for (const auto &Return : F.getReturns()) {
|
|
for (auto &Condition : Return.getConditions()) {
|
|
if (Condition.starts_with("`") && Condition.ends_with("`")) {
|
|
auto ConditionString = Condition.substr(1, Condition.size() - 2);
|
|
OS << formatv(TAB_2 "if ({0}) {{\n", ConditionString);
|
|
OS << formatv(TAB_3
|
|
"return createOffloadError(error::ErrorCode::{0}, "
|
|
"\"validation failure: {1}\");\n",
|
|
Return.getUnprefixedValue(), ConditionString);
|
|
OS << TAB_2 "}\n\n";
|
|
}
|
|
}
|
|
}
|
|
OS << TAB_1 "}\n\n";
|
|
}
|
|
|
|
// Perform actual function call to the implementation
|
|
ParamNameList = ParamNameList.substr(0, ParamNameList.size() - 2);
|
|
OS << formatv(TAB_1 "return llvm::offload::{0}_impl({1});\n\n", F.getName(),
|
|
ParamNameList);
|
|
OS << "}\n";
|
|
}
|
|
|
|
static void EmitEntryPointFunc(const FunctionRec &F, raw_ostream &OS) {
|
|
// Emit preamble
|
|
OS << formatv("{1}_APIEXPORT {0}_result_t {1}_APICALL {2}(\n ", PrefixLower,
|
|
PrefixUpper, F.getName());
|
|
// Emit arguments
|
|
std::string ParamNameList = "";
|
|
for (auto &Param : F.getParams()) {
|
|
OS << Param.getType() << " " << Param.getName();
|
|
if (Param != F.getParams().back()) {
|
|
OS << ", ";
|
|
}
|
|
ParamNameList += Param.getName().str() + ", ";
|
|
}
|
|
OS << ") {\n";
|
|
|
|
// Check offload is initialized
|
|
if (F.getName() != "olInit")
|
|
OS << "if (!llvm::offload::isOffloadInitialized()) return &UninitError;";
|
|
|
|
// Emit pre-call prints
|
|
OS << TAB_1 "if (llvm::offload::isTracingEnabled()) {\n";
|
|
OS << formatv(TAB_2 "llvm::errs() << \"---> {0}\";\n", F.getName());
|
|
OS << TAB_1 "}\n\n";
|
|
|
|
// Perform actual function call to the validation wrapper
|
|
ParamNameList = ParamNameList.substr(0, ParamNameList.size() - 2);
|
|
OS << formatv(
|
|
TAB_1 "{0}_result_t Result = llvmErrorToOffloadError({1}_val({2}));\n\n",
|
|
PrefixLower, F.getName(), ParamNameList);
|
|
|
|
// Emit post-call prints
|
|
OS << TAB_1 "if (llvm::offload::isTracingEnabled()) {\n";
|
|
if (F.getParams().size() > 0) {
|
|
OS << formatv(TAB_2 "{0} Params = {{", F.getParamStructName());
|
|
for (const auto &Param : F.getParams()) {
|
|
OS << "&" << Param.getName();
|
|
if (Param != F.getParams().back()) {
|
|
OS << ", ";
|
|
}
|
|
}
|
|
OS << formatv("};\n");
|
|
OS << TAB_2 "llvm::errs() << \"(\" << &Params << \")\";\n";
|
|
} else {
|
|
OS << TAB_2 "llvm::errs() << \"()\";\n";
|
|
}
|
|
OS << TAB_2 "llvm::errs() << \"-> \" << Result << \"\\n\";\n";
|
|
OS << TAB_2 "if (Result && Result->Details) {\n";
|
|
OS << TAB_3 "llvm::errs() << \" *Error Details* \" << Result->Details "
|
|
"<< \" \\n\";\n";
|
|
OS << TAB_2 "}\n";
|
|
OS << TAB_1 "}\n";
|
|
|
|
OS << TAB_1 "return Result;\n";
|
|
OS << "}\n";
|
|
}
|
|
|
|
static void EmitCodeLocWrapper(const FunctionRec &F, raw_ostream &OS) {
|
|
// Emit preamble
|
|
OS << formatv("{0}_result_t {1}WithCodeLoc(\n ", PrefixLower, F.getName());
|
|
// Emit arguments
|
|
std::string ParamNameList = "";
|
|
for (auto &Param : F.getParams()) {
|
|
OS << Param.getType() << " " << Param.getName() << ", ";
|
|
ParamNameList += Param.getName().str();
|
|
if (Param != F.getParams().back()) {
|
|
ParamNameList += ", ";
|
|
}
|
|
}
|
|
OS << "ol_code_location_t *CodeLocation";
|
|
OS << ") {\n";
|
|
OS << TAB_1 "currentCodeLocation() = CodeLocation;\n";
|
|
OS << formatv(TAB_1 "{0}_result_t Result = ::{1}({2});\n\n", PrefixLower,
|
|
F.getName(), ParamNameList);
|
|
OS << TAB_1 "currentCodeLocation() = nullptr;\n";
|
|
OS << TAB_1 "return Result;\n";
|
|
OS << "}\n";
|
|
}
|
|
|
|
void EmitOffloadEntryPoints(const RecordKeeper &Records, raw_ostream &OS) {
|
|
OS << GenericHeader;
|
|
|
|
constexpr const char *UninitMessage =
|
|
"liboffload has not been initialized - please call olInit before using "
|
|
"this API";
|
|
OS << formatv("static {0}_error_struct_t UninitError = "
|
|
"{{{1}_ERRC_UNINITIALIZED, \"{2}\"};",
|
|
PrefixLower, PrefixUpper, UninitMessage);
|
|
|
|
for (auto *R : Records.getAllDerivedDefinitions("Function")) {
|
|
EmitValidationFunc(FunctionRec{R}, OS);
|
|
EmitEntryPointFunc(FunctionRec{R}, OS);
|
|
EmitCodeLocWrapper(FunctionRec{R}, OS);
|
|
}
|
|
}
|