llvm-project/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
Kai 751a546fa9
[HLSL][DXIL][SPIRV] WavePrefixSum intrinsic support (#167946)
Issue: https://github.com/llvm/llvm-project/issues/99172
- [x] Implement `WavePrefixSum` clang builtin
- [x] Link `WavePrefixSum` clang builtin with `hlsl_intrinsics.h`
- [x] Add sema checks for `WavePrefixSum` to
`CheckHLSLBuiltinFunctionCall` in `SemaChecking.cpp`
- [x] Add codegen for `WavePrefixSum` to `EmitHLSLBuiltinExpr` in
`CGBuiltin.cpp`
- [x] Add codegen tests to
`clang/test/CodeGenHLSL/builtins/WavePrefixSum.hlsl`
- [x] Add sema tests to
`clang/test/SemaHLSL/BuiltIns/WavePrefixSum-errors.hlsl`
- [x] Create the `int_dx_WavePrefixSum` intrinsic in
`IntrinsicsDirectX.td`
- [x] Create the `DXILOpMapping` of `int_dx_WavePrefixSum` to `121` in
`DXIL.td`
- [x] Create the `WavePrefixSum.ll` and `WavePrefixSum_errors.ll` tests
in `llvm/test/CodeGen/DirectX/`
- [x] Create the `int_spv_WavePrefixSum` intrinsic in
`IntrinsicsSPIRV.td`
- [x] In SPIRVInstructionSelector.cpp create the `WavePrefixSum`
lowering and map it to `int_spv_WavePrefixSum` in
`SPIRVInstructionSelector::selectIntrinsic`.
- [x] Create SPIR-V backend test case in
`llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WavePrefixSum.ll`

I also added a new macro
`GENERATE_HLSL_INTRINSIC_FUNCTION_SELECT_UNSIGNED` in conjunction with
the new function `getUnsignedIntrinsicVariant` to make selecting
unsigned variants of the intrinsic easier. As a result, I was able to
replace `getWaveActiveSumIntrinsic`, `getWaveActiveMaxIntrinsic`, and
`getWaveActiveMinIntrinsic` using the new macro.
2026-02-03 03:00:45 -05:00

80 lines
2.3 KiB
C++

//===- DirectXTargetTransformInfo.cpp - DirectX TTI ---------------*- C++
//-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
//===----------------------------------------------------------------------===//
#include "DirectXTargetTransformInfo.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsDirectX.h"
using namespace llvm;
bool DirectXTTIImpl::isTargetIntrinsicWithScalarOpAtArg(
Intrinsic::ID ID, unsigned ScalarOpdIdx) const {
switch (ID) {
case Intrinsic::dx_wave_readlane:
return ScalarOpdIdx == 1;
default:
return false;
}
}
bool DirectXTTIImpl::isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
int OpdIdx) const {
switch (ID) {
case Intrinsic::dx_asdouble:
case Intrinsic::dx_firstbitlow:
case Intrinsic::dx_firstbitshigh:
case Intrinsic::dx_firstbituhigh:
case Intrinsic::dx_isinf:
case Intrinsic::dx_isnan:
case Intrinsic::dx_legacyf16tof32:
case Intrinsic::dx_legacyf32tof16:
return OpdIdx == 0;
default:
return OpdIdx == -1;
}
}
bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
Intrinsic::ID ID) const {
switch (ID) {
case Intrinsic::dx_asdouble:
case Intrinsic::dx_firstbitlow:
case Intrinsic::dx_firstbitshigh:
case Intrinsic::dx_firstbituhigh:
case Intrinsic::dx_frac:
case Intrinsic::dx_isinf:
case Intrinsic::dx_isnan:
case Intrinsic::dx_legacyf16tof32:
case Intrinsic::dx_legacyf32tof16:
case Intrinsic::dx_rsqrt:
case Intrinsic::dx_saturate:
case Intrinsic::dx_splitdouble:
case Intrinsic::dx_wave_readlane:
case Intrinsic::dx_wave_reduce_max:
case Intrinsic::dx_wave_reduce_min:
case Intrinsic::dx_wave_reduce_sum:
case Intrinsic::dx_wave_prefix_sum:
case Intrinsic::dx_wave_reduce_umax:
case Intrinsic::dx_wave_reduce_umin:
case Intrinsic::dx_wave_reduce_usum:
case Intrinsic::dx_wave_prefix_usum:
case Intrinsic::dx_imad:
case Intrinsic::dx_umad:
case Intrinsic::dx_ddx_coarse:
case Intrinsic::dx_ddy_coarse:
case Intrinsic::dx_ddx_fine:
case Intrinsic::dx_ddy_fine:
return true;
default:
return false;
}
}