From 42be7a3c53698a165e9612619f6a34a65bbf91ff Mon Sep 17 00:00:00 2001 From: Konstantin Seurer Date: Wed, 16 Aug 2023 10:37:56 +0200 Subject: [PATCH 1/3] radv: Remove dead radix_sort_vk_get_memory_requirements call --- src/amd/vulkan/radv_acceleration_structure.c | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/amd/vulkan/radv_acceleration_structure.c b/src/amd/vulkan/radv_acceleration_structure.c index ece47b1230c88..5c5eb16e61f9d 100644 --- a/src/amd/vulkan/radv_acceleration_structure.c +++ b/src/amd/vulkan/radv_acceleration_structure.c @@ -745,10 +745,6 @@ morton_sort(VkCommandBuffer commandBuffer, uint32_t infoCount, { RADV_FROM_HANDLE(radv_cmd_buffer, cmd_buffer, commandBuffer); for (uint32_t i = 0; i < infoCount; ++i) { - struct radix_sort_vk_memory_requirements requirements; - radix_sort_vk_get_memory_requirements(cmd_buffer->device->meta_state.accel_struct_build.radix_sort, - bvh_states[i].node_count, &requirements); - struct radix_sort_vk_sort_devaddr_info info = cmd_buffer->device->meta_state.accel_struct_build.radix_sort_info; info.count = bvh_states[i].node_count; -- GitLab From faa17e5322ea66cd74e37aab48316059a05738d6 Mon Sep 17 00:00:00 2001 From: Konstantin Seurer Date: Wed, 16 Aug 2023 11:09:25 +0200 Subject: [PATCH 2/3] radv/radix_sort: Vendor the radix sort dispatch code This needs to be done so we can optimize it for occpuancy when building multiple acceleration structures in parallel. Changes to the original code: - Change // to /* */ - clang-format - Replace vkCmd calls with calls to the driver entrypoints - Add a light weight info struct - Use radv_fill_buffer directly --- src/amd/vulkan/radv_acceleration_structure.c | 218 ++++++++++++++++--- src/amd/vulkan/radv_private.h | 1 - 2 files changed, 187 insertions(+), 32 deletions(-) diff --git a/src/amd/vulkan/radv_acceleration_structure.c b/src/amd/vulkan/radv_acceleration_structure.c index 5c5eb16e61f9d..9866de2e594a8 100644 --- a/src/amd/vulkan/radv_acceleration_structure.c +++ b/src/amd/vulkan/radv_acceleration_structure.c @@ -27,7 +27,9 @@ #include "nir_builder.h" #include "radv_cs.h" +#include "radix_sort/common/vk/barrier.h" #include "radix_sort/radv_radix_sort.h" +#include "radix_sort/shaders/push.h" #include "bvh/build_interface.h" #include "bvh/bvh.h" @@ -76,6 +78,7 @@ static const uint32_t header_spv[] = { }; #define KEY_ID_PAIR_SIZE 8 +#define MORTON_BIT_SIZE 24 enum internal_build_type { INTERNAL_BUILD_TYPE_LBVH, @@ -382,17 +385,6 @@ cleanup: return result; } -static void -radix_sort_fill_buffer(VkCommandBuffer commandBuffer, radix_sort_vk_buffer_info_t const *buffer_info, - VkDeviceSize offset, VkDeviceSize size, uint32_t data) -{ - RADV_FROM_HANDLE(radv_cmd_buffer, cmd_buffer, commandBuffer); - - assert(size != VK_WHOLE_SIZE); - - radv_fill_buffer(cmd_buffer, NULL, NULL, buffer_info->devaddr + buffer_info->offset + offset, size, data); -} - VkResult radv_device_init_null_accel_struct(struct radv_device *device) { @@ -576,12 +568,6 @@ radv_device_init_accel_struct_build_state(struct radv_device *device) device->meta_state.accel_struct_build.radix_sort = radv_create_radix_sort_u64(radv_device_to_handle(device), &device->meta_state.alloc, device->meta_state.cache); - - struct radix_sort_vk_sort_devaddr_info *radix_sort_info = &device->meta_state.accel_struct_build.radix_sort_info; - radix_sort_info->ext = NULL; - radix_sort_info->key_bits = 24; - radix_sort_info->fill_buffer = radix_sort_fill_buffer; - exit: mtx_unlock(&device->meta_state.mtx); return result; @@ -743,28 +729,198 @@ morton_sort(VkCommandBuffer commandBuffer, uint32_t infoCount, const VkAccelerationStructureBuildGeometryInfoKHR *pInfos, struct bvh_state *bvh_states, enum radv_cmd_flush_bits flush_bits) { + /* Copyright 2019 The Fuchsia Authors. */ RADV_FROM_HANDLE(radv_cmd_buffer, cmd_buffer, commandBuffer); + + radix_sort_vk_t *rs = cmd_buffer->device->meta_state.accel_struct_build.radix_sort; + for (uint32_t i = 0; i < infoCount; ++i) { - struct radix_sort_vk_sort_devaddr_info info = cmd_buffer->device->meta_state.accel_struct_build.radix_sort_info; - info.count = bvh_states[i].node_count; + uint32_t count = bvh_states[i].node_count; + uint64_t keyvals_even_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_buffer_offset[0]; + uint64_t keyvals_odd_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_buffer_offset[1]; + uint64_t internal_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_internal_offset; + + /* Anything to do? */ + if (!count) { + bvh_states[i].scratch_offset = bvh_states[i].scratch.sort_buffer_offset[0]; + continue; + } + + /* + * OVERVIEW + * + * 1. Pad the keyvals in `scatter_even`. + * 2. Zero the `histograms` and `partitions`. + * --- BARRIER --- + * 3. HISTOGRAM is dispatched before PREFIX. + * --- BARRIER --- + * 4. PREFIX is dispatched before the first SCATTER. + * --- BARRIER --- + * 5. One or more SCATTER dispatches. + * + * Note that the `partitions` buffer can be zeroed anytime before the first + * scatter. + */ + + /* How many passes? */ + uint32_t keyval_bytes = rs->config.keyval_dwords * (uint32_t)sizeof(uint32_t); + uint32_t keyval_bits = keyval_bytes * 8; + uint32_t key_bits = MIN2(MORTON_BIT_SIZE, keyval_bits); + uint32_t passes = (key_bits + RS_RADIX_LOG2 - 1) / RS_RADIX_LOG2; + + bvh_states[i].scratch_offset = bvh_states[i].scratch.sort_buffer_offset[passes & 1]; + + /* + * PAD KEYVALS AND ZERO HISTOGRAM/PARTITIONS + * + * Pad fractional blocks with max-valued keyvals. + * + * Zero the histograms and partitions buffer. + * + * This assumes the partitions follow the histograms. + */ + + /* FIXME(allanmac): Consider precomputing some of these values and hang them off `rs`. */ + + /* How many scatter blocks? */ + uint32_t scatter_wg_size = 1 << rs->config.scatter.workgroup_size_log2; + uint32_t scatter_block_kvs = scatter_wg_size * rs->config.scatter.block_rows; + uint32_t scatter_blocks = (count + scatter_block_kvs - 1) / scatter_block_kvs; + uint32_t count_ru_scatter = scatter_blocks * scatter_block_kvs; + + /* + * How many histogram blocks? + * + * Note that it's OK to have more max-valued digits counted by the histogram + * than sorted by the scatters because the sort is stable. + */ + uint32_t histo_wg_size = 1 << rs->config.histogram.workgroup_size_log2; + uint32_t histo_block_kvs = histo_wg_size * rs->config.histogram.block_rows; + uint32_t histo_blocks = (count_ru_scatter + histo_block_kvs - 1) / histo_block_kvs; + uint32_t count_ru_histo = histo_blocks * histo_block_kvs; + + /* Fill with max values */ + if (count_ru_histo > count) { + radv_fill_buffer(cmd_buffer, NULL, NULL, keyvals_even_addr + count * keyval_bytes, + (count_ru_histo - count) * keyval_bytes, 0xFFFFFFFF); + } + + /* + * Zero histograms and invalidate partitions. + * + * Note that the partition invalidation only needs to be performed once + * because the even/odd scatter dispatches rely on the the previous pass to + * leave the partitions in an invalid state. + * + * Note that the last workgroup doesn't read/write a partition so it doesn't + * need to be initialized. + */ + uint32_t histo_partition_count = passes + scatter_blocks - 1; + uint32_t pass_idx = (keyval_bytes - passes); + + uint32_t fill_base = pass_idx * (RS_RADIX_SIZE * sizeof(uint32_t)); + + radv_fill_buffer(cmd_buffer, NULL, NULL, internal_addr + rs->internal.histograms.offset + fill_base, + histo_partition_count * (RS_RADIX_SIZE * sizeof(uint32_t)), 0); + + /* + * Pipeline: HISTOGRAM + * + * TODO(allanmac): All subgroups should try to process approximately the same + * number of blocks in order to minimize tail effects. This was implemented + * and reverted but should be reimplemented and benchmarked later. + */ + vk_barrier_transfer_w_to_compute_r(commandBuffer); + + uint64_t devaddr_histograms = internal_addr + rs->internal.histograms.offset; + + /* Dispatch histogram */ + struct rs_push_histogram push_histogram = { + .devaddr_histograms = devaddr_histograms, + .devaddr_keyvals = keyvals_even_addr, + .passes = passes, + }; + + radv_CmdPushConstants(commandBuffer, rs->pipeline_layouts.named.histogram, VK_SHADER_STAGE_COMPUTE_BIT, 0, + sizeof(push_histogram), &push_histogram); + + radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.histogram); + + vk_common_CmdDispatch(commandBuffer, histo_blocks, 1, 1); - info.keyvals_even.buffer = VK_NULL_HANDLE; - info.keyvals_even.offset = 0; - info.keyvals_even.devaddr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_buffer_offset[0]; + /* + * Pipeline: PREFIX + * + * Launch one workgroup per pass. + */ + vk_barrier_compute_w_to_compute_r(commandBuffer); + + struct rs_push_prefix push_prefix = { + .devaddr_histograms = devaddr_histograms, + }; + + radv_CmdPushConstants(commandBuffer, rs->pipeline_layouts.named.prefix, VK_SHADER_STAGE_COMPUTE_BIT, 0, + sizeof(push_prefix), &push_prefix); - info.keyvals_odd = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_buffer_offset[1]; + radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.prefix); - info.internal.buffer = VK_NULL_HANDLE; - info.internal.offset = 0; - info.internal.devaddr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_internal_offset; + vk_common_CmdDispatch(commandBuffer, passes, 1, 1); - VkDeviceAddress result_addr; - radix_sort_vk_sort_devaddr(cmd_buffer->device->meta_state.accel_struct_build.radix_sort, &info, - radv_device_to_handle(cmd_buffer->device), commandBuffer, &result_addr); + /* Pipeline: SCATTER */ + vk_barrier_compute_w_to_compute_r(commandBuffer); + + uint32_t histogram_offset = pass_idx * (RS_RADIX_SIZE * sizeof(uint32_t)); + uint64_t devaddr_partitions = internal_addr + rs->internal.partitions.offset; + + struct rs_push_scatter push_scatter = { + .devaddr_keyvals_even = keyvals_even_addr, + .devaddr_keyvals_odd = keyvals_odd_addr, + .devaddr_partitions = devaddr_partitions, + .devaddr_histograms = devaddr_histograms + histogram_offset, + .pass_offset = (pass_idx & 3) * RS_RADIX_LOG2, + }; - assert(result_addr == info.keyvals_even.devaddr || result_addr == info.keyvals_odd); + { + uint32_t pass_dword = pass_idx / 4; - bvh_states[i].scratch_offset = (uint32_t)(result_addr - pInfos[i].scratchData.deviceAddress); + radv_CmdPushConstants(commandBuffer, rs->pipeline_layouts.named.scatter[pass_dword].even, + VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(push_scatter), &push_scatter); + + radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, + rs->pipelines.named.scatter[pass_dword].even); + } + + bool is_even = true; + + while (true) { + vk_common_CmdDispatch(commandBuffer, scatter_blocks, 1, 1); + + /* Continue? */ + if (++pass_idx >= keyval_bytes) + break; + + vk_barrier_compute_w_to_compute_r(commandBuffer); + + is_even ^= true; + push_scatter.devaddr_histograms += (RS_RADIX_SIZE * sizeof(uint32_t)); + push_scatter.pass_offset = (pass_idx & 3) * RS_RADIX_LOG2; + + uint32_t pass_dword = pass_idx / 4; + + /* Update push constants that changed */ + VkPipelineLayout pl = is_even ? rs->pipeline_layouts.named.scatter[pass_dword].even + : rs->pipeline_layouts.named.scatter[pass_dword].odd; + radv_CmdPushConstants(commandBuffer, pl, VK_SHADER_STAGE_COMPUTE_BIT, + offsetof(struct rs_push_scatter, devaddr_histograms), + sizeof(push_scatter.devaddr_histograms) + sizeof(push_scatter.pass_offset), + &push_scatter.devaddr_histograms); + + /* Bind new pipeline */ + VkPipeline p = + is_even ? rs->pipelines.named.scatter[pass_dword].even : rs->pipelines.named.scatter[pass_dword].odd; + + radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, p); + } } cmd_buffer->state.flush_bits |= flush_bits; diff --git a/src/amd/vulkan/radv_private.h b/src/amd/vulkan/radv_private.h index 1ea606c2ca111..2b0b9189f33f6 100644 --- a/src/amd/vulkan/radv_private.h +++ b/src/amd/vulkan/radv_private.h @@ -731,7 +731,6 @@ struct radv_meta_state { VkPipeline copy_pipeline; struct radix_sort_vk *radix_sort; - struct radix_sort_vk_sort_devaddr_info radix_sort_info; struct { VkBuffer buffer; -- GitLab From 04c77145628fe9956ae44a25ba7b1dfe401a9de8 Mon Sep 17 00:00:00 2001 From: Konstantin Seurer Date: Wed, 16 Aug 2023 11:50:18 +0200 Subject: [PATCH 3/3] radv: Perform multiple sorts in parallel This was the last part that didn't scale with multiple infos. Reducing the amount of barriers in this case improves DOOM Eternal performance by 50%. (Running with low resolution) --- src/amd/vulkan/radv_acceleration_structure.c | 264 ++++++++++--------- 1 file changed, 143 insertions(+), 121 deletions(-) diff --git a/src/amd/vulkan/radv_acceleration_structure.c b/src/amd/vulkan/radv_acceleration_structure.c index 9866de2e594a8..85852453fcb29 100644 --- a/src/amd/vulkan/radv_acceleration_structure.c +++ b/src/amd/vulkan/radv_acceleration_structure.c @@ -598,6 +598,13 @@ struct bvh_state { struct acceleration_structure_layout accel_struct; struct scratch_layout scratch; struct build_config config; + + /* Radix sort state */ + uint32_t scatter_blocks; + uint32_t count_ru_scatter; + uint32_t histo_blocks; + uint32_t count_ru_histo; + struct rs_push_scatter push_scatter; }; static uint32_t @@ -734,75 +741,79 @@ morton_sort(VkCommandBuffer commandBuffer, uint32_t infoCount, radix_sort_vk_t *rs = cmd_buffer->device->meta_state.accel_struct_build.radix_sort; - for (uint32_t i = 0; i < infoCount; ++i) { - uint32_t count = bvh_states[i].node_count; - uint64_t keyvals_even_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_buffer_offset[0]; - uint64_t keyvals_odd_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_buffer_offset[1]; - uint64_t internal_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_internal_offset; + /* + * OVERVIEW + * + * 1. Pad the keyvals in `scatter_even`. + * 2. Zero the `histograms` and `partitions`. + * --- BARRIER --- + * 3. HISTOGRAM is dispatched before PREFIX. + * --- BARRIER --- + * 4. PREFIX is dispatched before the first SCATTER. + * --- BARRIER --- + * 5. One or more SCATTER dispatches. + * + * Note that the `partitions` buffer can be zeroed anytime before the first + * scatter. + */ + + /* How many passes? */ + uint32_t keyval_bytes = rs->config.keyval_dwords * (uint32_t)sizeof(uint32_t); + uint32_t keyval_bits = keyval_bytes * 8; + uint32_t key_bits = MIN2(MORTON_BIT_SIZE, keyval_bits); + uint32_t passes = (key_bits + RS_RADIX_LOG2 - 1) / RS_RADIX_LOG2; - /* Anything to do? */ - if (!count) { + for (uint32_t i = 0; i < infoCount; ++i) { + if (bvh_states[i].node_count) + bvh_states[i].scratch_offset = bvh_states[i].scratch.sort_buffer_offset[passes & 1]; + else bvh_states[i].scratch_offset = bvh_states[i].scratch.sort_buffer_offset[0]; - continue; - } - - /* - * OVERVIEW - * - * 1. Pad the keyvals in `scatter_even`. - * 2. Zero the `histograms` and `partitions`. - * --- BARRIER --- - * 3. HISTOGRAM is dispatched before PREFIX. - * --- BARRIER --- - * 4. PREFIX is dispatched before the first SCATTER. - * --- BARRIER --- - * 5. One or more SCATTER dispatches. - * - * Note that the `partitions` buffer can be zeroed anytime before the first - * scatter. - */ - - /* How many passes? */ - uint32_t keyval_bytes = rs->config.keyval_dwords * (uint32_t)sizeof(uint32_t); - uint32_t keyval_bits = keyval_bytes * 8; - uint32_t key_bits = MIN2(MORTON_BIT_SIZE, keyval_bits); - uint32_t passes = (key_bits + RS_RADIX_LOG2 - 1) / RS_RADIX_LOG2; + } - bvh_states[i].scratch_offset = bvh_states[i].scratch.sort_buffer_offset[passes & 1]; + /* + * PAD KEYVALS AND ZERO HISTOGRAM/PARTITIONS + * + * Pad fractional blocks with max-valued keyvals. + * + * Zero the histograms and partitions buffer. + * + * This assumes the partitions follow the histograms. + */ + + /* FIXME(allanmac): Consider precomputing some of these values and hang them off `rs`. */ + + /* How many scatter blocks? */ + uint32_t scatter_wg_size = 1 << rs->config.scatter.workgroup_size_log2; + uint32_t scatter_block_kvs = scatter_wg_size * rs->config.scatter.block_rows; + + /* + * How many histogram blocks? + * + * Note that it's OK to have more max-valued digits counted by the histogram + * than sorted by the scatters because the sort is stable. + */ + uint32_t histo_wg_size = 1 << rs->config.histogram.workgroup_size_log2; + uint32_t histo_block_kvs = histo_wg_size * rs->config.histogram.block_rows; + + uint32_t pass_idx = (keyval_bytes - passes); - /* - * PAD KEYVALS AND ZERO HISTOGRAM/PARTITIONS - * - * Pad fractional blocks with max-valued keyvals. - * - * Zero the histograms and partitions buffer. - * - * This assumes the partitions follow the histograms. - */ + for (uint32_t i = 0; i < infoCount; ++i) { + if (!bvh_states[i].node_count) + continue; - /* FIXME(allanmac): Consider precomputing some of these values and hang them off `rs`. */ + uint64_t keyvals_even_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_buffer_offset[0]; + uint64_t internal_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_internal_offset; - /* How many scatter blocks? */ - uint32_t scatter_wg_size = 1 << rs->config.scatter.workgroup_size_log2; - uint32_t scatter_block_kvs = scatter_wg_size * rs->config.scatter.block_rows; - uint32_t scatter_blocks = (count + scatter_block_kvs - 1) / scatter_block_kvs; - uint32_t count_ru_scatter = scatter_blocks * scatter_block_kvs; + bvh_states[i].scatter_blocks = (bvh_states[i].node_count + scatter_block_kvs - 1) / scatter_block_kvs; + bvh_states[i].count_ru_scatter = bvh_states[i].scatter_blocks * scatter_block_kvs; - /* - * How many histogram blocks? - * - * Note that it's OK to have more max-valued digits counted by the histogram - * than sorted by the scatters because the sort is stable. - */ - uint32_t histo_wg_size = 1 << rs->config.histogram.workgroup_size_log2; - uint32_t histo_block_kvs = histo_wg_size * rs->config.histogram.block_rows; - uint32_t histo_blocks = (count_ru_scatter + histo_block_kvs - 1) / histo_block_kvs; - uint32_t count_ru_histo = histo_blocks * histo_block_kvs; + bvh_states[i].histo_blocks = (bvh_states[i].count_ru_scatter + histo_block_kvs - 1) / histo_block_kvs; + bvh_states[i].count_ru_histo = bvh_states[i].histo_blocks * histo_block_kvs; /* Fill with max values */ - if (count_ru_histo > count) { - radv_fill_buffer(cmd_buffer, NULL, NULL, keyvals_even_addr + count * keyval_bytes, - (count_ru_histo - count) * keyval_bytes, 0xFFFFFFFF); + if (bvh_states[i].count_ru_histo > bvh_states[i].node_count) { + radv_fill_buffer(cmd_buffer, NULL, NULL, keyvals_even_addr + bvh_states[i].node_count * keyval_bytes, + (bvh_states[i].count_ru_histo - bvh_states[i].node_count) * keyval_bytes, 0xFFFFFFFF); } /* @@ -815,28 +826,35 @@ morton_sort(VkCommandBuffer commandBuffer, uint32_t infoCount, * Note that the last workgroup doesn't read/write a partition so it doesn't * need to be initialized. */ - uint32_t histo_partition_count = passes + scatter_blocks - 1; - uint32_t pass_idx = (keyval_bytes - passes); + uint32_t histo_partition_count = passes + bvh_states[i].scatter_blocks - 1; uint32_t fill_base = pass_idx * (RS_RADIX_SIZE * sizeof(uint32_t)); radv_fill_buffer(cmd_buffer, NULL, NULL, internal_addr + rs->internal.histograms.offset + fill_base, histo_partition_count * (RS_RADIX_SIZE * sizeof(uint32_t)), 0); + } - /* - * Pipeline: HISTOGRAM - * - * TODO(allanmac): All subgroups should try to process approximately the same - * number of blocks in order to minimize tail effects. This was implemented - * and reverted but should be reimplemented and benchmarked later. - */ - vk_barrier_transfer_w_to_compute_r(commandBuffer); + /* + * Pipeline: HISTOGRAM + * + * TODO(allanmac): All subgroups should try to process approximately the same + * number of blocks in order to minimize tail effects. This was implemented + * and reverted but should be reimplemented and benchmarked later. + */ + vk_barrier_transfer_w_to_compute_r(commandBuffer); + + radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.histogram); + + for (uint32_t i = 0; i < infoCount; ++i) { + if (!bvh_states[i].node_count) + continue; - uint64_t devaddr_histograms = internal_addr + rs->internal.histograms.offset; + uint64_t keyvals_even_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_buffer_offset[0]; + uint64_t internal_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_internal_offset; /* Dispatch histogram */ struct rs_push_histogram push_histogram = { - .devaddr_histograms = devaddr_histograms, + .devaddr_histograms = internal_addr + rs->internal.histograms.offset, .devaddr_keyvals = keyvals_even_addr, .passes = passes, }; @@ -844,83 +862,87 @@ morton_sort(VkCommandBuffer commandBuffer, uint32_t infoCount, radv_CmdPushConstants(commandBuffer, rs->pipeline_layouts.named.histogram, VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(push_histogram), &push_histogram); - radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.histogram); + vk_common_CmdDispatch(commandBuffer, bvh_states[i].histo_blocks, 1, 1); + } - vk_common_CmdDispatch(commandBuffer, histo_blocks, 1, 1); + /* + * Pipeline: PREFIX + * + * Launch one workgroup per pass. + */ + vk_barrier_compute_w_to_compute_r(commandBuffer); - /* - * Pipeline: PREFIX - * - * Launch one workgroup per pass. - */ - vk_barrier_compute_w_to_compute_r(commandBuffer); + radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.prefix); + + for (uint32_t i = 0; i < infoCount; ++i) { + if (!bvh_states[i].node_count) + continue; + + uint64_t internal_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_internal_offset; struct rs_push_prefix push_prefix = { - .devaddr_histograms = devaddr_histograms, + .devaddr_histograms = internal_addr + rs->internal.histograms.offset, }; radv_CmdPushConstants(commandBuffer, rs->pipeline_layouts.named.prefix, VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(push_prefix), &push_prefix); - radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.prefix); - vk_common_CmdDispatch(commandBuffer, passes, 1, 1); + } - /* Pipeline: SCATTER */ - vk_barrier_compute_w_to_compute_r(commandBuffer); + /* Pipeline: SCATTER */ + vk_barrier_compute_w_to_compute_r(commandBuffer); - uint32_t histogram_offset = pass_idx * (RS_RADIX_SIZE * sizeof(uint32_t)); - uint64_t devaddr_partitions = internal_addr + rs->internal.partitions.offset; + uint32_t histogram_offset = pass_idx * (RS_RADIX_SIZE * sizeof(uint32_t)); - struct rs_push_scatter push_scatter = { + for (uint32_t i = 0; i < infoCount; i++) { + uint64_t keyvals_even_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_buffer_offset[0]; + uint64_t keyvals_odd_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_buffer_offset[1]; + uint64_t internal_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_internal_offset; + + bvh_states[i].push_scatter = (struct rs_push_scatter){ .devaddr_keyvals_even = keyvals_even_addr, .devaddr_keyvals_odd = keyvals_odd_addr, - .devaddr_partitions = devaddr_partitions, - .devaddr_histograms = devaddr_histograms + histogram_offset, - .pass_offset = (pass_idx & 3) * RS_RADIX_LOG2, + .devaddr_partitions = internal_addr + rs->internal.partitions.offset, + .devaddr_histograms = internal_addr + rs->internal.histograms.offset + histogram_offset, }; + } - { - uint32_t pass_dword = pass_idx / 4; + bool is_even = true; - radv_CmdPushConstants(commandBuffer, rs->pipeline_layouts.named.scatter[pass_dword].even, - VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(push_scatter), &push_scatter); + while (true) { + uint32_t pass_dword = pass_idx / 4; - radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, - rs->pipelines.named.scatter[pass_dword].even); - } + /* Bind new pipeline */ + VkPipeline p = + is_even ? rs->pipelines.named.scatter[pass_dword].even : rs->pipelines.named.scatter[pass_dword].odd; + radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, p); - bool is_even = true; + /* Update push constants that changed */ + VkPipelineLayout pl = is_even ? rs->pipeline_layouts.named.scatter[pass_dword].even // + : rs->pipeline_layouts.named.scatter[pass_dword].odd; - while (true) { - vk_common_CmdDispatch(commandBuffer, scatter_blocks, 1, 1); + for (uint32_t i = 0; i < infoCount; i++) { + if (!bvh_states[i].node_count) + continue; - /* Continue? */ - if (++pass_idx >= keyval_bytes) - break; + bvh_states[i].push_scatter.pass_offset = (pass_idx & 3) * RS_RADIX_LOG2; - vk_barrier_compute_w_to_compute_r(commandBuffer); + radv_CmdPushConstants(commandBuffer, pl, VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(struct rs_push_scatter), + &bvh_states[i].push_scatter); - is_even ^= true; - push_scatter.devaddr_histograms += (RS_RADIX_SIZE * sizeof(uint32_t)); - push_scatter.pass_offset = (pass_idx & 3) * RS_RADIX_LOG2; + vk_common_CmdDispatch(commandBuffer, bvh_states[i].scatter_blocks, 1, 1); - uint32_t pass_dword = pass_idx / 4; + bvh_states[i].push_scatter.devaddr_histograms += (RS_RADIX_SIZE * sizeof(uint32_t)); + } - /* Update push constants that changed */ - VkPipelineLayout pl = is_even ? rs->pipeline_layouts.named.scatter[pass_dword].even - : rs->pipeline_layouts.named.scatter[pass_dword].odd; - radv_CmdPushConstants(commandBuffer, pl, VK_SHADER_STAGE_COMPUTE_BIT, - offsetof(struct rs_push_scatter, devaddr_histograms), - sizeof(push_scatter.devaddr_histograms) + sizeof(push_scatter.pass_offset), - &push_scatter.devaddr_histograms); + /* Continue? */ + if (++pass_idx >= keyval_bytes) + break; - /* Bind new pipeline */ - VkPipeline p = - is_even ? rs->pipelines.named.scatter[pass_dword].even : rs->pipelines.named.scatter[pass_dword].odd; + vk_barrier_compute_w_to_compute_r(commandBuffer); - radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, p); - } + is_even ^= true; } cmd_buffer->state.flush_bits |= flush_bits; -- GitLab