|
1 |
| -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. |
| 1 | +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. |
2 | 2 | //
|
3 | 3 | // Permission is hereby granted, free of charge, to any person obtaining a copy
|
4 | 4 | // of this software and associated documentation files (the "Software"), to deal
|
@@ -68,24 +68,23 @@ void partition_kernel_impl(IndexIterator indices,
|
68 | 68 | const unsigned int spacing,
|
69 | 69 | BinaryFunction compare_function)
|
70 | 70 | {
|
71 |
| - const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); |
72 |
| - const unsigned int flat_block_id = ::rocprim::detail::block_id<0>(); |
| 71 | + const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); |
| 72 | + const unsigned int flat_block_id = ::rocprim::detail::block_id<0>(); |
73 | 73 | const unsigned int flat_block_size = ::rocprim::detail::block_size<0>();
|
| 74 | + const unsigned int input_size = input1_size + input2_size; |
| 75 | + const unsigned int id = flat_block_id * flat_block_size + flat_id; |
| 76 | + const unsigned int partition_id = id * spacing; |
| 77 | + const unsigned int partitions = (input_size + spacing - 1) / spacing; |
74 | 78 |
|
75 |
| - unsigned int id = flat_block_id * flat_block_size + flat_id; |
| 79 | + if(id > partitions) |
| 80 | + { |
| 81 | + return; |
| 82 | + } |
76 | 83 |
|
77 |
| - unsigned int partition_id = id * spacing; |
78 | 84 | size_t diag = min(static_cast<size_t>(partition_id), input1_size + input2_size);
|
79 | 85 |
|
80 |
| - unsigned int begin = |
81 |
| - merge_path( |
82 |
| - keys_input1, |
83 |
| - keys_input2, |
84 |
| - input1_size, |
85 |
| - input2_size, |
86 |
| - diag, |
87 |
| - compare_function |
88 |
| - ); |
| 86 | + unsigned int begin |
| 87 | + = merge_path(keys_input1, keys_input2, input1_size, input2_size, diag, compare_function); |
89 | 88 |
|
90 | 89 | indices[id] = begin;
|
91 | 90 | }
|
@@ -310,8 +309,10 @@ void merge_kernel_impl(IndexIterator indices,
|
310 | 309 | const unsigned int valid_in_last_block = count - block_offset;
|
311 | 310 | const bool is_incomplete_block = valid_in_last_block < items_per_block;
|
312 | 311 |
|
313 |
| - const unsigned int p1 = indices[flat_block_id]; |
314 |
| - const unsigned int p2 = indices[flat_block_id + 1]; |
| 312 | + const unsigned int partitions = (count + items_per_block - 1) / items_per_block; |
| 313 | + |
| 314 | + const unsigned int p1 = indices[rocprim::min(flat_block_id, partitions)]; |
| 315 | + const unsigned int p2 = indices[rocprim::min(flat_block_id + 1, partitions)]; |
315 | 316 |
|
316 | 317 | range_t range =
|
317 | 318 | compute_range(
|
|
0 commit comments