[Samples][Ray Tracing] Correct offsets/sizes in the shader binding table.

This commit is contained in:
asuessenbach 2020-08-12 11:38:42 +02:00
parent 13fb2b59e0
commit 2571778a4e

View File

@ -640,6 +640,11 @@ glm::vec3 randomVec3( float minValue, float maxValue )
randomDistribution( randomGenerator ) );
}
size_t roundUp( size_t value, size_t alignment )
{
return ( ( value + alignment - 1 ) / alignment ) * alignment;
}
int main( int /*argc*/, char ** /*argv*/ )
{
// number of cubes in x-, y-, and z-direction
@ -1097,7 +1102,7 @@ int main( int /*argc*/, char ** /*argv*/ )
uint32_t maxRecursionDepth = 2;
vk::RayTracingPipelineCreateInfoNV rayTracingPipelineCreateInfo(
{}, shaderStages, shaderGroups, maxRecursionDepth, *rayTracingPipelineLayout );
vk::UniquePipeline rayTracingPipeline;
vk::UniquePipeline rayTracingPipeline;
vk::ResultValue<vk::UniquePipeline> rvPipeline =
device->createRayTracingPipelineNVUnique( nullptr, rayTracingPipelineCreateInfo );
switch ( rvPipeline.result )
@ -1109,16 +1114,32 @@ int main( int /*argc*/, char ** /*argv*/ )
default: assert( false ); // should never happen
}
vk::StructureChain<vk::PhysicalDeviceProperties2, vk::PhysicalDeviceRayTracingPropertiesNV> propertiesChain =
physicalDevice.getProperties2<vk::PhysicalDeviceProperties2, vk::PhysicalDeviceRayTracingPropertiesNV>();
uint32_t shaderGroupBaseAlignment =
propertiesChain.get<vk::PhysicalDeviceRayTracingPropertiesNV>().shaderGroupBaseAlignment;
uint32_t shaderGroupHandleSize =
physicalDevice.getProperties2<vk::PhysicalDeviceProperties2, vk::PhysicalDeviceRayTracingPropertiesNV>()
.get<vk::PhysicalDeviceRayTracingPropertiesNV>()
.shaderGroupHandleSize;
assert( !( shaderGroupHandleSize % 16 ) );
uint32_t shaderBindingTableSize = 5 * shaderGroupHandleSize; // 1x raygen, 2x miss, 2x hitGroup
propertiesChain.get<vk::PhysicalDeviceRayTracingPropertiesNV>().shaderGroupHandleSize;
// with 5 shaders, we need a buffer to hold 5 shaderGroupHandles
vk::DeviceSize raygenShaderBindingOffset = 0; // starting with raygen
uint32_t raygenShaderTableSize = shaderGroupHandleSize; // one raygen shader
vk::DeviceSize missShaderBindingOffset =
raygenShaderBindingOffset + roundUp( raygenShaderTableSize, shaderGroupBaseAlignment );
vk::DeviceSize missShaderBindingStride = shaderGroupHandleSize;
uint32_t missShaderTableSize = vk::su::checked_cast<uint32_t>( 2 * missShaderBindingStride ); // two raygen shaders
vk::DeviceSize hitShaderBindingOffset =
missShaderBindingOffset + roundUp( missShaderTableSize, shaderGroupBaseAlignment );
vk::DeviceSize hitShaderBindingStride = shaderGroupHandleSize;
uint32_t hitShaderTableSize = vk::su::checked_cast<uint32_t>( 2 * hitShaderBindingStride ); // two hit shaders
vk::DeviceSize shaderBindingTableSize = hitShaderBindingOffset + hitShaderTableSize;
std::vector<uint8_t> shaderHandleStorage( shaderBindingTableSize );
device->getRayTracingShaderGroupHandlesNV<uint8_t>( *rayTracingPipeline, 0, 5, shaderHandleStorage );
device->getRayTracingShaderGroupHandlesNV<uint8_t>(
*rayTracingPipeline, 0, 1, { raygenShaderTableSize, &shaderHandleStorage[raygenShaderBindingOffset] } );
device->getRayTracingShaderGroupHandlesNV<uint8_t>(
*rayTracingPipeline, 1, 2, { missShaderTableSize, &shaderHandleStorage[missShaderBindingOffset] } );
device->getRayTracingShaderGroupHandlesNV<uint8_t>(
*rayTracingPipeline, 3, 2, { hitShaderTableSize, &shaderHandleStorage[hitShaderBindingOffset] } );
vk::su::BufferData shaderBindingTableBufferData( physicalDevice,
device,
@ -1250,20 +1271,14 @@ int main( int /*argc*/, char ** /*argv*/ )
*rayTracingDescriptorSets[backBufferIndex],
nullptr );
VkDeviceSize rayGenOffset = 0; // starting with raygen
VkDeviceSize missOffset = shaderGroupHandleSize; // after raygen
VkDeviceSize missStride = shaderGroupHandleSize;
VkDeviceSize hitGroupOffset = shaderGroupHandleSize + 2 * shaderGroupHandleSize; // after 1x raygen and 2x miss
VkDeviceSize hitGroupStride = shaderGroupHandleSize;
commandBuffer->traceRaysNV( *shaderBindingTableBufferData.buffer,
rayGenOffset,
raygenShaderBindingOffset,
*shaderBindingTableBufferData.buffer,
missOffset,
missStride,
missShaderBindingOffset,
missShaderBindingStride,
*shaderBindingTableBufferData.buffer,
hitGroupOffset,
hitGroupStride,
hitShaderBindingOffset,
hitShaderBindingStride,
nullptr,
0,
0,