[Samples][Ray Tracing] Correct offsets/sizes in the shader binding table.

This commit is contained in:
asuessenbach 2020-08-12 11:38:42 +02:00
parent 13fb2b59e0
commit 2571778a4e

View File

@ -640,6 +640,11 @@ glm::vec3 randomVec3( float minValue, float maxValue )
randomDistribution( randomGenerator ) ); randomDistribution( randomGenerator ) );
} }
size_t roundUp( size_t value, size_t alignment )
{
return ( ( value + alignment - 1 ) / alignment ) * alignment;
}
int main( int /*argc*/, char ** /*argv*/ ) int main( int /*argc*/, char ** /*argv*/ )
{ {
// number of cubes in x-, y-, and z-direction // number of cubes in x-, y-, and z-direction
@ -1097,7 +1102,7 @@ int main( int /*argc*/, char ** /*argv*/ )
uint32_t maxRecursionDepth = 2; uint32_t maxRecursionDepth = 2;
vk::RayTracingPipelineCreateInfoNV rayTracingPipelineCreateInfo( vk::RayTracingPipelineCreateInfoNV rayTracingPipelineCreateInfo(
{}, shaderStages, shaderGroups, maxRecursionDepth, *rayTracingPipelineLayout ); {}, shaderStages, shaderGroups, maxRecursionDepth, *rayTracingPipelineLayout );
vk::UniquePipeline rayTracingPipeline; vk::UniquePipeline rayTracingPipeline;
vk::ResultValue<vk::UniquePipeline> rvPipeline = vk::ResultValue<vk::UniquePipeline> rvPipeline =
device->createRayTracingPipelineNVUnique( nullptr, rayTracingPipelineCreateInfo ); device->createRayTracingPipelineNVUnique( nullptr, rayTracingPipelineCreateInfo );
switch ( rvPipeline.result ) switch ( rvPipeline.result )
@ -1109,16 +1114,32 @@ int main( int /*argc*/, char ** /*argv*/ )
default: assert( false ); // should never happen default: assert( false ); // should never happen
} }
vk::StructureChain<vk::PhysicalDeviceProperties2, vk::PhysicalDeviceRayTracingPropertiesNV> propertiesChain =
physicalDevice.getProperties2<vk::PhysicalDeviceProperties2, vk::PhysicalDeviceRayTracingPropertiesNV>();
uint32_t shaderGroupBaseAlignment =
propertiesChain.get<vk::PhysicalDeviceRayTracingPropertiesNV>().shaderGroupBaseAlignment;
uint32_t shaderGroupHandleSize = uint32_t shaderGroupHandleSize =
physicalDevice.getProperties2<vk::PhysicalDeviceProperties2, vk::PhysicalDeviceRayTracingPropertiesNV>() propertiesChain.get<vk::PhysicalDeviceRayTracingPropertiesNV>().shaderGroupHandleSize;
.get<vk::PhysicalDeviceRayTracingPropertiesNV>()
.shaderGroupHandleSize;
assert( !( shaderGroupHandleSize % 16 ) );
uint32_t shaderBindingTableSize = 5 * shaderGroupHandleSize; // 1x raygen, 2x miss, 2x hitGroup
// with 5 shaders, we need a buffer to hold 5 shaderGroupHandles vk::DeviceSize raygenShaderBindingOffset = 0; // starting with raygen
uint32_t raygenShaderTableSize = shaderGroupHandleSize; // one raygen shader
vk::DeviceSize missShaderBindingOffset =
raygenShaderBindingOffset + roundUp( raygenShaderTableSize, shaderGroupBaseAlignment );
vk::DeviceSize missShaderBindingStride = shaderGroupHandleSize;
uint32_t missShaderTableSize = vk::su::checked_cast<uint32_t>( 2 * missShaderBindingStride ); // two raygen shaders
vk::DeviceSize hitShaderBindingOffset =
missShaderBindingOffset + roundUp( missShaderTableSize, shaderGroupBaseAlignment );
vk::DeviceSize hitShaderBindingStride = shaderGroupHandleSize;
uint32_t hitShaderTableSize = vk::su::checked_cast<uint32_t>( 2 * hitShaderBindingStride ); // two hit shaders
vk::DeviceSize shaderBindingTableSize = hitShaderBindingOffset + hitShaderTableSize;
std::vector<uint8_t> shaderHandleStorage( shaderBindingTableSize ); std::vector<uint8_t> shaderHandleStorage( shaderBindingTableSize );
device->getRayTracingShaderGroupHandlesNV<uint8_t>( *rayTracingPipeline, 0, 5, shaderHandleStorage ); device->getRayTracingShaderGroupHandlesNV<uint8_t>(
*rayTracingPipeline, 0, 1, { raygenShaderTableSize, &shaderHandleStorage[raygenShaderBindingOffset] } );
device->getRayTracingShaderGroupHandlesNV<uint8_t>(
*rayTracingPipeline, 1, 2, { missShaderTableSize, &shaderHandleStorage[missShaderBindingOffset] } );
device->getRayTracingShaderGroupHandlesNV<uint8_t>(
*rayTracingPipeline, 3, 2, { hitShaderTableSize, &shaderHandleStorage[hitShaderBindingOffset] } );
vk::su::BufferData shaderBindingTableBufferData( physicalDevice, vk::su::BufferData shaderBindingTableBufferData( physicalDevice,
device, device,
@ -1250,20 +1271,14 @@ int main( int /*argc*/, char ** /*argv*/ )
*rayTracingDescriptorSets[backBufferIndex], *rayTracingDescriptorSets[backBufferIndex],
nullptr ); nullptr );
VkDeviceSize rayGenOffset = 0; // starting with raygen
VkDeviceSize missOffset = shaderGroupHandleSize; // after raygen
VkDeviceSize missStride = shaderGroupHandleSize;
VkDeviceSize hitGroupOffset = shaderGroupHandleSize + 2 * shaderGroupHandleSize; // after 1x raygen and 2x miss
VkDeviceSize hitGroupStride = shaderGroupHandleSize;
commandBuffer->traceRaysNV( *shaderBindingTableBufferData.buffer, commandBuffer->traceRaysNV( *shaderBindingTableBufferData.buffer,
rayGenOffset, raygenShaderBindingOffset,
*shaderBindingTableBufferData.buffer, *shaderBindingTableBufferData.buffer,
missOffset, missShaderBindingOffset,
missStride, missShaderBindingStride,
*shaderBindingTableBufferData.buffer, *shaderBindingTableBufferData.buffer,
hitGroupOffset, hitShaderBindingOffset,
hitGroupStride, hitShaderBindingStride,
nullptr, nullptr,
0, 0,
0, 0,