Skip to content

Commit 85253f8

Browse files
Added check to partition kernel if size is smaller than items_per_block (#538) (#546)
Co-authored-by: Nick Breed <[email protected]> Co-authored-by: Nick Breed <[email protected]>
1 parent 435f7f4 commit 85253f8

File tree

1 file changed

+17
-16
lines changed

1 file changed

+17
-16
lines changed

rocprim/include/rocprim/device/detail/device_merge.hpp

+17-16
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved.
1+
// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved.
22
//
33
// Permission is hereby granted, free of charge, to any person obtaining a copy
44
// of this software and associated documentation files (the "Software"), to deal
@@ -68,24 +68,23 @@ void partition_kernel_impl(IndexIterator indices,
6868
const unsigned int spacing,
6969
BinaryFunction compare_function)
7070
{
71-
const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>();
72-
const unsigned int flat_block_id = ::rocprim::detail::block_id<0>();
71+
const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>();
72+
const unsigned int flat_block_id = ::rocprim::detail::block_id<0>();
7373
const unsigned int flat_block_size = ::rocprim::detail::block_size<0>();
74+
const unsigned int input_size = input1_size + input2_size;
75+
const unsigned int id = flat_block_id * flat_block_size + flat_id;
76+
const unsigned int partition_id = id * spacing;
77+
const unsigned int partitions = (input_size + spacing - 1) / spacing;
7478

75-
unsigned int id = flat_block_id * flat_block_size + flat_id;
79+
if(id > partitions)
80+
{
81+
return;
82+
}
7683

77-
unsigned int partition_id = id * spacing;
7884
size_t diag = min(static_cast<size_t>(partition_id), input1_size + input2_size);
7985

80-
unsigned int begin =
81-
merge_path(
82-
keys_input1,
83-
keys_input2,
84-
input1_size,
85-
input2_size,
86-
diag,
87-
compare_function
88-
);
86+
unsigned int begin
87+
= merge_path(keys_input1, keys_input2, input1_size, input2_size, diag, compare_function);
8988

9089
indices[id] = begin;
9190
}
@@ -310,8 +309,10 @@ void merge_kernel_impl(IndexIterator indices,
310309
const unsigned int valid_in_last_block = count - block_offset;
311310
const bool is_incomplete_block = valid_in_last_block < items_per_block;
312311

313-
const unsigned int p1 = indices[flat_block_id];
314-
const unsigned int p2 = indices[flat_block_id + 1];
312+
const unsigned int partitions = (count + items_per_block - 1) / items_per_block;
313+
314+
const unsigned int p1 = indices[rocprim::min(flat_block_id, partitions)];
315+
const unsigned int p2 = indices[rocprim::min(flat_block_id + 1, partitions)];
315316

316317
range_t range =
317318
compute_range(

0 commit comments

Comments
 (0)