llvm-project/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
S. Bharadwaj Yadavalli 96547decd5
[DirectX] Infrastructure to collect shader flags for each function (#112967)
Currently, ShaderFlagsAnalysis pass represents various module-level
properties as well as function-level properties of a DXIL Module using a
single mask. However, one mask per function is needed for accurate
computation of shader flags mask, such as for entry function metadata
creation.

This change introduces a structure that wraps a sorted vector of
function-shader flag mask pairs that represent function properties
instead of a single shader flag mask that represents module properties
and properties of all functions. The result type of ShaderFlagsAnalysis
pass is changed to newly-defined structure type instead of a single
shader flags mask.

This allows accurate computation of shader flags of an entry function
(and all functions in a library shader) for use during its metadata
generation (DXILTranslateMetadata pass) and its feature flags in DX
container globals construction (DXContainerGlobals pass) based on the
shader flags mask of functions. However, note that the change to
implement propagation of such callee-based shader flags mask computation
is planned in a follow-on PR. Consequently, this PR changes shader flag
mask computation in DXILTranslateMetadata and DXContainerGlobals passes
to simply be a union of module flags and shader flags of all functions,
thereby retaining the existing effect of using a single shader flag
mask.
2024-11-25 16:02:46 -05:00

140 lines
4.9 KiB
C++

//===- DXILShaderFlags.cpp - DXIL Shader Flags helper objects -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file This file contains helper objects and APIs for working with DXIL
/// Shader Flags.
///
//===----------------------------------------------------------------------===//
#include "DXILShaderFlags.h"
#include "DirectX.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
using namespace llvm;
using namespace llvm::dxil;
static void updateFunctionFlags(ComputedShaderFlags &CSF,
const Instruction &I) {
if (!CSF.Doubles)
CSF.Doubles = I.getType()->isDoubleTy();
if (!CSF.Doubles) {
for (Value *Op : I.operands())
CSF.Doubles |= Op->getType()->isDoubleTy();
}
if (CSF.Doubles) {
switch (I.getOpcode()) {
case Instruction::FDiv:
case Instruction::UIToFP:
case Instruction::SIToFP:
case Instruction::FPToUI:
case Instruction::FPToSI:
// TODO: To be set if I is a call to DXIL intrinsic DXIL::Opcode::Fma
// https://github.com/llvm/llvm-project/issues/114554
CSF.DX11_1_DoubleExtensions = true;
break;
}
}
}
void ModuleShaderFlags::initialize(const Module &M) {
// Collect shader flags for each of the functions
for (const auto &F : M.getFunctionList()) {
if (F.isDeclaration())
continue;
ComputedShaderFlags CSF;
for (const auto &BB : F)
for (const auto &I : BB)
updateFunctionFlags(CSF, I);
// Insert shader flag mask for function F
FunctionFlags.push_back({&F, CSF});
// Update combined shader flags mask
CombinedSFMask.merge(CSF);
}
llvm::sort(FunctionFlags);
}
void ComputedShaderFlags::print(raw_ostream &OS) const {
uint64_t FlagVal = (uint64_t) * this;
OS << formatv("; Shader Flags Value: {0:x8}\n;\n", FlagVal);
if (FlagVal == 0)
return;
OS << "; Note: shader requires additional functionality:\n";
#define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleNum, FlagName, Str) \
if (FlagName) \
(OS << ";").indent(7) << Str << "\n";
#include "llvm/BinaryFormat/DXContainerConstants.def"
OS << "; Note: extra DXIL module flags:\n";
#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) \
if (FlagName) \
(OS << ";").indent(7) << Str << "\n";
#include "llvm/BinaryFormat/DXContainerConstants.def"
OS << ";\n";
}
/// Return the shader flags mask of the specified function Func.
const ComputedShaderFlags &
ModuleShaderFlags::getFunctionFlags(const Function *Func) const {
const auto Iter = llvm::lower_bound(
FunctionFlags, Func,
[](const std::pair<const Function *, ComputedShaderFlags> FSM,
const Function *FindFunc) { return (FSM.first < FindFunc); });
assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
"No Shader Flags Mask exists for function");
return Iter->second;
}
//===----------------------------------------------------------------------===//
// ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass
// Provide an explicit template instantiation for the static ID.
AnalysisKey ShaderFlagsAnalysis::Key;
ModuleShaderFlags ShaderFlagsAnalysis::run(Module &M,
ModuleAnalysisManager &AM) {
ModuleShaderFlags MSFI;
MSFI.initialize(M);
return MSFI;
}
PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M,
ModuleAnalysisManager &AM) {
const ModuleShaderFlags &FlagsInfo = AM.getResult<ShaderFlagsAnalysis>(M);
// Print description of combined shader flags for all module functions
OS << "; Combined Shader Flags for Module\n";
FlagsInfo.getCombinedFlags().print(OS);
// Print shader flags mask for each of the module functions
OS << "; Shader Flags for Module Functions\n";
for (const auto &F : M.getFunctionList()) {
if (F.isDeclaration())
continue;
auto SFMask = FlagsInfo.getFunctionFlags(&F);
OS << formatv("; Function {0} : {1:x8}\n;\n", F.getName(),
(uint64_t)(SFMask));
}
return PreservedAnalyses::all();
}
//===----------------------------------------------------------------------===//
// ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass
bool ShaderFlagsAnalysisWrapper::runOnModule(Module &M) {
MSFI.initialize(M);
return false;
}
char ShaderFlagsAnalysisWrapper::ID = 0;
INITIALIZE_PASS(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis",
"DXIL Shader Flag Analysis", true, true)