-
Notifications
You must be signed in to change notification settings - Fork 345
Proclaim return types and other fixes needed for CCCL 3.2 #5375
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7db7b72
657504e
17dbd5a
9463815
f3d7f0e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -736,8 +736,9 @@ void multisource_backward_pass( | |
|
|
||
| auto d_first = thrust::make_transform_iterator( | ||
| distances_2d.begin(), | ||
| cuda::proclaim_return_type<vertex_t>( | ||
| [invalid_distance] __device__(auto d) { return d == invalid_distance ? vertex_t{0} : d; })); | ||
| cuda::proclaim_return_type<vertex_t>([invalid_distance] __device__(vertex_t d) { | ||
| return d == invalid_distance ? vertex_t{0} : d; | ||
| })); | ||
| vertex_t global_max_distance = thrust::reduce(handle.get_thrust_policy(), | ||
| d_first, | ||
| d_first + distances_2d.size(), | ||
|
|
@@ -795,18 +796,21 @@ void multisource_backward_pass( | |
| auto v_first = graph_view.local_vertex_partition_range_first(); | ||
|
|
||
| // Calculate offsets for each distance level in the consecutive arrays | ||
| std::vector<size_t> h_distance_offsets(global_max_distance + 1); | ||
| // Need global_max_distance + 2 elements: one for each distance level (0 to global_max_distance) | ||
| // plus a sentinel at the end for CUB segmented sort end offsets | ||
| std::vector<size_t> h_distance_offsets(global_max_distance + 2); | ||
| size_t offset = 0; | ||
| for (vertex_t d = 0; d <= global_max_distance; ++d) { | ||
| h_distance_offsets[d] = offset; | ||
| offset += host_distance_counts[d]; | ||
| } | ||
| h_distance_offsets[global_max_distance + 1] = offset; // sentinel = total_vertices | ||
|
|
||
| // Copy offsets to device for kernel access | ||
| rmm::device_uvector<size_t> d_distance_offsets(global_max_distance + 1, handle.get_stream()); | ||
| rmm::device_uvector<size_t> d_distance_offsets(global_max_distance + 2, handle.get_stream()); | ||
| raft::update_device(d_distance_offsets.data(), | ||
| h_distance_offsets.data(), | ||
| global_max_distance + 1, | ||
| global_max_distance + 2, | ||
| handle.get_stream()); | ||
|
|
||
| // Populate consecutive arrays - single scan of distance array | ||
|
|
@@ -873,7 +877,11 @@ void multisource_backward_pass( | |
| // Allocate temporary storage for CUB segmented sort | ||
| rmm::device_uvector<std::byte> d_tmp_storage(0, handle.get_stream()); | ||
|
|
||
| // Process each chunk - sort consecutive arrays directly in-place | ||
| // Allocate output buffers for CUB sort (input/output cannot overlap) | ||
| rmm::device_uvector<vertex_t> sorted_vertices(total_vertices, handle.get_stream()); | ||
| rmm::device_uvector<origin_t> sorted_sources(total_vertices, handle.get_stream()); | ||
|
Comment on lines
-876
to
+882
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tests were failing so I was investigating this file, and I think this comment points to a violation of the CUB API preconditions. The inputs and outputs of the As a workaround, I am initializing an output buffer for the sorted results, and then moving it back to the original variable once computation is complete. |
||
|
|
||
| // Process each chunk | ||
| for (size_t chunk_i = 0; chunk_i < num_chunks; ++chunk_i) { | ||
| size_t chunk_vertex_start = h_vertex_chunk_offsets[chunk_i]; | ||
| size_t chunk_vertex_end = h_vertex_chunk_offsets[chunk_i + 1]; | ||
|
|
@@ -884,17 +892,19 @@ void multisource_backward_pass( | |
|
|
||
| if (num_segments_in_chunk > 0) { | ||
| auto offset_first = thrust::make_transform_iterator( | ||
| h_distance_offsets.data() + chunk_distance_start, | ||
| [chunk_vertex_start] __device__(size_t offset) { return offset - chunk_vertex_start; }); | ||
| d_distance_offsets.data() + chunk_distance_start, | ||
| cuda::proclaim_return_type<size_t>([chunk_vertex_start] __device__(size_t offset) { | ||
| return offset - chunk_vertex_start; | ||
| })); | ||
|
|
||
| // CUB segmented sort directly on consecutive arrays - no copy needed! | ||
| // CUB segmented sort requires separate input and output buffers | ||
| size_t temp_storage_bytes = 0; | ||
| cub::DeviceSegmentedSort::SortPairs(nullptr, | ||
| temp_storage_bytes, | ||
| all_vertices.data() + chunk_vertex_start, | ||
| all_vertices.data() + chunk_vertex_start, | ||
| all_sources.data() + chunk_vertex_start, | ||
| sorted_vertices.data() + chunk_vertex_start, | ||
| all_sources.data() + chunk_vertex_start, | ||
| sorted_sources.data() + chunk_vertex_start, | ||
| chunk_size, | ||
| num_segments_in_chunk, | ||
| offset_first, | ||
|
|
@@ -908,16 +918,20 @@ void multisource_backward_pass( | |
| cub::DeviceSegmentedSort::SortPairs(d_tmp_storage.data(), | ||
| temp_storage_bytes, | ||
| all_vertices.data() + chunk_vertex_start, | ||
| all_vertices.data() + chunk_vertex_start, | ||
| all_sources.data() + chunk_vertex_start, | ||
| sorted_vertices.data() + chunk_vertex_start, | ||
| all_sources.data() + chunk_vertex_start, | ||
| sorted_sources.data() + chunk_vertex_start, | ||
| chunk_size, | ||
| num_segments_in_chunk, | ||
| offset_first, | ||
| offset_first + 1, | ||
| handle.get_stream()); | ||
| } | ||
| } | ||
|
|
||
| // Use the sorted arrays for subsequent processing | ||
| all_vertices = std::move(sorted_vertices); | ||
| all_sources = std::move(sorted_sources); | ||
| } | ||
|
|
||
| // Process distance levels using pre-computed buckets (now with sorted vertices) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This appears to be an off-by-one error that results in an illegal memory access in
SortPairs.