llvm-project/llvm/utils/TableGen/DXILEmitter.cpp
S. Bharadwaj Yadavalli b1c8b9f89c
[DirectX][NFC] Leverage LLVM and DirectX intrinsic description in DXIL Op records (#83193)
* Leverage TableGen record descriptions of LLVM or DirectX intrinsics
that can be directly mapped in DXIL Ops TableGen description. As a
result, such DXIL Ops can be succinctly described without duplication.
DXILEmitter backend can derive the properties of DXIL Ops accordingly.
* Ensured that corresponding lit tests pass.
2024-02-29 06:21:44 -08:00

434 lines
15 KiB
C++

//===- DXILEmitter.cpp - DXIL operation Emitter ---------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// DXILEmitter uses the descriptions of DXIL operation to construct enum and
// helper functions for DXIL operation.
//
//===----------------------------------------------------------------------===//
#include "CodeGenTarget.h"
#include "SequenceToOffsetTable.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/CodeGenTypes/MachineValueType.h"
#include "llvm/Support/DXILABI.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
using namespace llvm;
using namespace llvm::dxil;
namespace {
struct DXILShaderModel {
int Major = 0;
int Minor = 0;
};
struct DXILOperationDesc {
std::string OpName; // name of DXIL operation
int OpCode; // ID of DXIL operation
StringRef OpClass; // name of the opcode class
StringRef Doc; // the documentation description of this instruction
SmallVector<MVT::SimpleValueType> OpTypes; // Vector of operand types -
// return type is at index 0
SmallVector<std::string>
OpAttributes; // operation attribute represented as strings
StringRef Intrinsic; // The llvm intrinsic map to OpName. Default is "" which
// means no map exists
bool IsDeriv = false; // whether this is some kind of derivative
bool IsGradient = false; // whether this requires a gradient calculation
bool IsFeedback = false; // whether this is a sampler feedback op
bool IsWave =
false; // whether this requires in-wave, cross-lane functionality
bool RequiresUniformInputs = false; // whether this operation requires that
// all of its inputs are uniform across
// the wave
SmallVector<StringRef, 4>
ShaderStages; // shader stages to which this applies, empty for all.
DXILShaderModel ShaderModel; // minimum shader model required
DXILShaderModel ShaderModelTranslated; // minimum shader model required with
// translation by linker
int OverloadParamIndex; // parameter index which control the overload.
// When < 0, should be only 1 overload type.
SmallVector<StringRef, 4> counters; // counters for this inst.
DXILOperationDesc(const Record *);
};
} // end anonymous namespace
/// Convert DXIL type name string to dxil::ParameterKind
///
/// \param VT Simple Value Type
/// \return ParameterKind As defined in llvm/Support/DXILABI.h
static ParameterKind getParameterKind(MVT::SimpleValueType VT) {
switch (VT) {
case MVT::isVoid:
return ParameterKind::VOID;
case MVT::f16:
return ParameterKind::HALF;
case MVT::f32:
return ParameterKind::FLOAT;
case MVT::f64:
return ParameterKind::DOUBLE;
case MVT::i1:
return ParameterKind::I1;
case MVT::i8:
return ParameterKind::I8;
case MVT::i16:
return ParameterKind::I16;
case MVT::i32:
return ParameterKind::I32;
case MVT::fAny:
case MVT::iAny:
return ParameterKind::OVERLOAD;
default:
llvm_unreachable("Support for specified DXIL Type not yet implemented");
}
}
/// Construct an object using the DXIL Operation records specified
/// in DXIL.td. This serves as the single source of reference of
/// the information extracted from the specified Record R, for
/// C++ code generated by this TableGen backend.
// \param R Object representing TableGen record of a DXIL Operation
DXILOperationDesc::DXILOperationDesc(const Record *R) {
OpName = R->getNameInitAsString();
OpCode = R->getValueAsInt("OpCode");
Doc = R->getValueAsString("Doc");
if (R->getValue("LLVMIntrinsic")) {
auto *IntrinsicDef = R->getValueAsDef("LLVMIntrinsic");
auto DefName = IntrinsicDef->getName();
assert(DefName.starts_with("int_") && "invalid intrinsic name");
// Remove the int_ from intrinsic name.
Intrinsic = DefName.substr(4);
// TODO: It is expected that return type and parameter types of
// DXIL Operation are the same as that of the intrinsic. Deviations
// are expected to be encoded in TableGen record specification and
// handled accordingly here. Support to be added later, as needed.
// Get parameter type list of the intrinsic. Types attribute contains
// the list of as [returnType, param1Type,, param2Type, ...]
OverloadParamIndex = -1;
auto TypeRecs = IntrinsicDef->getValueAsListOfDefs("Types");
unsigned TypeRecsSize = TypeRecs.size();
// Populate return type and parameter type names
for (unsigned i = 0; i < TypeRecsSize; i++) {
auto TR = TypeRecs[i];
OpTypes.emplace_back(getValueType(TR->getValueAsDef("VT")));
// Get the overload parameter index.
// TODO : Seems hacky. Is it possible that more than one parameter can
// be of overload kind??
// TODO: Check for any additional constraints specified for DXIL operation
// restricting return type.
if (i > 0) {
auto &CurParam = OpTypes.back();
if (getParameterKind(CurParam) >= ParameterKind::OVERLOAD) {
OverloadParamIndex = i;
}
}
}
// Get the operation class
OpClass = R->getValueAsDef("OpClass")->getName();
// NOTE: For now, assume that attributes of DXIL Operation are the same as
// that of the intrinsic. Deviations are expected to be encoded in TableGen
// record specification and handled accordingly here. Support to be added
// later.
auto IntrPropList = IntrinsicDef->getValueAsListInit("IntrProperties");
auto IntrPropListSize = IntrPropList->size();
for (unsigned i = 0; i < IntrPropListSize; i++) {
OpAttributes.emplace_back(IntrPropList->getElement(i)->getAsString());
}
}
}
/// Return a string representation of ParameterKind enum
/// \param Kind Parameter Kind enum value
/// \return std::string string representation of input Kind
static std::string getParameterKindStr(ParameterKind Kind) {
switch (Kind) {
case ParameterKind::INVALID:
return "INVALID";
case ParameterKind::VOID:
return "VOID";
case ParameterKind::HALF:
return "HALF";
case ParameterKind::FLOAT:
return "FLOAT";
case ParameterKind::DOUBLE:
return "DOUBLE";
case ParameterKind::I1:
return "I1";
case ParameterKind::I8:
return "I8";
case ParameterKind::I16:
return "I16";
case ParameterKind::I32:
return "I32";
case ParameterKind::I64:
return "I64";
case ParameterKind::OVERLOAD:
return "OVERLOAD";
case ParameterKind::CBUFFER_RET:
return "CBUFFER_RET";
case ParameterKind::RESOURCE_RET:
return "RESOURCE_RET";
case ParameterKind::DXIL_HANDLE:
return "DXIL_HANDLE";
}
llvm_unreachable("Unknown llvm::dxil::ParameterKind enum");
}
/// Return a string representation of OverloadKind enum that maps to
/// input Simple Value Type enum
/// \param VT Simple Value Type enum
/// \return std::string string representation of OverloadKind
static std::string getOverloadKindStr(MVT::SimpleValueType VT) {
switch (VT) {
case MVT::isVoid:
return "OverloadKind::VOID";
case MVT::f16:
return "OverloadKind::HALF";
case MVT::f32:
return "OverloadKind::FLOAT";
case MVT::f64:
return "OverloadKind::DOUBLE";
case MVT::i1:
return "OverloadKind::I1";
case MVT::i8:
return "OverloadKind::I8";
case MVT::i16:
return "OverloadKind::I16";
case MVT::i32:
return "OverloadKind::I32";
case MVT::i64:
return "OverloadKind::I64";
case MVT::iAny:
return "OverloadKind::I16 | OverloadKind::I32 | OverloadKind::I64";
case MVT::fAny:
return "OverloadKind::HALF | OverloadKind::FLOAT | OverloadKind::DOUBLE";
default:
llvm_unreachable(
"Support for specified parameter OverloadKind not yet implemented");
}
}
/// Emit Enums of DXIL Ops
/// \param A vector of DXIL Ops
/// \param Output stream
static void emitDXILEnums(std::vector<DXILOperationDesc> &Ops,
raw_ostream &OS) {
// Sort by OpCode
llvm::sort(Ops, [](DXILOperationDesc &A, DXILOperationDesc &B) {
return A.OpCode < B.OpCode;
});
OS << "// Enumeration for operations specified by DXIL\n";
OS << "enum class OpCode : unsigned {\n";
for (auto &Op : Ops) {
// Name = ID, // Doc
OS << Op.OpName << " = " << Op.OpCode << ", // " << Op.Doc << "\n";
}
OS << "\n};\n\n";
OS << "// Groups for DXIL operations with equivalent function templates\n";
OS << "enum class OpCodeClass : unsigned {\n";
// Build an OpClass set to print
SmallSet<StringRef, 2> OpClassSet;
for (auto &Op : Ops) {
OpClassSet.insert(Op.OpClass);
}
for (auto &C : OpClassSet) {
OS << C << ",\n";
}
OS << "\n};\n\n";
}
/// Emit map of DXIL operation to LLVM or DirectX intrinsic
/// \param A vector of DXIL Ops
/// \param Output stream
static void emitDXILIntrinsicMap(std::vector<DXILOperationDesc> &Ops,
raw_ostream &OS) {
OS << "\n";
// FIXME: use array instead of SmallDenseMap.
OS << "static const SmallDenseMap<Intrinsic::ID, dxil::OpCode> LowerMap = "
"{\n";
for (auto &Op : Ops) {
if (Op.Intrinsic.empty())
continue;
// {Intrinsic::sin, dxil::OpCode::Sin},
OS << " { Intrinsic::" << Op.Intrinsic << ", dxil::OpCode::" << Op.OpName
<< "},\n";
}
OS << "};\n";
OS << "\n";
}
/// Convert operation attribute string to Attribute enum
///
/// \param Attr string reference
/// \return std::string Attribute enum string
static std::string emitDXILOperationAttr(SmallVector<std::string> Attrs) {
for (auto Attr : Attrs) {
// TODO: For now just recognize IntrNoMem and IntrReadMem as valid and
// ignore others.
if (Attr == "IntrNoMem") {
return "Attribute::ReadNone";
} else if (Attr == "IntrReadMem") {
return "Attribute::ReadOnly";
}
}
return "Attribute::None";
}
/// Emit DXIL operation table
/// \param A vector of DXIL Ops
/// \param Output stream
static void emitDXILOperationTable(std::vector<DXILOperationDesc> &Ops,
raw_ostream &OS) {
// Sort by OpCode.
llvm::sort(Ops, [](DXILOperationDesc &A, DXILOperationDesc &B) {
return A.OpCode < B.OpCode;
});
// Collect Names.
SequenceToOffsetTable<std::string> OpClassStrings;
SequenceToOffsetTable<std::string> OpStrings;
SequenceToOffsetTable<SmallVector<ParameterKind>> Parameters;
StringMap<SmallVector<ParameterKind>> ParameterMap;
StringSet<> ClassSet;
for (auto &Op : Ops) {
OpStrings.add(Op.OpName);
if (ClassSet.contains(Op.OpClass))
continue;
ClassSet.insert(Op.OpClass);
OpClassStrings.add(Op.OpClass.data());
SmallVector<ParameterKind> ParamKindVec;
// ParamKindVec is a vector of parameters. Skip return type at index 0
for (unsigned i = 1; i < Op.OpTypes.size(); i++) {
ParamKindVec.emplace_back(getParameterKind(Op.OpTypes[i]));
}
ParameterMap[Op.OpClass] = ParamKindVec;
Parameters.add(ParamKindVec);
}
// Layout names.
OpStrings.layout();
OpClassStrings.layout();
Parameters.layout();
// Emit the DXIL operation table.
//{dxil::OpCode::Sin, OpCodeNameIndex, OpCodeClass::unary,
// OpCodeClassNameIndex,
// OverloadKind::FLOAT | OverloadKind::HALF, Attribute::AttrKind::ReadNone, 0,
// 3, ParameterTableOffset},
OS << "static const OpCodeProperty *getOpCodeProperty(dxil::OpCode Op) "
"{\n";
OS << " static const OpCodeProperty OpCodeProps[] = {\n";
for (auto &Op : Ops) {
OS << " { dxil::OpCode::" << Op.OpName << ", " << OpStrings.get(Op.OpName)
<< ", OpCodeClass::" << Op.OpClass << ", "
<< OpClassStrings.get(Op.OpClass.data()) << ", "
<< getOverloadKindStr(Op.OpTypes[0]) << ", "
<< emitDXILOperationAttr(Op.OpAttributes) << ", "
<< Op.OverloadParamIndex << ", " << Op.OpTypes.size() - 1 << ", "
<< Parameters.get(ParameterMap[Op.OpClass]) << " },\n";
}
OS << " };\n";
OS << " // FIXME: change search to indexing with\n";
OS << " // Op once all DXIL operations are added.\n";
OS << " OpCodeProperty TmpProp;\n";
OS << " TmpProp.OpCode = Op;\n";
OS << " const OpCodeProperty *Prop =\n";
OS << " llvm::lower_bound(OpCodeProps, TmpProp,\n";
OS << " [](const OpCodeProperty &A, const "
"OpCodeProperty &B) {\n";
OS << " return A.OpCode < B.OpCode;\n";
OS << " });\n";
OS << " assert(Prop && \"failed to find OpCodeProperty\");\n";
OS << " return Prop;\n";
OS << "}\n\n";
// Emit the string tables.
OS << "static const char *getOpCodeName(dxil::OpCode Op) {\n\n";
OpStrings.emitStringLiteralDef(OS,
" static const char DXILOpCodeNameTable[]");
OS << " auto *Prop = getOpCodeProperty(Op);\n";
OS << " unsigned Index = Prop->OpCodeNameOffset;\n";
OS << " return DXILOpCodeNameTable + Index;\n";
OS << "}\n\n";
OS << "static const char *getOpCodeClassName(const OpCodeProperty &Prop) "
"{\n\n";
OpClassStrings.emitStringLiteralDef(
OS, " static const char DXILOpCodeClassNameTable[]");
OS << " unsigned Index = Prop.OpCodeClassNameOffset;\n";
OS << " return DXILOpCodeClassNameTable + Index;\n";
OS << "}\n ";
OS << "static const ParameterKind *getOpCodeParameterKind(const "
"OpCodeProperty &Prop) "
"{\n\n";
OS << " static const ParameterKind DXILOpParameterKindTable[] = {\n";
Parameters.emit(
OS,
[](raw_ostream &ParamOS, ParameterKind Kind) {
ParamOS << "ParameterKind::" << getParameterKindStr(Kind);
},
"ParameterKind::INVALID");
OS << " };\n\n";
OS << " unsigned Index = Prop.ParameterTableOffset;\n";
OS << " return DXILOpParameterKindTable + Index;\n";
OS << "}\n ";
}
/// Entry function call that invokes the functionality of this TableGen backend
/// \param Records TableGen records of DXIL Operations defined in DXIL.td
/// \param OS output stream
static void EmitDXILOperation(RecordKeeper &Records, raw_ostream &OS) {
OS << "// Generated code, do not edit.\n";
OS << "\n";
// Get all DXIL Ops to intrinsic mapping records
std::vector<Record *> OpIntrMaps =
Records.getAllDerivedDefinitions("DXILOpMapping");
std::vector<DXILOperationDesc> DXILOps;
for (auto *Record : OpIntrMaps) {
DXILOps.emplace_back(DXILOperationDesc(Record));
}
OS << "#ifdef DXIL_OP_ENUM\n";
emitDXILEnums(DXILOps, OS);
OS << "#endif\n\n";
OS << "#ifdef DXIL_OP_INTRINSIC_MAP\n";
emitDXILIntrinsicMap(DXILOps, OS);
OS << "#endif\n\n";
OS << "#ifdef DXIL_OP_OPERATION_TABLE\n";
emitDXILOperationTable(DXILOps, OS);
OS << "#endif\n\n";
}
static TableGen::Emitter::Opt X("gen-dxil-operation", EmitDXILOperation,
"Generate DXIL operation information");