llvm-project/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp

155 lines
5.3 KiB
C++

//- NVPTXForwardParams.cpp - NVPTX Forward Device Params Removing Local Copy -//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// PTX supports 2 methods of accessing device function parameters:
//
// - "simple" case: If a parameters is only loaded, and all loads can address
// the parameter via a constant offset, then the parameter may be loaded via
// the ".param" address space. This case is not possible if the parameters
// is stored to or has it's address taken. This method is preferable when
// possible. Ex:
//
// ld.param.u32 %r1, [foo_param_1];
// ld.param.u32 %r2, [foo_param_1+4];
//
// - "move param" case: For more complex cases the address of the param may be
// placed in a register via a "mov" instruction. This "mov" also implicitly
// moves the param to the ".local" address space and allows for it to be
// written to. This essentially defers the responsibilty of the byval copy
// to the PTX calling convention.
//
// mov.b64 %rd1, foo_param_0;
// st.local.u32 [%rd1], 42;
// add.u64 %rd3, %rd1, %rd2;
// ld.local.u32 %r2, [%rd3];
//
// In NVPTXLowerArgs and SelectionDAG, we pessimistically assume that all
// parameters will use the "move param" case and the local address space. This
// pass is responsible for switching to the "simple" case when possible, as it
// is more efficient.
//
// We do this by simply traversing uses of the param "mov" instructions an
// trivially checking if they are all loads.
//
//===----------------------------------------------------------------------===//
#include "NVPTX.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/MachineInstr.h"
#include "llvm/CodeGen/MachineOperand.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/TargetRegisterInfo.h"
#include "llvm/Support/ErrorHandling.h"
using namespace llvm;
static bool traverseMoveUse(MachineInstr &U, const MachineRegisterInfo &MRI,
SmallVectorImpl<MachineInstr *> &RemoveList,
SmallVectorImpl<MachineInstr *> &LoadInsts) {
switch (U.getOpcode()) {
case NVPTX::LD_i16:
case NVPTX::LD_i32:
case NVPTX::LD_i64:
case NVPTX::LDV_i16_v2:
case NVPTX::LDV_i16_v4:
case NVPTX::LDV_i32_v2:
case NVPTX::LDV_i32_v4:
case NVPTX::LDV_i64_v2:
case NVPTX::LDV_i64_v4: {
LoadInsts.push_back(&U);
return true;
}
case NVPTX::cvta_local:
case NVPTX::cvta_local_64:
case NVPTX::cvta_to_local:
case NVPTX::cvta_to_local_64: {
for (auto &U2 : MRI.use_instructions(U.operands_begin()->getReg()))
if (!traverseMoveUse(U2, MRI, RemoveList, LoadInsts))
return false;
RemoveList.push_back(&U);
return true;
}
default:
return false;
}
}
static bool eliminateMove(MachineInstr &Mov, const MachineRegisterInfo &MRI,
SmallVectorImpl<MachineInstr *> &RemoveList) {
SmallVector<MachineInstr *, 16> MaybeRemoveList;
SmallVector<MachineInstr *, 16> LoadInsts;
for (auto &U : MRI.use_instructions(Mov.operands_begin()->getReg()))
if (!traverseMoveUse(U, MRI, MaybeRemoveList, LoadInsts))
return false;
RemoveList.append(MaybeRemoveList);
RemoveList.push_back(&Mov);
const MachineOperand *ParamSymbol = Mov.uses().begin();
assert(ParamSymbol->isSymbol());
constexpr unsigned LDInstBasePtrOpIdx = 5;
constexpr unsigned LDInstAddrSpaceOpIdx = 2;
for (auto *LI : LoadInsts) {
(LI->uses().begin() + LDInstBasePtrOpIdx)
->ChangeToES(ParamSymbol->getSymbolName());
(LI->uses().begin() + LDInstAddrSpaceOpIdx)
->ChangeToImmediate(NVPTX::AddressSpace::Param);
}
return true;
}
static bool forwardDeviceParams(MachineFunction &MF) {
const auto &MRI = MF.getRegInfo();
bool Changed = false;
SmallVector<MachineInstr *, 16> RemoveList;
for (auto &MI : make_early_inc_range(*MF.begin()))
if (MI.getOpcode() == NVPTX::MOV32_PARAM ||
MI.getOpcode() == NVPTX::MOV64_PARAM)
Changed |= eliminateMove(MI, MRI, RemoveList);
for (auto *MI : RemoveList)
MI->eraseFromParent();
return Changed;
}
/// ----------------------------------------------------------------------------
/// Pass (Manager) Boilerplate
/// ----------------------------------------------------------------------------
namespace {
struct NVPTXForwardParamsPass : public MachineFunctionPass {
static char ID;
NVPTXForwardParamsPass() : MachineFunctionPass(ID) {}
bool runOnMachineFunction(MachineFunction &MF) override;
void getAnalysisUsage(AnalysisUsage &AU) const override {
MachineFunctionPass::getAnalysisUsage(AU);
}
};
} // namespace
char NVPTXForwardParamsPass::ID = 0;
INITIALIZE_PASS(NVPTXForwardParamsPass, "nvptx-forward-params",
"NVPTX Forward Params", false, false)
bool NVPTXForwardParamsPass::runOnMachineFunction(MachineFunction &MF) {
return forwardDeviceParams(MF);
}
MachineFunctionPass *llvm::createNVPTXForwardParamsPass() {
return new NVPTXForwardParamsPass();
}