diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/device_batch_memcpy.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/device_batch_memcpy.hpp index 84198f6cbc8..2eb9c9d111e 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/device_batch_memcpy.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/device_batch_memcpy.hpp @@ -545,8 +545,14 @@ struct batch_memcpy_impl if(blev_buffer_offset < num_blev_buffers) { auto tile_buffer_id = buffer_by_size_class[blev_buffer_offset].buffer_id; + /* In the case that buffer_size_type is rocthrust::device_reference a static cast to + / buffer_size_type is needed so that the type passed into ceiling_div is not + / rocthrust::device_reference. This is possible since rocthrust::device_reference + / can be implicitly cast to type T. + */ + buffer_size_type size = static_cast(buffers.sizes[tile_buffer_id]); tile_offsets[i] - = rocprim::detail::ceiling_div(buffers.sizes[tile_buffer_id], + = rocprim::detail::ceiling_div(size, blev_block_size * blev_bytes_per_thread); } else @@ -620,10 +626,15 @@ struct batch_memcpy_impl buffer_offset += warps_per_block) { const auto buffer_id = buffers_by_size_class[buffer_offset].buffer_id; - + /* In the case that buffer_size_type is rocthrust::device_reference a static cast to + / buffer_size_type is needed so that the type passed into copy_items is not + / rocthrust::device_reference. This is possible since rocthrust::device_reference + / can be implicitly cast to type T. + */ + buffer_size_type size = static_cast(tile_buffers.sizes[buffer_id]); batch_memcpy::copy_items(tile_buffers.srcs[buffer_id], - tile_buffers.dsts[buffer_id], - tile_buffers.sizes[buffer_id]); + tile_buffers.dsts[buffer_id], + size); } }