diff --git a/benchmark/CMakeLists.txt b/benchmark/CMakeLists.txt index 9f02b161b..eb136407c 100644 --- a/benchmark/CMakeLists.txt +++ b/benchmark/CMakeLists.txt @@ -60,6 +60,7 @@ add_rocprim_benchmark_hip(benchmark_hip_block_radix_sort.cpp) add_rocprim_benchmark_hip(benchmark_hip_block_reduce.cpp) add_rocprim_benchmark_hip(benchmark_hip_block_scan.cpp) add_rocprim_benchmark_hip(benchmark_hip_block_sort.cpp) +add_rocprim_benchmark_hip(benchmark_hip_device_binary_search.cpp) add_rocprim_benchmark_hip(benchmark_hip_device_histogram.cpp) add_rocprim_benchmark_hip(benchmark_hip_device_merge.cpp) add_rocprim_benchmark_hip(benchmark_hip_device_merge_sort.cpp) diff --git a/benchmark/benchmark_hc_block_discontinuity.cpp b/benchmark/benchmark_hc_block_discontinuity.cpp index 3fe7464aa..83d10d5ea 100644 --- a/benchmark/benchmark_hc_block_discontinuity.cpp +++ b/benchmark/benchmark_hc_block_discontinuity.cpp @@ -91,7 +91,7 @@ struct flag_heads rp::block_store_direct_striped(lid, d_output.data() + block_offset, input); } ); - } + } }; struct flag_tails @@ -114,7 +114,7 @@ struct flag_tails { const unsigned int lid = idx.local[0]; const unsigned int block_offset = idx.tile[0] * ItemsPerThread * BlockSize; - + T input[ItemsPerThread]; rp::block_load_direct_striped(lid, d_input.data() + block_offset, input); @@ -154,7 +154,7 @@ struct flag_heads_and_tails bool WithTile, unsigned int Trials > -static void run(const hc::array & d_input, const hc::array & d_output, + static void run(const hc::array & d_input, const hc::array & d_output, hc::accelerator_view acc_view, size_t size) { const size_t grid_size = size / ItemsPerThread; @@ -304,7 +304,7 @@ int main(int argc, char *argv[]) benchmark::Initialize(&argc, argv); const size_t size = parser.get("size"); const int trials = parser.get("trials"); - + // HC hc::accelerator acc; auto acc_view = acc.get_default_view(); diff --git a/benchmark/benchmark_hip_block_discontinuity.cpp b/benchmark/benchmark_hip_block_discontinuity.cpp index 19610ea53..58512a481 100644 --- a/benchmark/benchmark_hip_block_discontinuity.cpp +++ b/benchmark/benchmark_hip_block_discontinuity.cpp @@ -56,6 +56,20 @@ const size_t DEFAULT_N = 1024 * 1024 * 128; namespace rp = rocprim; +template< + class Runner, + class T, + unsigned int BlockSize, + unsigned int ItemsPerThread, + bool WithTile, + unsigned int Trials +> +__global__ +void kernel(const T * d_input, T * d_output) +{ + Runner::template run(d_input, d_output); +} + struct flag_heads { template< @@ -65,8 +79,8 @@ struct flag_heads bool WithTile, unsigned int Trials > - __global__ - static void kernel(const T * d_input, T * d_output) + __device__ + static void run(const T * d_input, T * d_output) { const unsigned int lid = hipThreadIdx_x; const unsigned int block_offset = hipBlockIdx_x * ItemsPerThread * BlockSize; @@ -108,8 +122,8 @@ struct flag_tails bool WithTile, unsigned int Trials > - __global__ - static void kernel(const T * d_input, T * d_output) + __device__ + static void run(const T * d_input, T * d_output) { const unsigned int lid = hipThreadIdx_x; const unsigned int block_offset = hipBlockIdx_x * ItemsPerThread * BlockSize; @@ -151,8 +165,8 @@ struct flag_heads_and_tails bool WithTile, unsigned int Trials > - __global__ - static void kernel(const T * d_input, T * d_output) + __device__ + static void run(const T * d_input, T * d_output) { const unsigned int lid = hipThreadIdx_x; const unsigned int block_offset = hipBlockIdx_x * ItemsPerThread * BlockSize; @@ -219,7 +233,7 @@ void run_benchmark(benchmark::State& state, hipStream_t stream, size_t N) auto start = std::chrono::high_resolution_clock::now(); hipLaunchKernelGGL( - HIP_KERNEL_NAME(Benchmark::template kernel), + HIP_KERNEL_NAME(kernel), dim3(size/items_per_block), dim3(BlockSize), 0, stream, d_input, d_output ); diff --git a/benchmark/benchmark_hip_block_exchange.cpp b/benchmark/benchmark_hip_block_exchange.cpp index bc694d265..d04e2ab48 100644 --- a/benchmark/benchmark_hip_block_exchange.cpp +++ b/benchmark/benchmark_hip_block_exchange.cpp @@ -56,6 +56,19 @@ const size_t DEFAULT_N = 1024 * 1024 * 128; namespace rp = rocprim; +template< + class Runner, + class T, + unsigned int BlockSize, + unsigned int ItemsPerThread, + unsigned int Trials +> +__global__ +void kernel(const T * d_input, T * d_output) +{ + Runner::template run(d_input, d_output); +} + struct blocked_to_striped { template< @@ -64,8 +77,8 @@ struct blocked_to_striped unsigned int ItemsPerThread, unsigned int Trials > - __global__ - static void kernel(const T * d_input, T * d_output) + __device__ + static void run(const T * d_input, T * d_output) { const unsigned int lid = hipThreadIdx_x; const unsigned int block_offset = hipBlockIdx_x * ItemsPerThread * BlockSize; @@ -92,8 +105,8 @@ struct striped_to_blocked unsigned int ItemsPerThread, unsigned int Trials > - __global__ - static void kernel(const T * d_input, T * d_output) + __device__ + static void run(const T * d_input, T * d_output) { const unsigned int lid = hipThreadIdx_x; const unsigned int block_offset = hipBlockIdx_x * ItemsPerThread * BlockSize; @@ -120,8 +133,8 @@ struct blocked_to_warp_striped unsigned int ItemsPerThread, unsigned int Trials > - __global__ - static void kernel(const T * d_input, T * d_output) + __device__ + static void run(const T * d_input, T * d_output) { const unsigned int lid = hipThreadIdx_x; const unsigned int block_offset = hipBlockIdx_x * ItemsPerThread * BlockSize; @@ -148,8 +161,8 @@ struct warp_striped_to_blocked unsigned int ItemsPerThread, unsigned int Trials > - __global__ - static void kernel(const T * d_input, T * d_output) + __device__ + static void run(const T * d_input, T * d_output) { const unsigned int lid = hipThreadIdx_x; const unsigned int block_offset = hipBlockIdx_x * ItemsPerThread * BlockSize; @@ -176,8 +189,8 @@ struct scatter_to_blocked unsigned int ItemsPerThread, unsigned int Trials > - __global__ - static void kernel(const T * d_input, T * d_output) + __device__ + static void run(const T * d_input, T * d_output) { const unsigned int lid = hipThreadIdx_x; const unsigned int block_offset = hipBlockIdx_x * ItemsPerThread * BlockSize; @@ -206,8 +219,8 @@ struct scatter_to_striped unsigned int ItemsPerThread, unsigned int Trials > - __global__ - static void kernel(const T * d_input, T * d_output) + __device__ + static void run(const T * d_input, T * d_output) { const unsigned int lid = hipThreadIdx_x; const unsigned int block_offset = hipBlockIdx_x * ItemsPerThread * BlockSize; @@ -267,7 +280,7 @@ void run_benchmark(benchmark::State& state, hipStream_t stream, size_t N) auto start = std::chrono::high_resolution_clock::now(); hipLaunchKernelGGL( - HIP_KERNEL_NAME(Benchmark::template kernel), + HIP_KERNEL_NAME(kernel), dim3(size/items_per_block), dim3(BlockSize), 0, stream, d_input, d_output ); diff --git a/benchmark/benchmark_hip_block_histogram.cpp b/benchmark/benchmark_hip_block_histogram.cpp index 495e7fb11..3956bbcf6 100644 --- a/benchmark/benchmark_hip_block_histogram.cpp +++ b/benchmark/benchmark_hip_block_histogram.cpp @@ -56,6 +56,20 @@ const size_t DEFAULT_N = 1024 * 1024 * 128; namespace rp = rocprim; +template< + class Runner, + class T, + unsigned int BlockSize, + unsigned int ItemsPerThread, + unsigned int BinSize, + unsigned int Trials +> +__global__ +void kernel(const T* input, T* output) +{ + Runner::template run(input, output); +} + template struct histogram { @@ -66,8 +80,8 @@ struct histogram unsigned int BinSize, unsigned int Trials > - __global__ - static void kernel(const T* input, T* output) + __device__ + static void run(const T* input, T* output) { const unsigned int index = ((hipBlockIdx_x * BlockSize) + hipThreadIdx_x) * ItemsPerThread; unsigned int global_offset = hipBlockIdx_x * BinSize; @@ -95,7 +109,7 @@ struct histogram { output[global_offset + hipThreadIdx_x] = histogram[offset + hipThreadIdx_x]; global_offset += BlockSize; - } + } } } }; @@ -133,7 +147,7 @@ void run_benchmark(benchmark::State& state, hipStream_t stream, size_t N) { auto start = std::chrono::high_resolution_clock::now(); hipLaunchKernelGGL( - HIP_KERNEL_NAME(Benchmark::template kernel), + HIP_KERNEL_NAME(kernel), dim3(size/items_per_block), dim3(BlockSize), 0, stream, d_input, d_output ); diff --git a/benchmark/benchmark_hip_block_reduce.cpp b/benchmark/benchmark_hip_block_reduce.cpp index 46f697244..6b0937afd 100644 --- a/benchmark/benchmark_hip_block_reduce.cpp +++ b/benchmark/benchmark_hip_block_reduce.cpp @@ -56,6 +56,19 @@ const size_t DEFAULT_N = 1024 * 1024 * 128; namespace rp = rocprim; +template< + class Runner, + class T, + unsigned int BlockSize, + unsigned int ItemsPerThread, + unsigned int Trials +> +__global__ +void kernel(const T* input, T* output) +{ + Runner::template run(input, output); +} + template struct reduce { @@ -65,8 +78,8 @@ struct reduce unsigned int ItemsPerThread, unsigned int Trials > - __global__ - static void kernel(const T* input, T* output) + __device__ + static void run(const T* input, T* output) { const unsigned int i = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x; @@ -125,7 +138,7 @@ void run_benchmark(benchmark::State& state, hipStream_t stream, size_t N) { auto start = std::chrono::high_resolution_clock::now(); hipLaunchKernelGGL( - HIP_KERNEL_NAME(Benchmark::template kernel), + HIP_KERNEL_NAME(kernel), dim3(size/items_per_block), dim3(BlockSize), 0, stream, d_input, d_output ); diff --git a/benchmark/benchmark_hip_block_scan.cpp b/benchmark/benchmark_hip_block_scan.cpp index 9632e3c9b..7105983f2 100644 --- a/benchmark/benchmark_hip_block_scan.cpp +++ b/benchmark/benchmark_hip_block_scan.cpp @@ -56,6 +56,19 @@ const size_t DEFAULT_N = 1024 * 1024 * 128; namespace rp = rocprim; +template< + class Runner, + class T, + unsigned int BlockSize, + unsigned int ItemsPerThread, + unsigned int Trials +> +__global__ +void kernel(const T* input, T* output) +{ + Runner::template run(input, output); +} + template struct inclusive_scan { @@ -65,8 +78,8 @@ struct inclusive_scan unsigned int ItemsPerThread, unsigned int Trials > - __global__ - static void kernel(const T* input, T* output) + __device__ + static void run(const T* input, T* output) { const unsigned int i = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x; @@ -102,8 +115,8 @@ struct exclusive_scan unsigned int ItemsPerThread, unsigned int Trials > - __global__ - static void kernel(const T* input, T* output) + __device__ + static void run(const T* input, T* output) { const unsigned int i = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x; using U = typename std::remove_reference::type; @@ -164,7 +177,7 @@ void run_benchmark(benchmark::State& state, hipStream_t stream, size_t N) { auto start = std::chrono::high_resolution_clock::now(); hipLaunchKernelGGL( - HIP_KERNEL_NAME(Benchmark::template kernel), + HIP_KERNEL_NAME(kernel), dim3(size/items_per_block), dim3(BlockSize), 0, stream, d_input, d_output ); diff --git a/benchmark/benchmark_hip_device_binary_search.cpp b/benchmark/benchmark_hip_device_binary_search.cpp new file mode 100644 index 000000000..efa4fea59 --- /dev/null +++ b/benchmark/benchmark_hip_device_binary_search.cpp @@ -0,0 +1,228 @@ +// MIT License +// +// Copyright (c) 2019 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include +#include +#include +#include +#include +#include +#include + +// Google Benchmark +#include "benchmark/benchmark.h" +// CmdParser +#include "cmdparser.hpp" +#include "benchmark_utils.hpp" + +// HIP API +#include +#include + +// rocPRIM +#include + +#define HIP_CHECK(condition) \ + { \ + hipError_t error = condition; \ + if(error != hipSuccess){ \ + std::cout << "HIP error: " << error << " line: " << __LINE__ << std::endl; \ + exit(error); \ + } \ + } + +#ifndef DEFAULT_N +const size_t DEFAULT_N = 1024 * 1024 * 32; +#endif + +const unsigned int batch_size = 10; +const unsigned int warmup_size = 5; + +template +void run_lower_bound_benchmark(benchmark::State& state, hipStream_t stream, + size_t haystack_size, size_t needles_size, + bool sorted_needles) +{ + using haystack_type = T; + using needle_type = T; + using output_type = size_t; + + // Generate data + std::vector haystack(haystack_size); + std::iota(haystack.begin(), haystack.end(), 0); + + std::vector needles = get_random_data( + needles_size, needle_type(0), needle_type(haystack_size) + ); + if(sorted_needles) + { + std::sort(needles.begin(), needles.end()); + } + + haystack_type * d_haystack; + needle_type * d_needles; + output_type * d_output; + HIP_CHECK(hipMalloc(&d_haystack, haystack_size * sizeof(haystack_type))); + HIP_CHECK(hipMalloc(&d_needles, needles_size * sizeof(needle_type))); + HIP_CHECK(hipMalloc(&d_output, needles_size * sizeof(output_type))); + HIP_CHECK( + hipMemcpy( + d_haystack, haystack.data(), + haystack_size * sizeof(haystack_type), + hipMemcpyHostToDevice + ) + ); + HIP_CHECK( + hipMemcpy( + d_needles, needles.data(), + needles_size * sizeof(needle_type), + hipMemcpyHostToDevice + ) + ); + + void * d_temporary_storage = nullptr; + size_t temporary_storage_bytes; + HIP_CHECK( + rocprim::lower_bound( + d_temporary_storage, temporary_storage_bytes, + d_haystack, d_needles, d_output, + haystack_size, needles_size, + rocprim::less<>(), + stream + ) + ); + + HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); + + // Warm-up + for(size_t i = 0; i < warmup_size; i++) + { + HIP_CHECK( + rocprim::lower_bound( + d_temporary_storage, temporary_storage_bytes, + d_haystack, d_needles, d_output, + haystack_size, needles_size, + rocprim::less<>(), + stream + ) + ); + } + HIP_CHECK(hipDeviceSynchronize()); + + for(auto _ : state) + { + auto start = std::chrono::high_resolution_clock::now(); + + for(size_t i = 0; i < batch_size; i++) + { + HIP_CHECK( + rocprim::lower_bound( + d_temporary_storage, temporary_storage_bytes, + d_haystack, d_needles, d_output, + haystack_size, needles_size, + rocprim::less<>(), + stream + ) + ); + } + HIP_CHECK(hipDeviceSynchronize()); + + auto end = std::chrono::high_resolution_clock::now(); + auto elapsed_seconds = + std::chrono::duration_cast>(end - start); + state.SetIterationTime(elapsed_seconds.count()); + } + state.SetBytesProcessed(state.iterations() * batch_size * needles_size * sizeof(needle_type)); + state.SetItemsProcessed(state.iterations() * batch_size * needles_size); + + HIP_CHECK(hipFree(d_temporary_storage)); + HIP_CHECK(hipFree(d_haystack)); + HIP_CHECK(hipFree(d_needles)); + HIP_CHECK(hipFree(d_output)); +} + +#define CREATE_LOWER_BOUND_BENCHMARK(T, K, SORTED) \ +benchmark::RegisterBenchmark( \ + ( \ + std::string("lower_bound") + "<" #T ">(" #K "\% " + \ + (SORTED ? "sorted" : "random") + " needles)" \ + ).c_str(), \ + [=](benchmark::State& state) { run_lower_bound_benchmark(state, stream, size, size * K / 100, SORTED); } \ +) + +int main(int argc, char *argv[]) +{ + cli::Parser parser(argc, argv); + parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("trials", "trials", -1, "number of iterations"); + parser.run_and_exit_if_error(); + + // Parse argv + benchmark::Initialize(&argc, argv); + const size_t size = parser.get("size"); + const int trials = parser.get("trials"); + + // HIP + hipStream_t stream = 0; // default + hipDeviceProp_t devProp; + int device_id = 0; + HIP_CHECK(hipGetDevice(&device_id)); + HIP_CHECK(hipGetDeviceProperties(&devProp, device_id)); + std::cout << "[HIP] Device name: " << devProp.name << std::endl; + + using custom_float2 = custom_type; + using custom_double2 = custom_type; + + // Add benchmarks + std::vector benchmarks = + { + CREATE_LOWER_BOUND_BENCHMARK(float, 10, false), + CREATE_LOWER_BOUND_BENCHMARK(double, 10, false), + CREATE_LOWER_BOUND_BENCHMARK(custom_float2, 10, false), + CREATE_LOWER_BOUND_BENCHMARK(custom_double2, 10, false), + + CREATE_LOWER_BOUND_BENCHMARK(float, 10, true), + CREATE_LOWER_BOUND_BENCHMARK(double, 10, true), + CREATE_LOWER_BOUND_BENCHMARK(custom_float2, 10, true), + CREATE_LOWER_BOUND_BENCHMARK(custom_double2, 10, true), + }; + + // Use manual timing + for(auto& b : benchmarks) + { + b->UseManualTime(); + b->Unit(benchmark::kMillisecond); + } + + // Force number of iterations + if(trials > 0) + { + for(auto& b : benchmarks) + { + b->Iterations(trials); + } + } + + // Run benchmarks + benchmark::RunSpecifiedBenchmarks(); + return 0; +} diff --git a/benchmark/benchmark_hip_device_merge.cpp b/benchmark/benchmark_hip_device_merge.cpp index 8c7c22261..d40bfc911 100644 --- a/benchmark/benchmark_hip_device_merge.cpp +++ b/benchmark/benchmark_hip_device_merge.cpp @@ -67,36 +67,123 @@ void run_merge_keys_benchmark(benchmark::State& state, hipStream_t stream, size_ const size_t size1 = size / 2; const size_t size2 = size - size1; + ::rocprim::less compare_op; + // Generate data - std::vector keys_input1; - std::vector keys_input2; - if(std::is_floating_point::value) + std::vector keys_input1 = get_random_data(size1, 0, size); + std::vector keys_input2 = get_random_data(size2, 0, size); + std::sort(keys_input1.begin(), keys_input1.end(), compare_op); + std::sort(keys_input2.begin(), keys_input2.end(), compare_op); + + key_type * d_keys_input1; + key_type * d_keys_input2; + key_type * d_keys_output; + HIP_CHECK(hipMalloc(&d_keys_input1, size1 * sizeof(key_type))); + HIP_CHECK(hipMalloc(&d_keys_input2, size2 * sizeof(key_type))); + HIP_CHECK(hipMalloc(&d_keys_output, size * sizeof(key_type))); + HIP_CHECK( + hipMemcpy( + d_keys_input1, keys_input1.data(), + size1 * sizeof(key_type), + hipMemcpyHostToDevice + ) + ); + HIP_CHECK( + hipMemcpy( + d_keys_input2, keys_input2.data(), + size2 * sizeof(key_type), + hipMemcpyHostToDevice + ) + ); + + void * d_temporary_storage = nullptr; + size_t temporary_storage_bytes = 0; + HIP_CHECK( + rp::merge( + d_temporary_storage, temporary_storage_bytes, + d_keys_input1, d_keys_input2, d_keys_output, size1, size2, + compare_op, stream, false + ) + ); + + HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); + + // Warm-up + for(size_t i = 0; i < warmup_size; i++) { - keys_input1 = get_random_data(size1, (key_type)-1000, (key_type)+1000); - keys_input2 = get_random_data(size2, (key_type)-1000, (key_type)+1000); + HIP_CHECK( + rp::merge( + d_temporary_storage, temporary_storage_bytes, + d_keys_input1, d_keys_input2, d_keys_output, size1, size2, + compare_op, stream, false + ) + ); } - else + HIP_CHECK(hipDeviceSynchronize()); + + for (auto _ : state) { - keys_input1 = get_random_data( - size1, - std::numeric_limits::min(), - std::numeric_limits::max() - ); - keys_input2 = get_random_data( - size2, - std::numeric_limits::min(), - std::numeric_limits::max() - ); + auto start = std::chrono::high_resolution_clock::now(); + + for(size_t i = 0; i < batch_size; i++) + { + HIP_CHECK( + rp::merge( + d_temporary_storage, temporary_storage_bytes, + d_keys_input1, d_keys_input2, d_keys_output, size1, size2, + compare_op, stream, false + ) + ); + } + HIP_CHECK(hipDeviceSynchronize()); + + auto end = std::chrono::high_resolution_clock::now(); + auto elapsed_seconds = + std::chrono::duration_cast>(end - start); + state.SetIterationTime(elapsed_seconds.count()); } - std::sort(keys_input1.begin(), keys_input1.end()); - std::sort(keys_input2.begin(), keys_input2.end()); + state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(key_type)); + state.SetItemsProcessed(state.iterations() * batch_size * size); + + HIP_CHECK(hipFree(d_temporary_storage)); + HIP_CHECK(hipFree(d_keys_input1)); + HIP_CHECK(hipFree(d_keys_input2)); + HIP_CHECK(hipFree(d_keys_output)); +} + +template +void run_merge_pairs_benchmark(benchmark::State& state, hipStream_t stream, size_t size) +{ + using key_type = Key; + using value_type = Value; + + const size_t size1 = size / 2; + const size_t size2 = size - size1; + + ::rocprim::less compare_op; + + // Generate data + std::vector keys_input1 = get_random_data(size1, 0, size); + std::vector keys_input2 = get_random_data(size2, 0, size); + std::sort(keys_input1.begin(), keys_input1.end(), compare_op); + std::sort(keys_input2.begin(), keys_input2.end(), compare_op); + std::vector values_input1(size1); + std::vector values_input2(size2); + std::iota(values_input1.begin(), values_input1.end(), 0); + std::iota(values_input2.begin(), values_input2.end(), size1); key_type * d_keys_input1; key_type * d_keys_input2; key_type * d_keys_output; + value_type * d_values_input1; + value_type * d_values_input2; + value_type * d_values_output; HIP_CHECK(hipMalloc(&d_keys_input1, size1 * sizeof(key_type))); HIP_CHECK(hipMalloc(&d_keys_input2, size2 * sizeof(key_type))); HIP_CHECK(hipMalloc(&d_keys_output, size * sizeof(key_type))); + HIP_CHECK(hipMalloc(&d_values_input1, size1 * sizeof(value_type))); + HIP_CHECK(hipMalloc(&d_values_input2, size2 * sizeof(value_type))); + HIP_CHECK(hipMalloc(&d_values_output, size * sizeof(value_type))); HIP_CHECK( hipMemcpy( d_keys_input1, keys_input1.data(), @@ -111,17 +198,16 @@ void run_merge_keys_benchmark(benchmark::State& state, hipStream_t stream, size_ hipMemcpyHostToDevice ) ); - HIP_CHECK(hipDeviceSynchronize()); - - ::rocprim::less lesser_op; void * d_temporary_storage = nullptr; size_t temporary_storage_bytes = 0; HIP_CHECK( rp::merge( d_temporary_storage, temporary_storage_bytes, - d_keys_input1, d_keys_input2, d_keys_output, size1, size2, - lesser_op, stream, false + d_keys_input1, d_keys_input2, d_keys_output, + d_values_input1, d_values_input2, d_values_output, + size1, size2, + compare_op, stream, false ) ); @@ -134,8 +220,10 @@ void run_merge_keys_benchmark(benchmark::State& state, hipStream_t stream, size_ HIP_CHECK( rp::merge( d_temporary_storage, temporary_storage_bytes, - d_keys_input1, d_keys_input2, d_keys_output, size1, size2, - lesser_op, stream, false + d_keys_input1, d_keys_input2, d_keys_output, + d_values_input1, d_values_input2, d_values_output, + size1, size2, + compare_op, stream, false ) ); } @@ -150,8 +238,10 @@ void run_merge_keys_benchmark(benchmark::State& state, hipStream_t stream, size_ HIP_CHECK( rp::merge( d_temporary_storage, temporary_storage_bytes, - d_keys_input1, d_keys_input2, d_keys_output, size1, size2, - lesser_op, stream, false + d_keys_input1, d_keys_input2, d_keys_output, + d_values_input1, d_values_input2, d_values_output, + size1, size2, + compare_op, stream, false ) ); } @@ -162,35 +252,29 @@ void run_merge_keys_benchmark(benchmark::State& state, hipStream_t stream, size_ std::chrono::duration_cast>(end - start); state.SetIterationTime(elapsed_seconds.count()); } - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(key_type)); + state.SetBytesProcessed(state.iterations() * batch_size * size * (sizeof(key_type) + sizeof(value_type))); state.SetItemsProcessed(state.iterations() * batch_size * size); HIP_CHECK(hipFree(d_temporary_storage)); HIP_CHECK(hipFree(d_keys_input1)); HIP_CHECK(hipFree(d_keys_input2)); HIP_CHECK(hipFree(d_keys_output)); + HIP_CHECK(hipFree(d_values_input1)); + HIP_CHECK(hipFree(d_values_input2)); + HIP_CHECK(hipFree(d_values_output)); } #define CREATE_MERGE_KEYS_BENCHMARK(Key) \ benchmark::RegisterBenchmark( \ - (std::string("merge_keys") + "<" #Key ">").c_str(), \ + (std::string("merge") + "<" #Key ">").c_str(), \ [=](benchmark::State& state) { run_merge_keys_benchmark(state, stream, size); } \ ) -void add_merge_keys_benchmarks(std::vector& benchmarks, - hipStream_t stream, - size_t size) -{ - std::vector bs = - { - CREATE_MERGE_KEYS_BENCHMARK(int), - CREATE_MERGE_KEYS_BENCHMARK(long long), - - CREATE_MERGE_KEYS_BENCHMARK(char), - CREATE_MERGE_KEYS_BENCHMARK(short), - }; - benchmarks.insert(benchmarks.end(), bs.begin(), bs.end()); -} +#define CREATE_MERGE_PAIRS_BENCHMARK(Key, Value) \ +benchmark::RegisterBenchmark( \ + (std::string("merge") + "<" #Key ", " #Value ">").c_str(), \ + [=](benchmark::State& state) { run_merge_pairs_benchmark(state, stream, size); } \ +) int main(int argc, char *argv[]) { @@ -212,9 +296,26 @@ int main(int argc, char *argv[]) HIP_CHECK(hipGetDeviceProperties(&devProp, device_id)); std::cout << "[HIP] Device name: " << devProp.name << std::endl; + using custom_int2 = custom_type; + using custom_double2 = custom_type; + // Add benchmarks - std::vector benchmarks; - add_merge_keys_benchmarks(benchmarks, stream, size); + std::vector benchmarks = + { + CREATE_MERGE_KEYS_BENCHMARK(int), + CREATE_MERGE_KEYS_BENCHMARK(long long), + CREATE_MERGE_KEYS_BENCHMARK(char), + CREATE_MERGE_KEYS_BENCHMARK(short), + CREATE_MERGE_KEYS_BENCHMARK(custom_int2), + CREATE_MERGE_KEYS_BENCHMARK(custom_double2), + + CREATE_MERGE_PAIRS_BENCHMARK(int, int), + CREATE_MERGE_PAIRS_BENCHMARK(long long, long long), + CREATE_MERGE_PAIRS_BENCHMARK(char, char), + CREATE_MERGE_PAIRS_BENCHMARK(short, short), + CREATE_MERGE_PAIRS_BENCHMARK(custom_int2, custom_int2), + CREATE_MERGE_PAIRS_BENCHMARK(custom_double2, custom_double2), + }; // Use manual timing for(auto& b : benchmarks) diff --git a/rocprim/include/rocprim/detail/binary_op_wrappers.hpp b/rocprim/include/rocprim/detail/binary_op_wrappers.hpp index e8466007a..4b3def3e9 100644 --- a/rocprim/include/rocprim/detail/binary_op_wrappers.hpp +++ b/rocprim/include/rocprim/detail/binary_op_wrappers.hpp @@ -73,12 +73,7 @@ struct headflag_scan_op_wrapper { static_assert(std::is_convertible::value, "F must be convertible to bool"); - #ifdef __cpp_lib_is_invocable - using value_type = typename std::invoke_result::type; - #else - using value_type = typename std::result_of::type; - #endif - using result_type = rocprim::tuple; + using result_type = rocprim::tuple; using input_type = result_type; ROCPRIM_HOST_DEVICE inline @@ -110,52 +105,6 @@ struct headflag_scan_op_wrapper BinaryFunction scan_op_; }; -// Wrapper for performing scan-by-key -template< - class V, - class K, - class BinaryFunction, - class KCompare = ::rocprim::equal_to -> -struct scan_by_key_op_wrapper -{ - #ifdef __cpp_lib_is_invocable - using value_type = typename std::invoke_result::type; - #else - using value_type = typename std::result_of::type; - #endif - using result_type = rocprim::tuple; - using input_type = result_type; - - ROCPRIM_HOST_DEVICE inline - scan_by_key_op_wrapper() = default; - - ROCPRIM_HOST_DEVICE inline - scan_by_key_op_wrapper(BinaryFunction scan_op, KCompare compare_keys_op = KCompare()) - : scan_op_(scan_op), compare_keys_op_(compare_keys_op) - { - } - - ROCPRIM_HOST_DEVICE inline - ~scan_by_key_op_wrapper() = default; - - ROCPRIM_HOST_DEVICE inline - result_type operator()(const input_type& t1, const input_type& t2) - { - if(compare_keys_op_(rocprim::get<1>(t1), rocprim::get<1>(t2))) - { - return rocprim::make_tuple( - scan_op_(rocprim::get<0>(t1), rocprim::get<0>(t2)), - rocprim::get<1>(t2) - ); - } - return t2; - } - -private: - BinaryFunction scan_op_; - KCompare compare_keys_op_; -}; template struct inequality_wrapper diff --git a/rocprim/include/rocprim/detail/match_result_type.hpp b/rocprim/include/rocprim/detail/match_result_type.hpp index 06036370a..75ba31b01 100644 --- a/rocprim/include/rocprim/detail/match_result_type.hpp +++ b/rocprim/include/rocprim/detail/match_result_type.hpp @@ -44,7 +44,7 @@ struct tuple_contains_type> : tuple_contains_type< template struct tuple_contains_type> : std::true_type {}; -template +template struct match_result_type { private: @@ -54,22 +54,11 @@ struct match_result_type using binary_result_type = typename std::result_of::type; #endif - // Fixed output_type in case OutputType is void or is a tuple containing void - static constexpr bool is_output_type_invalid = - std::is_void::value || tuple_contains_type::value; - using value_type = - typename std::conditional::type; - - // value_type is not a valid result_type if we can't covert it to binary_result_type - static constexpr bool is_value_type_valid = - std::is_convertible::value; - public: - using type = typename std::conditional::type; + using type = binary_result_type; }; } // end namespace detail END_ROCPRIM_NAMESPACE #endif // ROCPRIM_DETAIL_MATCH_RESULT_TYPE_HPP_ - diff --git a/rocprim/include/rocprim/device/detail/device_binary_search.hpp b/rocprim/include/rocprim/device/detail/device_binary_search.hpp new file mode 100644 index 000000000..02effbd16 --- /dev/null +++ b/rocprim/include/rocprim/device/detail/device_binary_search.hpp @@ -0,0 +1,125 @@ +// Copyright (c) 2019 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_BINARY_SEARCH_HPP_ +#define ROCPRIM_DEVICE_DETAIL_DEVICE_BINARY_SEARCH_HPP_ + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template +Size get_binary_search_middle(Size left, Size right) +{ + // Instead of `/ 2` we use `* 33 / 64`, i.e. the middle is slightly moved. + // This greatly reduces address aliasing and hence cache misses for (nearly-)power-of-two + // sizes of haystack (when addresses are mapped to the same cache line). + // For random needles and (nearly-)power-of-two sizes, this change increases performance + // 4-20 times making it equal to performance of arbitrary sizes of haystack. + // See https://www.pvk.ca/Blog/2012/07/30/binary-search-is-a-pathological-case-for-caches/ + const Size d = right - left; + return left + d / 2 + d / 64; +} + +template +ROCPRIM_DEVICE +Size lower_bound_n(RandomAccessIterator first, + Size size, + const T& value, + BinaryPredicate compare_op) +{ + Size left = 0; + Size right = size; + while(left < right) + { + const Size mid = get_binary_search_middle(left, right); + if(compare_op(first[mid], value)) + { + left = mid + 1; + } + else + { + right = mid; + } + } + return left; +} + +template +ROCPRIM_DEVICE +Size upper_bound_n(RandomAccessIterator first, + Size size, + const T& value, + BinaryPredicate compare_op) +{ + Size left = 0; + Size right = size; + while(left < right) + { + const Size mid = get_binary_search_middle(left, right); + if(compare_op(value, first[mid])) + { + right = mid; + } + else + { + left = mid + 1; + } + } + return left; +} + +struct lower_bound_search_op +{ + template + ROCPRIM_DEVICE + Size operator()(HaystackIterator haystack, Size size, const T& value, CompareOp compare_op) const + { + return lower_bound_n(haystack, size, value, compare_op); + } +}; + +struct upper_bound_search_op +{ + template + ROCPRIM_DEVICE + Size operator()(HaystackIterator haystack, Size size, const T& value, CompareOp compare_op) const + { + return upper_bound_n(haystack, size, value, compare_op); + } +}; + +struct binary_search_op +{ + template + ROCPRIM_DEVICE + bool operator()(HaystackIterator haystack, Size size, const T& value, CompareOp compare_op) const + { + const Size n = lower_bound_n(haystack, size, value, compare_op); + return n != size && !compare_op(value, haystack[n]); + } +}; + +} // end of detail namespace + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DETAIL_DEVICE_BINARY_SEARCH_HPP_ diff --git a/rocprim/include/rocprim/device/detail/device_merge.hpp b/rocprim/include/rocprim/device/detail/device_merge.hpp index 8e27a009b..40043ee6e 100644 --- a/rocprim/include/rocprim/device/detail/device_merge.hpp +++ b/rocprim/include/rocprim/device/detail/device_merge.hpp @@ -310,8 +310,7 @@ merge_values(unsigned int flat_id, #pragma unroll for(unsigned int i = 0; i < ItemsPerThread; ++i) { - unsigned int id = ItemsPerThread * i + flat_id; - if(id < count) + if(flat_id * ItemsPerThread + i < count) { values[i] = (index[i] < input1_size) ? values_input1[index[i]] : values_input2[index[i] - input1_size]; diff --git a/rocprim/include/rocprim/device/detail/device_partition.hpp b/rocprim/include/rocprim/device/detail/device_partition.hpp index 97d39fe29..e38998b76 100644 --- a/rocprim/include/rocprim/device/detail/device_partition.hpp +++ b/rocprim/include/rocprim/device/detail/device_partition.hpp @@ -402,7 +402,6 @@ template< select_method SelectMethod, bool OnlySelected, class Config, - class ResultType, class InputIterator, class FlagIterator, class OutputIterator, @@ -428,7 +427,7 @@ void partition_kernel_impl(InputIterator input, constexpr unsigned int items_per_block = block_size * items_per_thread; using offset_type = typename OffsetLookbackScanState::value_type; - using value_type = ResultType; + using value_type = typename std::iterator_traits::value_type; // Block primitives using block_load_value_type = ::rocprim::block_load< diff --git a/rocprim/include/rocprim/device/detail/device_reduce_by_key.hpp b/rocprim/include/rocprim/device/detail/device_reduce_by_key.hpp index 286c72c5b..968023c6a 100644 --- a/rocprim/include/rocprim/device/detail/device_reduce_by_key.hpp +++ b/rocprim/include/rocprim/device/detail/device_reduce_by_key.hpp @@ -41,20 +41,18 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template +template struct carry_out { ROCPRIM_DEVICE inline carry_out& operator=(carry_out rhs) { - key = rhs.key; value = rhs.value; destination = rhs.destination; next_has_carry_in = rhs.next_has_carry_in; return *this; } - Key key; Value value; // carry-out of the current batch unsigned int destination; bool next_has_carry_in; // the next batch has carry-in (i.e. it continues the last segment from the current batch) @@ -311,7 +309,6 @@ template< unsigned int ItemsPerThread, class KeysInputIterator, class ValuesInputIterator, - class Key, class Result, class UniqueOutputIterator, class AggregatesOutputIterator, @@ -323,7 +320,7 @@ void reduce_by_key(KeysInputIterator keys_input, ValuesInputIterator values_input, unsigned int size, const unsigned int * unique_starts, - carry_out * carry_outs, + carry_out * carry_outs, Result * leading_aggregates, UniqueOutputIterator unique_output, AggregatesOutputIterator aggregates_output, @@ -334,7 +331,7 @@ void reduce_by_key(KeysInputIterator keys_input, { constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; - using key_type = Key; + using key_type = typename std::iterator_traits::value_type; using result_type = Result; using keys_load_type = ::rocprim::block_load< @@ -498,7 +495,6 @@ void reduce_by_key(KeysInputIterator keys_input, if(bi == blocks_per_batch - 1) { // Save carry-out of the last block of the current batch - carry_outs[batch_id].key = keys[ItemsPerThread - 1]; carry_outs[batch_id].value = values[ItemsPerThread - 1]; carry_outs[batch_id].destination = block_start + ranks[ItemsPerThread - 1]; carry_outs[batch_id].next_has_carry_in = !tail_flags[ItemsPerThread - 1]; @@ -549,24 +545,20 @@ void reduce_by_key(KeysInputIterator keys_input, template< unsigned int BlockSize, unsigned int ItemsPerThread, - class Key, class Result, class AggregatesOutputIterator, - class KeyCompareFunction, class BinaryFunction > ROCPRIM_DEVICE inline -void scan_and_scatter_carry_outs(const carry_out * carry_outs, +void scan_and_scatter_carry_outs(const carry_out * carry_outs, const Result * leading_aggregates, AggregatesOutputIterator aggregates_output, - KeyCompareFunction key_compare_op, BinaryFunction reduce_op, unsigned int batches) { - using key_type = Key; using result_type = Result; - using discontinuity_type = ::rocprim::block_discontinuity; + using discontinuity_type = ::rocprim::block_discontinuity; using scan_type = ::rocprim::block_scan, BlockSize>; ROCPRIM_SHARED_MEMORY struct @@ -577,25 +569,27 @@ void scan_and_scatter_carry_outs(const carry_out * carry_outs, const unsigned int flat_id = ::rocprim::flat_block_thread_id(); - carry_out cs[ItemsPerThread]; + carry_out cs[ItemsPerThread]; block_load_direct_blocked(flat_id, carry_outs, cs, batches - 1); - key_type keys[ItemsPerThread]; + unsigned int destinations[ItemsPerThread]; result_type values[ItemsPerThread]; for(unsigned int i = 0; i < ItemsPerThread; i++) { - keys[i] = cs[i].key; + destinations[i] = cs[i].destination; values[i] = cs[i].value; } bool head_flags[ItemsPerThread]; bool tail_flags[ItemsPerThread]; - const key_type successor_key = keys[ItemsPerThread - 1]; // Do not always flag the last item in the block - + ::rocprim::equal_to compare_op; + // If a carry-out of the current batch has the same destination as previous batches, + // then we need to scan its value with values of those previous batches. discontinuity_type().flag_heads_and_tails( head_flags, tail_flags, - successor_key, keys, - guarded_key_flag_op(key_compare_op, batches - 1), + destinations[ItemsPerThread - 1], // Do not always flag the last item in the block + destinations, + guarded_key_flag_op(compare_op, batches - 1), storage.discontinuity ); @@ -614,7 +608,7 @@ void scan_and_scatter_carry_outs(const carry_out * carry_outs, { if(tail_flags[i]) { - const unsigned int dst = cs[i].destination; + const unsigned int dst = destinations[i]; const result_type aggregate = pairs[i].value; if(cs[i].next_has_carry_in) { diff --git a/rocprim/include/rocprim/device/detail/device_scan_lookback.hpp b/rocprim/include/rocprim/device/detail/device_scan_lookback.hpp index d7710a3a8..5b6df0685 100644 --- a/rocprim/include/rocprim/device/detail/device_scan_lookback.hpp +++ b/rocprim/include/rocprim/device/detail/device_scan_lookback.hpp @@ -241,6 +241,7 @@ void lookback_scan_kernel_impl(InputIterator input, input + block_offset, values, valid_in_last_block, + *(input + block_offset), storage.load ); } diff --git a/rocprim/include/rocprim/device/detail/device_scan_reduce_then_scan.hpp b/rocprim/include/rocprim/device/detail/device_scan_reduce_then_scan.hpp index 47be6a07d..775616b24 100644 --- a/rocprim/include/rocprim/device/detail/device_scan_reduce_then_scan.hpp +++ b/rocprim/include/rocprim/device/detail/device_scan_reduce_then_scan.hpp @@ -141,6 +141,7 @@ void single_scan_kernel_impl(InputIterator input, input, values, input_size, + *(input), storage.load ); ::rocprim::syncthreads(); // sync threads to reuse shared memory @@ -373,6 +374,7 @@ void final_scan_kernel_impl(InputIterator input, input + block_offset, values, valid_in_last_block, + *(input + block_offset), storage.load ); } diff --git a/rocprim/include/rocprim/device/detail/device_transform.hpp b/rocprim/include/rocprim/device/detail/device_transform.hpp index fb15f7821..24febc211 100644 --- a/rocprim/include/rocprim/device/detail/device_transform.hpp +++ b/rocprim/include/rocprim/device/detail/device_transform.hpp @@ -96,14 +96,13 @@ void transform_kernel_impl(InputIterator input, const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); const unsigned int flat_block_id = ::rocprim::detail::block_id<0>(); - const unsigned int block_offset = flat_block_id * BlockSize * ItemsPerThread; - const unsigned int number_of_blocks = (input_size + items_per_block - 1)/items_per_block; - auto valid_in_last_block = input_size - items_per_block * (number_of_blocks - 1); + const unsigned int block_offset = flat_block_id * items_per_block; + const unsigned int number_of_blocks = ::rocprim::detail::grid_size<0>(); + const unsigned int valid_in_last_block = input_size - block_offset; input_type input_values[ItemsPerThread]; result_type output_values[ItemsPerThread]; - // load input values into values if(flat_block_id == (number_of_blocks - 1)) // last block { block_load_direct_striped( @@ -112,25 +111,16 @@ void transform_kernel_impl(InputIterator input, input_values, valid_in_last_block ); - } - else - { - block_load_direct_striped( - flat_id, - input + block_offset, - input_values - ); - } - #pragma unroll - for(unsigned int i = 0; i < ItemsPerThread; i++) - { - output_values[i] = transform_op(input_values[i]); - } + #pragma unroll + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + if(BlockSize * i + flat_id < valid_in_last_block) + { + output_values[i] = transform_op(input_values[i]); + } + } - // Save values into output array - if(flat_block_id == (number_of_blocks - 1)) // last block - { block_store_direct_striped( flat_id, output + block_offset, @@ -140,6 +130,18 @@ void transform_kernel_impl(InputIterator input, } else { + block_load_direct_striped( + flat_id, + input + block_offset, + input_values + ); + + #pragma unroll + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + output_values[i] = transform_op(input_values[i]); + } + block_store_direct_striped( flat_id, output + block_offset, diff --git a/rocprim/include/rocprim/device/device_binary_search_hc.hpp b/rocprim/include/rocprim/device/device_binary_search_hc.hpp new file mode 100644 index 000000000..be958876c --- /dev/null +++ b/rocprim/include/rocprim/device/device_binary_search_hc.hpp @@ -0,0 +1,175 @@ +// Copyright (c) 2019 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_BINARY_SEARCH_HC_HPP_ +#define ROCPRIM_DEVICE_DEVICE_BINARY_SEARCH_HC_HPP_ + +#include +#include + +#include "../config.hpp" +#include "../detail/various.hpp" + +#include "detail/device_binary_search.hpp" + +#include "device_transform_hc.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +/// \addtogroup devicemodule_hc +/// @{ + +namespace detail +{ + +template< + class Config, + class HaystackIterator, + class NeedlesIterator, + class OutputIterator, + class SearchFunction, + class CompareFunction +> +inline +void binary_search(void * temporary_storage, + size_t& storage_size, + HaystackIterator haystack, + NeedlesIterator needles, + OutputIterator output, + size_t haystack_size, + size_t needles_size, + SearchFunction search_op, + CompareFunction compare_op, + hc::accelerator_view acc_view, + bool debug_synchronous) +{ + using value_type = typename std::iterator_traits::value_type; + + if(temporary_storage == nullptr) + { + // Make sure user won't try to allocate 0 bytes memory, otherwise + // user may again pass nullptr as temporary_storage + storage_size = 4; + return; + } + + transform( + needles, output, + needles_size, + [haystack, haystack_size, search_op, compare_op](const value_type& value) + { + return search_op(haystack, haystack_size, value, compare_op); + }, + acc_view, debug_synchronous + ); +} + +} // end of detail namespace + +template< + class Config = default_config, + class HaystackIterator, + class NeedlesIterator, + class OutputIterator, + class CompareFunction = ::rocprim::less<> +> +inline +void lower_bound(void * temporary_storage, + size_t& storage_size, + HaystackIterator haystack, + NeedlesIterator needles, + OutputIterator output, + size_t haystack_size, + size_t needles_size, + CompareFunction compare_op = CompareFunction(), + hc::accelerator_view acc_view = hc::accelerator().get_default_view(), + bool debug_synchronous = false) +{ + detail::binary_search( + temporary_storage, storage_size, + haystack, needles, output, + haystack_size, needles_size, + detail::lower_bound_search_op(), compare_op, + acc_view, debug_synchronous + ); +} + +template< + class Config = default_config, + class HaystackIterator, + class NeedlesIterator, + class OutputIterator, + class CompareFunction = ::rocprim::less<> +> +inline +void upper_bound(void * temporary_storage, + size_t& storage_size, + HaystackIterator haystack, + NeedlesIterator needles, + OutputIterator output, + size_t haystack_size, + size_t needles_size, + CompareFunction compare_op = CompareFunction(), + hc::accelerator_view acc_view = hc::accelerator().get_default_view(), + bool debug_synchronous = false) +{ + detail::binary_search( + temporary_storage, storage_size, + haystack, needles, output, + haystack_size, needles_size, + detail::upper_bound_search_op(), compare_op, + acc_view, debug_synchronous + ); +} + +template< + class Config = default_config, + class HaystackIterator, + class NeedlesIterator, + class OutputIterator, + class CompareFunction = ::rocprim::less<> +> +inline +void binary_search(void * temporary_storage, + size_t& storage_size, + HaystackIterator haystack, + NeedlesIterator needles, + OutputIterator output, + size_t haystack_size, + size_t needles_size, + CompareFunction compare_op = CompareFunction(), + hc::accelerator_view acc_view = hc::accelerator().get_default_view(), + bool debug_synchronous = false) +{ + detail::binary_search( + temporary_storage, storage_size, + haystack, needles, output, + haystack_size, needles_size, + detail::binary_search_op(), compare_op, + acc_view, debug_synchronous + ); +} + +/// @} +// end of group devicemodule_hc + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DEVICE_BINARY_SEARCH_HC_HPP_ diff --git a/rocprim/include/rocprim/device/device_binary_search_hip.hpp b/rocprim/include/rocprim/device/device_binary_search_hip.hpp new file mode 100644 index 000000000..6b7734ed4 --- /dev/null +++ b/rocprim/include/rocprim/device/device_binary_search_hip.hpp @@ -0,0 +1,175 @@ +// Copyright (c) 2019 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_BINARY_SEARCH_HIP_HPP_ +#define ROCPRIM_DEVICE_DEVICE_BINARY_SEARCH_HIP_HPP_ + +#include +#include + +#include "../config.hpp" +#include "../detail/various.hpp" + +#include "detail/device_binary_search.hpp" + +#include "device_transform_hip.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +/// \addtogroup devicemodule_hip +/// @{ + +namespace detail +{ + +template< + class Config, + class HaystackIterator, + class NeedlesIterator, + class OutputIterator, + class SearchFunction, + class CompareFunction +> +inline +hipError_t binary_search(void * temporary_storage, + size_t& storage_size, + HaystackIterator haystack, + NeedlesIterator needles, + OutputIterator output, + size_t haystack_size, + size_t needles_size, + SearchFunction search_op, + CompareFunction compare_op, + hipStream_t stream, + bool debug_synchronous) +{ + using value_type = typename std::iterator_traits::value_type; + + if(temporary_storage == nullptr) + { + // Make sure user won't try to allocate 0 bytes memory, otherwise + // user may again pass nullptr as temporary_storage + storage_size = 4; + return hipSuccess; + } + + return transform( + needles, output, + needles_size, + [haystack, haystack_size, search_op, compare_op](const value_type& value) + { + return search_op(haystack, haystack_size, value, compare_op); + }, + stream, debug_synchronous + ); +} + +} // end of detail namespace + +template< + class Config = default_config, + class HaystackIterator, + class NeedlesIterator, + class OutputIterator, + class CompareFunction = ::rocprim::less<> +> +inline +hipError_t lower_bound(void * temporary_storage, + size_t& storage_size, + HaystackIterator haystack, + NeedlesIterator needles, + OutputIterator output, + size_t haystack_size, + size_t needles_size, + CompareFunction compare_op = CompareFunction(), + hipStream_t stream = 0, + bool debug_synchronous = false) +{ + return detail::binary_search( + temporary_storage, storage_size, + haystack, needles, output, + haystack_size, needles_size, + detail::lower_bound_search_op(), compare_op, + stream, debug_synchronous + ); +} + +template< + class Config = default_config, + class HaystackIterator, + class NeedlesIterator, + class OutputIterator, + class CompareFunction = ::rocprim::less<> +> +inline +hipError_t upper_bound(void * temporary_storage, + size_t& storage_size, + HaystackIterator haystack, + NeedlesIterator needles, + OutputIterator output, + size_t haystack_size, + size_t needles_size, + CompareFunction compare_op = CompareFunction(), + hipStream_t stream = 0, + bool debug_synchronous = false) +{ + return detail::binary_search( + temporary_storage, storage_size, + haystack, needles, output, + haystack_size, needles_size, + detail::upper_bound_search_op(), compare_op, + stream, debug_synchronous + ); +} + +template< + class Config = default_config, + class HaystackIterator, + class NeedlesIterator, + class OutputIterator, + class CompareFunction = ::rocprim::less<> +> +inline +hipError_t binary_search(void * temporary_storage, + size_t& storage_size, + HaystackIterator haystack, + NeedlesIterator needles, + OutputIterator output, + size_t haystack_size, + size_t needles_size, + CompareFunction compare_op = CompareFunction(), + hipStream_t stream = 0, + bool debug_synchronous = false) +{ + return detail::binary_search( + temporary_storage, storage_size, + haystack, needles, output, + haystack_size, needles_size, + detail::binary_search_op(), compare_op, + stream, debug_synchronous + ); +} + +/// @} +// end of group devicemodule_hip + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DEVICE_BINARY_SEARCH_HIP_HPP_ diff --git a/rocprim/include/rocprim/device/device_merge_hc.hpp b/rocprim/include/rocprim/device/device_merge_hc.hpp index 6805a40c7..0bad903c5 100644 --- a/rocprim/include/rocprim/device/device_merge_hc.hpp +++ b/rocprim/include/rocprim/device/device_merge_hc.hpp @@ -152,8 +152,8 @@ void merge_impl(void * temporary_storage, /// \brief HC parallel merge primitive for device level. /// -/// \p merge function performs a device-wide merge of keys. -/// Function merges two ordered sets of input keys based on comparison function. +/// \p merge function performs a device-wide merge. +/// Function merges two ordered sets of input values based on comparison function. /// /// \par Overview /// * The contents of the inputs are not altered by the merging function. @@ -163,20 +163,20 @@ void merge_impl(void * temporary_storage, /// /// \tparam Config - [optional] configuration of the primitive. It can be \p merge_config or /// a custom class with the same members. -/// \tparam KeysInputIterator1 - random-access iterator type of the first input range. Must meet the +/// \tparam InputIterator1 - random-access iterator type of the first input range. Must meet the /// requirements of a C++ InputIterator concept. It can be a simple pointer type. -/// \tparam KeysInputIterator2 - random-access iterator type of the second input range. Must meet the +/// \tparam InputIterator2 - random-access iterator type of the second input range. Must meet the /// requirements of a C++ InputIterator concept. It can be a simple pointer type. -/// \tparam KeysOutputIterator - random-access iterator type of the output range. Must meet the +/// \tparam OutputIterator - random-access iterator type of the output range. Must meet the /// requirements of a C++ OutputIterator concept. It can be a simple pointer type. /// /// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When /// a null pointer is passed, the required allocation size (in bytes) is written to /// \p storage_size and function returns without performing the sort operation. /// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. -/// \param [in] keys_input1 - pointer to the first element in the first range to merge. -/// \param [in] keys_input2 - pointer to the first element in the second range to merge. -/// \param [out] keys_output - pointer to the first element in the output range. +/// \param [in] input1 - iterator to the first element in the first range to merge. +/// \param [in] input2 - iterator to the first element in the second range to merge. +/// \param [out] output - iterator to the first element in the output range. /// \param [in] input1_size - number of element in the first input range. /// \param [in] input2_size - number of element in the second input range. /// \param [in] compare_function - binary operation function object that will be used for comparison. @@ -200,7 +200,8 @@ void merge_impl(void * temporary_storage, /// hc::accelerator_view acc_view = ...; /// /// // Prepare input and output (declare pointers, allocate device memory etc.) -/// size_t input_size; // e.g., 4 +/// size_t input_size1; // e.g., 4 +/// size_t input_size2; // e.g., 4 /// hc::array input1; // e.g., [0, 1, 2, 3] /// hc::array input2; // e.g., [0, 1, 2, 3] /// hc::array output; // empty array of 8 elements @@ -211,7 +212,7 @@ void merge_impl(void * temporary_storage, /// rocprim::merge( /// temporary_storage_ptr, temporary_storage_size_bytes, /// input1.accelerator_pointer(), input2.accelerator_pointer(), -/// output.accelerator_pointer(), input_size, input_size +/// output.accelerator_pointer(), input_size1, input_size2 /// ); /// /// // allocate temporary storage @@ -221,9 +222,135 @@ void merge_impl(void * temporary_storage, /// rocprim::merge( /// temporary_storage_ptr, temporary_storage_size_bytes, /// input1.accelerator_pointer(), input2.accelerator_pointer(), -/// output.accelerator_pointer(), input_size, input_size +/// output.accelerator_pointer(), input_size1, input_size2 +/// ); +/// // output: [0, 0, 1, 1, 2, 2, 3, 3] +/// \endcode +/// \endparblock +template< + class Config = default_config, + class InputIterator1, + class InputIterator2, + class OutputIterator, + class BinaryFunction = ::rocprim::less::value_type> +> +inline +void merge(void * temporary_storage, + size_t& storage_size, + InputIterator1 input1, + InputIterator2 input2, + OutputIterator output, + const size_t input1_size, + const size_t input2_size, + BinaryFunction compare_function = BinaryFunction(), + hc::accelerator_view acc_view = hc::accelerator().get_default_view(), + bool debug_synchronous = false) +{ + empty_type * values = nullptr; + detail::merge_impl( + temporary_storage, storage_size, + input1, input2, output, + values, values, values, + input1_size, input2_size, compare_function, + acc_view, debug_synchronous + ); +} + +/// \brief HC parallel merge primitive for device level. +/// +/// \p merge function performs a device-wide merge of (key, value) pairs. +/// Function merges two ordered sets of input keys and corresponding values +/// based on key comparison function. +/// +/// \par Overview +/// * The contents of the inputs are not altered by the merging function. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * Accepts custom compare_functions for merging across the device. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be \p merge_config or +/// a custom class with the same members. +/// \tparam KeysInputIterator1 - random-access iterator type of the first keys input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam KeysInputIterator2 - random-access iterator type of the second keys input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam KeysOutputIterator - random-access iterator type of the keys output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam ValuesInputIterator1 - random-access iterator type of the first values input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam ValuesInputIterator2 - random-access iterator type of the second values input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam ValuesOutputIterator - random-access iterator type of the values output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the sort operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] keys_input1 - iterator to the first key in the first range to merge. +/// \param [in] keys_input2 - iterator to the first key in the second range to merge. +/// \param [out] keys_output - iterator to the first key in the output range. +/// \param [in] values_input1 - iterator to the first value in the first range to merge. +/// \param [in] values_input2 - iterator to the first value in the second range to merge. +/// \param [out] values_output - iterator to the first value in the output range. +/// \param [in] input1_size - number of element in the first input range. +/// \param [in] input2_size - number of element in the second input range. +/// \param [in] compare_function - binary operation function object that will be used for key comparison. +/// The signature of the function should be equivalent to the following: +/// bool f(const T &a, const T &b);. The signature does not need to have +/// const &, but function object must not modify the objects passed to it. +/// The default value is \p BinaryFunction(). +/// \param [in] acc_view - [optional] \p hc::accelerator_view object. The default value +/// is \p hc::accelerator().get_default_view() (default view of the default accelerator). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \par Example +/// \parblock +/// In this example a device-level ascending merge is performed on an array of +/// \p int values. +/// +/// \code{.cpp} +/// #include +/// +/// hc::accelerator_view acc_view = ...; +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size1; // e.g., 4 +/// size_t input_size2; // e.g., 4 +/// hc::array keys_input1; // e.g., [0, 1, 2, 3] +/// hc::array keys_input2; // e.g., [0, 1, 2, 3] +/// hc::array keys_output; // empty array of 8 elements +/// hc::array values_input1; // e.g., [10, 11, 12, 13] +/// hc::array values_input2; // e.g., [20, 21, 22, 23] +/// hc::array values_output; // empty array of 8 elements +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::merge( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input1.accelerator_pointer(), keys_input2.accelerator_pointer(), +/// keys_output.accelerator_pointer(), +/// values_input1.accelerator_pointer(), values_input2.accelerator_pointer(), +/// values_output.accelerator_pointer(), +/// input_size1, input_size2 +/// ); +/// +/// // allocate temporary storage +/// hc::array temporary_storage(temporary_storage_size_bytes, acc_view); +/// +/// // perform merge +/// rocprim::merge( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input1.accelerator_pointer(), keys_input2.accelerator_pointer(), +/// keys_output.accelerator_pointer(), +/// values_input1.accelerator_pointer(), values_input2.accelerator_pointer(), +/// values_output.accelerator_pointer(), +/// input_size1, input_size2 /// ); /// // keys_output: [0, 0, 1, 1, 2, 2, 3, 3] +/// // values_output: [10, 20, 11, 21, 12, 22, 13, 23] /// \endcode /// \endparblock template< @@ -231,6 +358,9 @@ template< class KeysInputIterator1, class KeysInputIterator2, class KeysOutputIterator, + class ValuesInputIterator1, + class ValuesInputIterator2, + class ValuesOutputIterator, class BinaryFunction = ::rocprim::less::value_type> > inline @@ -239,17 +369,19 @@ void merge(void * temporary_storage, KeysInputIterator1 keys_input1, KeysInputIterator2 keys_input2, KeysOutputIterator keys_output, + ValuesInputIterator1 values_input1, + ValuesInputIterator2 values_input2, + ValuesOutputIterator values_output, const size_t input1_size, const size_t input2_size, BinaryFunction compare_function = BinaryFunction(), hc::accelerator_view acc_view = hc::accelerator().get_default_view(), bool debug_synchronous = false) { - empty_type * values = nullptr; detail::merge_impl( temporary_storage, storage_size, keys_input1, keys_input2, keys_output, - values, values, values, + values_input1, values_input2, values_output, input1_size, input2_size, compare_function, acc_view, debug_synchronous ); diff --git a/rocprim/include/rocprim/device/device_merge_hip.hpp b/rocprim/include/rocprim/device/device_merge_hip.hpp index f235eaf3f..852188730 100644 --- a/rocprim/include/rocprim/device/device_merge_hip.hpp +++ b/rocprim/include/rocprim/device/device_merge_hip.hpp @@ -198,8 +198,8 @@ hipError_t merge_impl(void * temporary_storage, /// \brief HIP parallel merge primitive for device level. /// -/// \p merge function performs a device-wide merge of keys. -/// Function merges two ordered sets of input keys based on comparison function. +/// \p merge function performs a device-wide merge. +/// Function merges two ordered sets of input values based on comparison function. /// /// \par Overview /// * The contents of the inputs are not altered by the merging function. @@ -209,20 +209,20 @@ hipError_t merge_impl(void * temporary_storage, /// /// \tparam Config - [optional] configuration of the primitive. It can be \p merge_config or /// a custom class with the same members. -/// \tparam KeysInputIterator1 - random-access iterator type of the first input range. Must meet the +/// \tparam InputIterator1 - random-access iterator type of the first input range. Must meet the /// requirements of a C++ InputIterator concept. It can be a simple pointer type. -/// \tparam KeysInputIterator2 - random-access iterator type of the second input range. Must meet the +/// \tparam InputIterator2 - random-access iterator type of the second input range. Must meet the /// requirements of a C++ InputIterator concept. It can be a simple pointer type. -/// \tparam KeysOutputIterator - random-access iterator type of the output range. Must meet the +/// \tparam OutputIterator - random-access iterator type of the output range. Must meet the /// requirements of a C++ OutputIterator concept. It can be a simple pointer type. /// /// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When /// a null pointer is passed, the required allocation size (in bytes) is written to /// \p storage_size and function returns without performing the sort operation. /// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. -/// \param [in] keys_input1 - pointer to the first element in the first range to merge. -/// \param [in] keys_input2 - pointer to the first element in the second range to merge. -/// \param [out] keys_output - pointer to the first element in the output range. +/// \param [in] input1 - iterator to the first element in the first range to merge. +/// \param [in] input2 - iterator to the first element in the second range to merge. +/// \param [out] output - iterator to the first element in the output range. /// \param [in] input1_size - number of element in the first input range. /// \param [in] input2_size - number of element in the second input range. /// \param [in] compare_function - binary operation function object that will be used for comparison. @@ -246,7 +246,8 @@ hipError_t merge_impl(void * temporary_storage, /// #include /// /// // Prepare input and output (declare pointers, allocate device memory etc.) -/// size_t input_size; // e.g., 4 +/// size_t input_size1; // e.g., 4 +/// size_t input_size2; // e.g., 4 /// int * input1; // e.g., [0, 1, 2, 3] /// int * input2; // e.g., [0, 1, 2, 3] /// int * output; // empty array of 8 elements @@ -256,7 +257,7 @@ hipError_t merge_impl(void * temporary_storage, /// // Get required size of the temporary storage /// rocprim::merge( /// temporary_storage_ptr, temporary_storage_size_bytes, -/// input1, input2, output, input_size, input_size +/// input1, input2, output, input_size1, input_size2 /// ); /// /// // allocate temporary storage @@ -265,9 +266,131 @@ hipError_t merge_impl(void * temporary_storage, /// // perform merge /// rocprim::merge( /// temporary_storage_ptr, temporary_storage_size_bytes, -/// input1, input2, output, input_size, input_size +/// input1, input2, output, input_size1, input_size2 +/// ); +/// // output: [0, 0, 1, 1, 2, 2, 3, 3] +/// \endcode +/// \endparblock +template< + class Config = default_config, + class InputIterator1, + class InputIterator2, + class OutputIterator, + class BinaryFunction = ::rocprim::less::value_type> +> +inline +hipError_t merge(void * temporary_storage, + size_t& storage_size, + InputIterator1 input1, + InputIterator2 input2, + OutputIterator output, + const size_t input1_size, + const size_t input2_size, + BinaryFunction compare_function = BinaryFunction(), + const hipStream_t stream = 0, + bool debug_synchronous = false) +{ + empty_type * values = nullptr; + return detail::merge_impl( + temporary_storage, storage_size, + input1, input2, output, + values, values, values, + input1_size, input2_size, compare_function, + stream, debug_synchronous + ); +} + +/// \brief HIP parallel merge primitive for device level. +/// +/// \p merge function performs a device-wide merge of (key, value) pairs. +/// Function merges two ordered sets of input keys and corresponding values +/// based on key comparison function. +/// +/// \par Overview +/// * The contents of the inputs are not altered by the merging function. +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * Accepts custom compare_functions for merging across the device. +/// +/// \tparam Config - [optional] configuration of the primitive. It can be \p merge_config or +/// a custom class with the same members. +/// \tparam KeysInputIterator1 - random-access iterator type of the first keys input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam KeysInputIterator2 - random-access iterator type of the second keys input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam KeysOutputIterator - random-access iterator type of the keys output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam ValuesInputIterator1 - random-access iterator type of the first values input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam ValuesInputIterator2 - random-access iterator type of the second values input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam ValuesOutputIterator - random-access iterator type of the values output range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the sort operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] keys_input1 - iterator to the first key in the first range to merge. +/// \param [in] keys_input2 - iterator to the first key in the second range to merge. +/// \param [out] keys_output - iterator to the first key in the output range. +/// \param [in] values_input1 - iterator to the first value in the first range to merge. +/// \param [in] values_input2 - iterator to the first value in the second range to merge. +/// \param [out] values_output - iterator to the first value in the output range. +/// \param [in] input1_size - number of element in the first input range. +/// \param [in] input2_size - number of element in the second input range. +/// \param [in] compare_function - binary operation function object that will be used for key comparison. +/// The signature of the function should be equivalent to the following: +/// bool f(const T &a, const T &b);. The signature does not need to have +/// const &, but function object must not modify the objects passed to it. +/// The default value is \p BinaryFunction(). +/// \param [in] stream - [optional] HIP stream object. Default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is \p false. +/// +/// \returns \p hipSuccess (\p 0) after successful sort; otherwise a HIP runtime error of +/// type \p hipError_t. +/// +/// \par Example +/// \parblock +/// In this example a device-level ascending merge is performed on an array of +/// \p int values. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size1; // e.g., 4 +/// size_t input_size2; // e.g., 4 +/// int * keys_input1; // e.g., [0, 1, 2, 3] +/// int * keys_input2; // e.g., [0, 1, 2, 3] +/// int * keys_output; // empty array of 8 elements +/// int * values_input1; // e.g., [10, 11, 12, 13] +/// int * values_input2; // e.g., [20, 21, 22, 23] +/// int * values_output; // empty array of 8 elements +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::merge( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input1, keys_input2, keys_output, +/// values_input1, values_input2, values_output, +// input_size1, input_size2 +/// ); +/// +/// // allocate temporary storage +/// hipMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform merge +/// rocprim::merge( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// keys_input1, keys_input2, keys_output, +/// values_input1, values_input2, values_output, +// input_size1, input_size2 /// ); /// // keys_output: [0, 0, 1, 1, 2, 2, 3, 3] +/// // values_output: [10, 20, 11, 21, 12, 22, 13, 23] /// \endcode /// \endparblock template< @@ -275,6 +398,9 @@ template< class KeysInputIterator1, class KeysInputIterator2, class KeysOutputIterator, + class ValuesInputIterator1, + class ValuesInputIterator2, + class ValuesOutputIterator, class BinaryFunction = ::rocprim::less::value_type> > inline @@ -283,17 +409,19 @@ hipError_t merge(void * temporary_storage, KeysInputIterator1 keys_input1, KeysInputIterator2 keys_input2, KeysOutputIterator keys_output, + ValuesInputIterator1 values_input1, + ValuesInputIterator2 values_input2, + ValuesOutputIterator values_output, const size_t input1_size, const size_t input2_size, BinaryFunction compare_function = BinaryFunction(), const hipStream_t stream = 0, bool debug_synchronous = false) { - empty_type * values = nullptr; return detail::merge_impl( temporary_storage, storage_size, keys_input1, keys_input2, keys_output, - values, values, values, + values_input1, values_input2, values_output, input1_size, input2_size, compare_function, stream, debug_synchronous ); diff --git a/rocprim/include/rocprim/device/device_partition_hc.hpp b/rocprim/include/rocprim/device/device_partition_hc.hpp index 1b2ebcd9d..6dea99331 100644 --- a/rocprim/include/rocprim/device/device_partition_hc.hpp +++ b/rocprim/include/rocprim/device/device_partition_hc.hpp @@ -28,7 +28,6 @@ #include "../functional.hpp" #include "../type_traits.hpp" #include "../detail/various.hpp" -#include "../detail/match_result_type.hpp" #include "device_select_config.hpp" #include "detail/device_partition.hpp" @@ -81,31 +80,11 @@ void partition_impl(void * temporary_storage, { using offset_type = unsigned int; using input_type = typename std::iterator_traits::value_type; - using output_type = typename std::iterator_traits::value_type; - // Fix for cases when output_type is void (there's no sizeof(void)), it's - // a tuple which contains an item of type void, or is not convertible to - // input_type which is used in InequalityOp - static constexpr bool is_output_type_voidlike = - std::is_same::value || tuple_contains_type::value; - // If output_type is voidlike we don't want std::is_convertible to - // be evaluated, it leads to errors if input_type is a tuple - using is_output_type_convertible = typename std::conditional< - is_output_type_voidlike, std::false_type, std::is_convertible - >::type; - static constexpr bool is_output_type_invalid = - is_output_type_voidlike || !(is_output_type_convertible::value); - using value_type = typename std::conditional< - is_output_type_invalid, input_type, output_type - >::type; - // Use smaller type for private storage - using result_type = typename std::conditional< - (sizeof(value_type) > sizeof(input_type)), input_type, value_type - >::type; // Get default config if Config is default_config using config = default_or_custom_config< Config, - default_select_config + default_select_config >; using offset_scan_state_type = detail::lookback_scan_state; @@ -170,7 +149,7 @@ void partition_impl(void * temporary_storage, hc::tiled_extent<1>(grid_size, block_size), [=](hc::tiled_index<1>) [[hc]] { - partition_kernel_impl( + partition_kernel_impl( input, flags, output, selected_count_output, size, predicate, inequality_op, offset_scan_state, number_of_blocks, ordered_bid ); diff --git a/rocprim/include/rocprim/device/device_partition_hip.hpp b/rocprim/include/rocprim/device/device_partition_hip.hpp index 811b39b57..dd0a72f34 100644 --- a/rocprim/include/rocprim/device/device_partition_hip.hpp +++ b/rocprim/include/rocprim/device/device_partition_hip.hpp @@ -28,7 +28,6 @@ #include "../functional.hpp" #include "../type_traits.hpp" #include "../detail/various.hpp" -#include "../detail/match_result_type.hpp" #include "device_select_config.hpp" #include "detail/device_partition.hpp" @@ -45,7 +44,6 @@ template< select_method SelectMethod, bool OnlySelected, class Config, - class ResultType, class InputIterator, class FlagIterator, class OutputIterator, @@ -66,7 +64,7 @@ void partition_kernel(InputIterator input, const unsigned int number_of_blocks, ordered_block_id ordered_bid) { - partition_kernel_impl( + partition_kernel_impl( input, flags, output, selected_count_output, size, predicate, inequality_op, offset_scan_state, number_of_blocks, ordered_bid ); @@ -137,31 +135,11 @@ hipError_t partition_impl(void * temporary_storage, { using offset_type = unsigned int; using input_type = typename std::iterator_traits::value_type; - using output_type = typename std::iterator_traits::value_type; - // Fix for cases when output_type is void (there's no sizeof(void)), it's - // a tuple which contains an item of type void, or is not convertible to - // input_type which is used in InequalityOp - static constexpr bool is_output_type_voidlike = - std::is_same::value || tuple_contains_type::value; - // If output_type is voidlike we don't want std::is_convertible to - // be evaluated, it leads to errors if input_type is a tuple - using is_output_type_convertible = typename std::conditional< - is_output_type_voidlike, std::false_type, std::is_convertible - >::type; - static constexpr bool is_output_type_invalid = - is_output_type_voidlike || !(is_output_type_convertible::value); - using value_type = typename std::conditional< - is_output_type_invalid, input_type, output_type - >::type; - // Use smaller type for private storage - using result_type = typename std::conditional< - (sizeof(value_type) > sizeof(input_type)), input_type, value_type - >::type; // Get default config if Config is default_config using config = default_or_custom_config< Config, - default_select_config + default_select_config >; using offset_scan_state_type = detail::lookback_scan_state; @@ -218,7 +196,7 @@ hipError_t partition_impl(void * temporary_storage, grid_size = number_of_blocks; hipLaunchKernelGGL( HIP_KERNEL_NAME(partition_kernel< - SelectMethod, OnlySelected, config, result_type, + SelectMethod, OnlySelected, config, InputIterator, FlagIterator, OutputIterator, SelectedCountOutputIterator, UnaryPredicate, decltype(inequality_op), offset_scan_state_type >), diff --git a/rocprim/include/rocprim/device/device_reduce_by_key_hc.hpp b/rocprim/include/rocprim/device/device_reduce_by_key_hc.hpp index 16212b365..a14365296 100644 --- a/rocprim/include/rocprim/device/device_reduce_by_key_hc.hpp +++ b/rocprim/include/rocprim/device/device_reduce_by_key_hc.hpp @@ -80,10 +80,9 @@ void reduce_by_key_impl(void * temporary_storage, using key_type = typename std::iterator_traits::value_type; using result_type = typename ::rocprim::detail::match_result_type< typename std::iterator_traits::value_type, - typename std::iterator_traits::value_type, BinaryFunction >::type; - using carry_out_type = carry_out; + using carry_out_type = carry_out; using config = default_or_custom_config< Config, @@ -183,7 +182,7 @@ void reduce_by_key_impl(void * temporary_storage, scan_and_scatter_carry_outs( carry_outs, leading_aggregates, aggregates_output, - key_compare_op, reduce_op, + reduce_op, batches ); } diff --git a/rocprim/include/rocprim/device/device_reduce_by_key_hip.hpp b/rocprim/include/rocprim/device/device_reduce_by_key_hip.hpp index f3f5d1ddf..97773d35f 100644 --- a/rocprim/include/rocprim/device/device_reduce_by_key_hip.hpp +++ b/rocprim/include/rocprim/device/device_reduce_by_key_hip.hpp @@ -81,7 +81,6 @@ template< unsigned int ItemsPerThread, class KeysInputIterator, class ValuesInputIterator, - class Key, class Result, class UniqueOutputIterator, class AggregatesOutputIterator, @@ -93,7 +92,7 @@ void reduce_by_key_kernel(KeysInputIterator keys_input, ValuesInputIterator values_input, unsigned int size, const unsigned int * unique_starts, - carry_out * carry_outs, + carry_out * carry_outs, Result * leading_aggregates, UniqueOutputIterator unique_output, AggregatesOutputIterator aggregates_output, @@ -114,23 +113,20 @@ void reduce_by_key_kernel(KeysInputIterator keys_input, template< unsigned int BlockSize, unsigned int ItemsPerThread, - class Key, class Result, class AggregatesOutputIterator, - class KeyCompareFunction, class BinaryFunction > __global__ -void scan_and_scatter_carry_outs_kernel(const carry_out * carry_outs, +void scan_and_scatter_carry_outs_kernel(const carry_out * carry_outs, const Result * leading_aggregates, AggregatesOutputIterator aggregates_output, - KeyCompareFunction key_compare_op, BinaryFunction reduce_op, unsigned int batches) { scan_and_scatter_carry_outs( carry_outs, leading_aggregates, aggregates_output, - key_compare_op, reduce_op, + reduce_op, batches ); } @@ -177,10 +173,9 @@ hipError_t reduce_by_key_impl(void * temporary_storage, using key_type = typename std::iterator_traits::value_type; using result_type = typename ::rocprim::detail::match_result_type< typename std::iterator_traits::value_type, - typename std::iterator_traits::value_type, BinaryFunction >::type; - using carry_out_type = carry_out; + using carry_out_type = carry_out; using config = default_or_custom_config< Config, @@ -263,7 +258,7 @@ hipError_t reduce_by_key_impl(void * temporary_storage, dim3(1), dim3(config::scan::block_size), 0, stream, const_cast(carry_outs), const_cast(leading_aggregates), aggregates_output, - key_compare_op, reduce_op, + reduce_op, batches ); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("scan_and_scatter_carry_outs", config::scan::block_size, start) diff --git a/rocprim/include/rocprim/device/device_reduce_hc.hpp b/rocprim/include/rocprim/device/device_reduce_hc.hpp index 2d74438ee..12522bf3f 100644 --- a/rocprim/include/rocprim/device/device_reduce_hc.hpp +++ b/rocprim/include/rocprim/device/device_reduce_hc.hpp @@ -71,9 +71,8 @@ void reduce_impl(void * temporary_storage, const bool debug_synchronous) { using input_type = typename std::iterator_traits::value_type; - using output_type = typename std::iterator_traits::value_type; using result_type = typename ::rocprim::detail::match_result_type< - input_type, output_type, BinaryFunction + input_type, BinaryFunction >::type; // Get default config if Config is default_config diff --git a/rocprim/include/rocprim/device/device_reduce_hip.hpp b/rocprim/include/rocprim/device/device_reduce_hip.hpp index c67870f8a..667a3bc4f 100644 --- a/rocprim/include/rocprim/device/device_reduce_hip.hpp +++ b/rocprim/include/rocprim/device/device_reduce_hip.hpp @@ -107,9 +107,8 @@ hipError_t reduce_impl(void * temporary_storage, bool debug_synchronous) { using input_type = typename std::iterator_traits::value_type; - using output_type = typename std::iterator_traits::value_type; using result_type = typename ::rocprim::detail::match_result_type< - input_type, output_type, BinaryFunction + input_type, BinaryFunction >::type; // Get default config if Config is default_config diff --git a/rocprim/include/rocprim/device/device_scan_by_key_hc.hpp b/rocprim/include/rocprim/device/device_scan_by_key_hc.hpp index d765e6338..d6f60a7ea 100644 --- a/rocprim/include/rocprim/device/device_scan_by_key_hc.hpp +++ b/rocprim/include/rocprim/device/device_scan_by_key_hc.hpp @@ -27,7 +27,6 @@ #include "../config.hpp" #include "../iterator/zip_iterator.hpp" #include "../iterator/discard_iterator.hpp" -#include "../iterator/detail/replace_first_iterator.hpp" #include "../types/tuple.hpp" #include "../detail/various.hpp" @@ -148,21 +147,35 @@ void inclusive_scan_by_key(void * temporary_storage, const bool debug_synchronous = false) { using input_type = typename std::iterator_traits::value_type; - using key_type = typename std::iterator_traits::value_type; - using scan_by_key_operator = detail::scan_by_key_op_wrapper< - input_type, key_type, BinaryFunction, KeyCompareFunction - >; + using result_type = typename ::rocprim::detail::match_result_type< + input_type, BinaryFunction + >::type; + using flag_type = bool; + using headflag_scan_op_wrapper_type = + detail::headflag_scan_op_wrapper< + result_type, flag_type, BinaryFunction + >; - return inclusive_scan( + // Flag the first item of each segment as its head, + // then run inclusive scan + inclusive_scan( temporary_storage, storage_size, - make_zip_iterator( - make_tuple(values_input, keys_input) - ), - make_zip_iterator( - make_tuple(values_output, make_discard_iterator()) + rocprim::make_transform_iterator( + rocprim::make_counting_iterator(0), + [values_input, keys_input, key_compare_op] + ROCPRIM_DEVICE (const size_t i) + { + flag_type flag(true); + if(i > 0) + { + flag = flag_type(!key_compare_op(keys_input[i - 1], keys_input[i])); + } + return rocprim::make_tuple(values_input[i], flag); + } ), + rocprim::make_zip_iterator(rocprim::make_tuple(values_output, rocprim::make_discard_iterator())), size, - scan_by_key_operator(scan_op, key_compare_op), + headflag_scan_op_wrapper_type(scan_op), acc_view, debug_synchronous ); @@ -283,62 +296,48 @@ void exclusive_scan_by_key(void * temporary_storage, const bool debug_synchronous = false) { using input_type = typename std::iterator_traits::value_type; - using key_type = typename std::iterator_traits::value_type; - using scan_by_key_operator = detail::scan_by_key_op_wrapper< - input_type, key_type, BinaryFunction, KeyCompareFunction - >; + using result_type = typename ::rocprim::detail::match_result_type< + input_type, BinaryFunction + >::type; + using flag_type = bool; + using headflag_scan_op_wrapper_type = + detail::headflag_scan_op_wrapper< + result_type, flag_type, BinaryFunction + >; - return inclusive_scan( + const result_type initial_value_converted = static_cast(initial_value); + + // Flag the last item of each segment as the next segment's head, use initial_value as its value, + // then run exclusive scan + exclusive_scan( temporary_storage, storage_size, - // Using replace_first_iterator shifts input one item to left and replaces - // first value with initial_value. Then transform_iterator replaces last - // elements of other segments to initial_value. That modified input data - // can be inclusively scanned and produce expected exclusive results. - // - // values_input: [1, 2, 3, 4, 5, 6, 7, 8] - // replace_first_iterator(values_input): [9, 1, 2, 3, 4, 5, 6, 7] - // keys_input: [1, 1, 1, 2, 2, 3, 3, 4] - // replace_first_iterator(keys_input): [-, 1, 1, 1, 2, 2, 3, 3] - // initial_value: 9 - // transform_iterator: [9, 1, 2, 9, 4, 9, 6, 9] - // - // inclusive_scan result: [9, 10, 12, 9, 13, 9, 15, 9] - make_transform_iterator( - make_zip_iterator( - make_tuple( - detail::replace_first_iterator( - values_input - 1, initial_value - ), - keys_input, - detail::replace_first_iterator( - keys_input - 1, key_type() - ) - ) - ), - [initial_value, key_compare_op](const ::rocprim::tuple& t) - -> ::rocprim::tuple + rocprim::make_transform_iterator( + rocprim::make_counting_iterator(0), + [values_input, keys_input, key_compare_op, initial_value_converted, size] + ROCPRIM_DEVICE (const size_t i) { - if(!key_compare_op(::rocprim::get<1>(t), ::rocprim::get<2>(t))) + flag_type flag(false); + if(i + 1 < size) + { + flag = flag_type(!key_compare_op(keys_input[i], keys_input[i + 1])); + } + result_type value = initial_value_converted; + if(!flag) { - return ::rocprim::make_tuple( - static_cast(initial_value), - ::rocprim::get<1>(t) - ); + value = values_input[i]; } - return ::rocprim::make_tuple( - ::rocprim::get<0>(t), ::rocprim::get<1>(t) - ); + return rocprim::make_tuple(value, flag); } ), - make_zip_iterator(make_tuple(values_output, make_discard_iterator())), + rocprim::make_zip_iterator(rocprim::make_tuple(values_output, rocprim::make_discard_iterator())), + rocprim::make_tuple(initial_value_converted, flag_type(true)), // init value is a head of the first segment size, - scan_by_key_operator(scan_op, key_compare_op), + headflag_scan_op_wrapper_type(scan_op), acc_view, debug_synchronous ); } - /// @} // end of group devicemodule_hc diff --git a/rocprim/include/rocprim/device/device_scan_by_key_hip.hpp b/rocprim/include/rocprim/device/device_scan_by_key_hip.hpp index 6b2ae7dec..a7dbdaaff 100644 --- a/rocprim/include/rocprim/device/device_scan_by_key_hip.hpp +++ b/rocprim/include/rocprim/device/device_scan_by_key_hip.hpp @@ -27,7 +27,6 @@ #include "../config.hpp" #include "../iterator/zip_iterator.hpp" #include "../iterator/discard_iterator.hpp" -#include "../iterator/detail/replace_first_iterator.hpp" #include "../types/tuple.hpp" #include "../detail/various.hpp" @@ -149,21 +148,35 @@ hipError_t inclusive_scan_by_key(void * temporary_storage, bool debug_synchronous = false) { using input_type = typename std::iterator_traits::value_type; - using key_type = typename std::iterator_traits::value_type; - using scan_by_key_operator = detail::scan_by_key_op_wrapper< - input_type, key_type, BinaryFunction, KeyCompareFunction - >; + using result_type = typename ::rocprim::detail::match_result_type< + input_type, BinaryFunction + >::type; + using flag_type = bool; + using headflag_scan_op_wrapper_type = + detail::headflag_scan_op_wrapper< + result_type, flag_type, BinaryFunction + >; + // Flag the first item of each segment as its head, + // then run inclusive scan return inclusive_scan( temporary_storage, storage_size, - make_zip_iterator( - make_tuple(values_input, keys_input) - ), - make_zip_iterator( - make_tuple(values_output, make_discard_iterator()) + rocprim::make_transform_iterator( + rocprim::make_counting_iterator(0), + [values_input, keys_input, key_compare_op] + ROCPRIM_DEVICE (const size_t i) + { + flag_type flag(true); + if(i > 0) + { + flag = flag_type(!key_compare_op(keys_input[i - 1], keys_input[i])); + } + return rocprim::make_tuple(values_input[i], flag); + } ), + rocprim::make_zip_iterator(rocprim::make_tuple(values_output, rocprim::make_discard_iterator())), size, - scan_by_key_operator(scan_op, key_compare_op), + headflag_scan_op_wrapper_type(scan_op), stream, debug_synchronous ); @@ -285,44 +298,43 @@ hipError_t exclusive_scan_by_key(void * temporary_storage, bool debug_synchronous = false) { using input_type = typename std::iterator_traits::value_type; - using key_type = typename std::iterator_traits::value_type; - using scan_by_key_operator = detail::scan_by_key_op_wrapper< - input_type, key_type, BinaryFunction, KeyCompareFunction - >; + using result_type = typename ::rocprim::detail::match_result_type< + input_type, BinaryFunction + >::type; + using flag_type = bool; + using headflag_scan_op_wrapper_type = + detail::headflag_scan_op_wrapper< + result_type, flag_type, BinaryFunction + >; - return inclusive_scan( + const result_type initial_value_converted = static_cast(initial_value); + + // Flag the last item of each segment as the next segment's head, use initial_value as its value, + // then run exclusive scan + return exclusive_scan( temporary_storage, storage_size, - make_transform_iterator( - make_zip_iterator( - make_tuple( - detail::replace_first_iterator( - values_input - 1, initial_value - ), - keys_input, - detail::replace_first_iterator( - keys_input - 1, key_type() - ) - ) - ), - [initial_value, key_compare_op] ROCPRIM_HOST_DEVICE - (const ::rocprim::tuple& t) - -> ::rocprim::tuple + rocprim::make_transform_iterator( + rocprim::make_counting_iterator(0), + [values_input, keys_input, key_compare_op, initial_value_converted, size] + ROCPRIM_DEVICE (const size_t i) { - if(!key_compare_op(::rocprim::get<1>(t), ::rocprim::get<2>(t))) + flag_type flag(false); + if(i + 1 < size) + { + flag = flag_type(!key_compare_op(keys_input[i], keys_input[i + 1])); + } + result_type value = initial_value_converted; + if(!flag) { - return ::rocprim::make_tuple( - static_cast(initial_value), - ::rocprim::get<1>(t) - ); + value = values_input[i]; } - return ::rocprim::make_tuple( - ::rocprim::get<0>(t), ::rocprim::get<1>(t) - ); + return rocprim::make_tuple(value, flag); } ), - make_zip_iterator(make_tuple(values_output, make_discard_iterator())), + rocprim::make_zip_iterator(rocprim::make_tuple(values_output, rocprim::make_discard_iterator())), + rocprim::make_tuple(initial_value_converted, flag_type(true)), // init value is a head of the first segment size, - scan_by_key_operator(scan_op, key_compare_op), + headflag_scan_op_wrapper_type(scan_op), stream, debug_synchronous ); diff --git a/rocprim/include/rocprim/device/device_scan_hc.hpp b/rocprim/include/rocprim/device/device_scan_hc.hpp index aa3fcb074..dad82cb79 100644 --- a/rocprim/include/rocprim/device/device_scan_hc.hpp +++ b/rocprim/include/rocprim/device/device_scan_hc.hpp @@ -75,9 +75,8 @@ auto scan_impl(void * temporary_storage, -> typename std::enable_if::type { using input_type = typename std::iterator_traits::value_type; - using output_type = typename std::iterator_traits::value_type; using result_type = typename ::rocprim::detail::match_result_type< - input_type, output_type, BinaryFunction + input_type, BinaryFunction >::type; // Get default config if Config is default_config @@ -205,9 +204,8 @@ auto scan_impl(void * temporary_storage, -> typename std::enable_if::type { using input_type = typename std::iterator_traits::value_type; - using output_type = typename std::iterator_traits::value_type; using result_type = typename ::rocprim::detail::match_result_type< - input_type, output_type, BinaryFunction + input_type, BinaryFunction >::type; // Get default config if Config is default_config @@ -394,9 +392,8 @@ void inclusive_scan(void * temporary_storage, const bool debug_synchronous = false) { using input_type = typename std::iterator_traits::value_type; - using output_type = typename std::iterator_traits::value_type; using result_type = typename ::rocprim::detail::match_result_type< - input_type, output_type, BinaryFunction + input_type, BinaryFunction >::type; // Get default config if Config is default_config @@ -516,9 +513,8 @@ void exclusive_scan(void * temporary_storage, const bool debug_synchronous = false) { using input_type = typename std::iterator_traits::value_type; - using output_type = typename std::iterator_traits::value_type; using result_type = typename ::rocprim::detail::match_result_type< - input_type, output_type, BinaryFunction + input_type, BinaryFunction >::type; // Get default config if Config is default_config diff --git a/rocprim/include/rocprim/device/device_scan_hip.hpp b/rocprim/include/rocprim/device/device_scan_hip.hpp index f88a1cff4..36d9115ab 100644 --- a/rocprim/include/rocprim/device/device_scan_hip.hpp +++ b/rocprim/include/rocprim/device/device_scan_hip.hpp @@ -189,9 +189,8 @@ auto scan_impl(void * temporary_storage, -> typename std::enable_if::type { using input_type = typename std::iterator_traits::value_type; - using output_type = typename std::iterator_traits::value_type; using result_type = typename ::rocprim::detail::match_result_type< - input_type, output_type, BinaryFunction + input_type, BinaryFunction >::type; using config = Config; @@ -319,9 +318,8 @@ auto scan_impl(void * temporary_storage, -> typename std::enable_if::type { using input_type = typename std::iterator_traits::value_type; - using output_type = typename std::iterator_traits::value_type; using result_type = typename ::rocprim::detail::match_result_type< - input_type, output_type, BinaryFunction + input_type, BinaryFunction >::type; using config = Config; @@ -503,9 +501,8 @@ hipError_t inclusive_scan(void * temporary_storage, bool debug_synchronous = false) { using input_type = typename std::iterator_traits::value_type; - using output_type = typename std::iterator_traits::value_type; using result_type = typename ::rocprim::detail::match_result_type< - input_type, output_type, BinaryFunction + input_type, BinaryFunction >::type; // Get default config if Config is default_config @@ -625,9 +622,8 @@ hipError_t exclusive_scan(void * temporary_storage, bool debug_synchronous = false) { using input_type = typename std::iterator_traits::value_type; - using output_type = typename std::iterator_traits::value_type; using result_type = typename ::rocprim::detail::match_result_type< - input_type, output_type, BinaryFunction + input_type, BinaryFunction >::type; // Get default config if Config is default_config diff --git a/rocprim/include/rocprim/device/device_segmented_reduce_hc.hpp b/rocprim/include/rocprim/device/device_segmented_reduce_hc.hpp index ac238d0f4..152a1c6bc 100644 --- a/rocprim/include/rocprim/device/device_segmented_reduce_hc.hpp +++ b/rocprim/include/rocprim/device/device_segmented_reduce_hc.hpp @@ -27,6 +27,7 @@ #include "../config.hpp" #include "../functional.hpp" #include "../detail/various.hpp" +#include "../detail/match_result_type.hpp" #include "detail/device_segmented_reduce.hpp" @@ -72,9 +73,8 @@ void segmented_reduce_impl(void * temporary_storage, bool debug_synchronous) { using input_type = typename std::iterator_traits::value_type; - using output_type = typename std::iterator_traits::value_type; using result_type = typename ::rocprim::detail::match_result_type< - input_type, output_type, BinaryFunction + input_type, BinaryFunction >::type; // Get default config if Config is default_config diff --git a/rocprim/include/rocprim/device/device_segmented_reduce_hip.hpp b/rocprim/include/rocprim/device/device_segmented_reduce_hip.hpp index a17088b40..3f3112df0 100644 --- a/rocprim/include/rocprim/device/device_segmented_reduce_hip.hpp +++ b/rocprim/include/rocprim/device/device_segmented_reduce_hip.hpp @@ -27,6 +27,7 @@ #include "../config.hpp" #include "../functional.hpp" #include "../detail/various.hpp" +#include "../detail/match_result_type.hpp" #include "detail/device_segmented_reduce.hpp" @@ -98,9 +99,8 @@ hipError_t segmented_reduce_impl(void * temporary_storage, bool debug_synchronous) { using input_type = typename std::iterator_traits::value_type; - using output_type = typename std::iterator_traits::value_type; using result_type = typename ::rocprim::detail::match_result_type< - input_type, output_type, BinaryFunction + input_type, BinaryFunction >::type; // Get default config if Config is default_config diff --git a/rocprim/include/rocprim/device/device_segmented_scan_hc.hpp b/rocprim/include/rocprim/device/device_segmented_scan_hc.hpp index a559812f8..cbd4f5ba3 100644 --- a/rocprim/include/rocprim/device/device_segmented_scan_hc.hpp +++ b/rocprim/include/rocprim/device/device_segmented_scan_hc.hpp @@ -26,11 +26,11 @@ #include "../config.hpp" #include "../detail/various.hpp" +#include "../detail/match_result_type.hpp" #include "../iterator/zip_iterator.hpp" #include "../iterator/discard_iterator.hpp" #include "../iterator/transform_iterator.hpp" -#include "../iterator/detail/replace_first_iterator.hpp" #include "../types/tuple.hpp" #include "device_scan_hc.hpp" @@ -79,9 +79,8 @@ void segmented_scan_impl(void * temporary_storage, const bool debug_synchronous) { using input_type = typename std::iterator_traits::value_type; - using output_type = typename std::iterator_traits::value_type; using result_type = typename ::rocprim::detail::match_result_type< - input_type, output_type, BinaryFunction + input_type, BinaryFunction >::type; // Get default config if Config is default_config @@ -230,9 +229,8 @@ void segmented_inclusive_scan(void * temporary_storage, const bool debug_synchronous = false) { using input_type = typename std::iterator_traits::value_type; - using output_type = typename std::iterator_traits::value_type; using result_type = typename ::rocprim::detail::match_result_type< - input_type, output_type, BinaryFunction + input_type, BinaryFunction >::type; return detail::segmented_scan_impl( @@ -460,16 +458,19 @@ void segmented_inclusive_scan(void * temporary_storage, const bool debug_synchronous = false) { using input_type = typename std::iterator_traits::value_type; + using result_type = typename ::rocprim::detail::match_result_type< + input_type, BinaryFunction + >::type; using flag_type = typename std::iterator_traits::value_type; using headflag_scan_op_wrapper_type = detail::headflag_scan_op_wrapper< - input_type, flag_type, BinaryFunction + result_type, flag_type, BinaryFunction >; - return inclusive_scan( + inclusive_scan( temporary_storage, storage_size, - make_zip_iterator(make_tuple(input, head_flags)), - make_zip_iterator(make_tuple(output, make_discard_iterator())), + rocprim::make_zip_iterator(rocprim::make_tuple(input, head_flags)), + rocprim::make_zip_iterator(rocprim::make_tuple(output, rocprim::make_discard_iterator())), size, headflag_scan_op_wrapper_type(scan_op), acc_view, debug_synchronous ); @@ -578,49 +579,45 @@ void segmented_exclusive_scan(void * temporary_storage, const bool debug_synchronous = false) { using input_type = typename std::iterator_traits::value_type; + using result_type = typename ::rocprim::detail::match_result_type< + input_type, BinaryFunction + >::type; using flag_type = typename std::iterator_traits::value_type; using headflag_scan_op_wrapper_type = detail::headflag_scan_op_wrapper< - input_type, flag_type, BinaryFunction + result_type, flag_type, BinaryFunction >; - return inclusive_scan( + const result_type initial_value_converted = static_cast(initial_value); + + // Flag the last item of each segment as the next segment's head, use initial_value as its value, + // then run exclusive scan + exclusive_scan( temporary_storage, storage_size, - // Using replace_first_iterator shifts input one item to left and replaces - // first value with initial_value. Then transform_iterator replaces last - // elements of other segments to initial_value. That modified input data - // can be inclusively scanned and produce expected exclusive results. - // - // input: [1, 2, 3, 4, 5, 6, 7, 8] - // replace_first_iterator(input): [9, 1, 2, 3, 4, 5, 6, 7] - // head_flags: [1, 0, 0, 1, 0, 1, 0, 0] - // initial_value: 9 - // transform_iterator: [9, 1, 2, 9, 4, 9, 6, 7] - // - // inclusive_scan result: [9, 10, 12, 9, 13, 9, 15, 22] - make_transform_iterator( - make_zip_iterator( - make_tuple( - detail::replace_first_iterator(input - 1, initial_value), - head_flags - ) - ), - [initial_value](const ::rocprim::tuple& t) - -> ::rocprim::tuple + rocprim::make_transform_iterator( + rocprim::make_counting_iterator(0), + [input, head_flags, initial_value_converted, size] + ROCPRIM_DEVICE (const size_t i) { - if(::rocprim::get<1>(t)) + flag_type flag(false); + if(i + 1 < size) { - return ::rocprim::make_tuple( - static_cast(initial_value), - ::rocprim::get<1>(t) - ); + flag = head_flags[i + 1]; } - return t; + result_type value = initial_value_converted; + if(!flag) + { + value = input[i]; + } + return rocprim::make_tuple(value, flag); } ), - make_zip_iterator(make_tuple(output, make_discard_iterator())), - size, headflag_scan_op_wrapper_type(scan_op), - acc_view, debug_synchronous + rocprim::make_zip_iterator(rocprim::make_tuple(output, rocprim::make_discard_iterator())), + rocprim::make_tuple(initial_value_converted, flag_type(true)), // init value is a head of the first segment + size, + headflag_scan_op_wrapper_type(scan_op), + acc_view, + debug_synchronous ); } diff --git a/rocprim/include/rocprim/device/device_segmented_scan_hip.hpp b/rocprim/include/rocprim/device/device_segmented_scan_hip.hpp index 794cc1b95..af88be8c1 100644 --- a/rocprim/include/rocprim/device/device_segmented_scan_hip.hpp +++ b/rocprim/include/rocprim/device/device_segmented_scan_hip.hpp @@ -26,11 +26,11 @@ #include "../config.hpp" #include "../detail/various.hpp" +#include "../detail/match_result_type.hpp" #include "../iterator/zip_iterator.hpp" #include "../iterator/discard_iterator.hpp" #include "../iterator/transform_iterator.hpp" -#include "../iterator/detail/replace_first_iterator.hpp" #include "../types/tuple.hpp" #include "device_scan_config.hpp" @@ -106,9 +106,8 @@ hipError_t segmented_scan_impl(void * temporary_storage, bool debug_synchronous) { using input_type = typename std::iterator_traits::value_type; - using output_type = typename std::iterator_traits::value_type; using result_type = typename ::rocprim::detail::match_result_type< - input_type, output_type, BinaryFunction + input_type, BinaryFunction >::type; // Get default config if Config is default_config @@ -250,9 +249,8 @@ hipError_t segmented_inclusive_scan(void * temporary_storage, bool debug_synchronous = false) { using input_type = typename std::iterator_traits::value_type; - using output_type = typename std::iterator_traits::value_type; using result_type = typename ::rocprim::detail::match_result_type< - input_type, output_type, BinaryFunction + input_type, BinaryFunction >::type; return detail::segmented_scan_impl( @@ -478,16 +476,19 @@ hipError_t segmented_inclusive_scan(void * temporary_storage, bool debug_synchronous = false) { using input_type = typename std::iterator_traits::value_type; + using result_type = typename ::rocprim::detail::match_result_type< + input_type, BinaryFunction + >::type; using flag_type = typename std::iterator_traits::value_type; using headflag_scan_op_wrapper_type = detail::headflag_scan_op_wrapper< - input_type, flag_type, BinaryFunction + result_type, flag_type, BinaryFunction >; return inclusive_scan( temporary_storage, storage_size, - make_zip_iterator(make_tuple(input, head_flags)), - make_zip_iterator(make_tuple(output, make_discard_iterator())), + rocprim::make_zip_iterator(rocprim::make_tuple(input, head_flags)), + rocprim::make_zip_iterator(rocprim::make_tuple(output, rocprim::make_discard_iterator())), size, headflag_scan_op_wrapper_type(scan_op), stream, debug_synchronous ); @@ -595,45 +596,45 @@ hipError_t segmented_exclusive_scan(void * temporary_storage, bool debug_synchronous = false) { using input_type = typename std::iterator_traits::value_type; + using result_type = typename ::rocprim::detail::match_result_type< + input_type, BinaryFunction + >::type; using flag_type = typename std::iterator_traits::value_type; using headflag_scan_op_wrapper_type = detail::headflag_scan_op_wrapper< - input_type, flag_type, BinaryFunction + result_type, flag_type, BinaryFunction >; - return inclusive_scan( + const result_type initial_value_converted = static_cast(initial_value); + + // Flag the last item of each segment as the next segment's head, use initial_value as its value, + // then run exclusive scan + return exclusive_scan( temporary_storage, storage_size, - // input: [1, 2, 3, 4, 5, 6, 7, 8] - // replace_first_iterator(input): [9, 1, 2, 3, 4, 5, 6, 7] - // head_flags: [1, 0, 0, 1, 0, 1, 0, 0] - // initial_value: 9 - // transform_iterator: [9, 1, 2, 9, 4, 9, 6, 7] - // - // inclusive_scan result: [9, 10, 12, 9, 13, 9, 15, 22] - make_transform_iterator( - make_zip_iterator( - make_tuple( - detail::replace_first_iterator(input - 1, initial_value), - head_flags - ) - ), - [initial_value] ROCPRIM_HOST_DEVICE - (const ::rocprim::tuple& t) - -> ::rocprim::tuple + rocprim::make_transform_iterator( + rocprim::make_counting_iterator(0), + [input, head_flags, initial_value_converted, size] + ROCPRIM_DEVICE (const size_t i) { - if(::rocprim::get<1>(t)) + flag_type flag(false); + if(i + 1 < size) + { + flag = head_flags[i + 1]; + } + result_type value = initial_value_converted; + if(!flag) { - return ::rocprim::make_tuple( - static_cast(initial_value), - ::rocprim::get<1>(t) - ); + value = input[i]; } - return t; + return rocprim::make_tuple(value, flag); } ), - make_zip_iterator(make_tuple(output, make_discard_iterator())), - size, headflag_scan_op_wrapper_type(scan_op), - stream, debug_synchronous + rocprim::make_zip_iterator(rocprim::make_tuple(output, rocprim::make_discard_iterator())), + rocprim::make_tuple(initial_value_converted, flag_type(true)), // init value is a head of the first segment + size, + headflag_scan_op_wrapper_type(scan_op), + stream, + debug_synchronous ); } diff --git a/rocprim/include/rocprim/functional.hpp b/rocprim/include/rocprim/functional.hpp index 3b638c710..4fd622cd3 100644 --- a/rocprim/include/rocprim/functional.hpp +++ b/rocprim/include/rocprim/functional.hpp @@ -54,7 +54,7 @@ void swap(T& a, T& b) b = c; } -template +template struct less { ROCPRIM_HOST_DEVICE inline @@ -64,6 +64,17 @@ struct less } }; +template<> +struct less +{ + template + ROCPRIM_HOST_DEVICE inline + constexpr bool operator()(const T& a, const U& b) const + { + return a < b; + } +}; + template struct less_equal { diff --git a/rocprim/include/rocprim/iterator/counting_iterator.hpp b/rocprim/include/rocprim/iterator/counting_iterator.hpp index 2799b73b3..1c95cb303 100644 --- a/rocprim/include/rocprim/iterator/counting_iterator.hpp +++ b/rocprim/include/rocprim/iterator/counting_iterator.hpp @@ -35,7 +35,7 @@ BEGIN_ROCPRIM_NAMESPACE /// \class counting_iterator -/// \brief A random-access input (read-only) iterator over a sequence of consecutive values. +/// \brief A random-access input (read-only) iterator over a sequence of consecutive integer values. /// /// \par Overview /// * A counting_iterator represents a pointer into a range of sequentially increasing values. @@ -65,6 +65,8 @@ class counting_iterator /// The category of the iterator. using iterator_category = std::random_access_iterator_tag; + static_assert(std::is_integral::value, "Incrementable must be integral type"); + #ifndef DOXYGEN_SHOULD_SKIP_THIS using self_type = counting_iterator; #endif @@ -213,16 +215,7 @@ class counting_iterator private: template inline - auto equal_value(const T& x, const T& y) const - -> typename std::enable_if<::rocprim::is_floating_point::value, bool>::type - { - return difference_type(y) - difference_type(x) == 0; - } - - template - inline - auto equal_value(const T& x, const T& y) const - -> typename std::enable_if::value, bool>::type + bool equal_value(const T& x, const T& y) const { return (x == y); } diff --git a/rocprim/include/rocprim/iterator/transform_iterator.hpp b/rocprim/include/rocprim/iterator/transform_iterator.hpp index c256ba2a6..d53c20738 100644 --- a/rocprim/include/rocprim/iterator/transform_iterator.hpp +++ b/rocprim/include/rocprim/iterator/transform_iterator.hpp @@ -63,13 +63,6 @@ template< > class transform_iterator { -private: - using input_category = typename std::iterator_traits::iterator_category; - static_assert( - std::is_same::value, - "InputIterator must be a random-access iterator" - ); - public: /// The type of the value that can be obtained by dereferencing the iterator. using value_type = ValueType; diff --git a/rocprim/include/rocprim/rocprim.hpp b/rocprim/include/rocprim/rocprim.hpp index 8da8996f0..2bf0ddc76 100644 --- a/rocprim/include/rocprim/rocprim.hpp +++ b/rocprim/include/rocprim/rocprim.hpp @@ -50,6 +50,7 @@ #include "block/block_store.hpp" #ifdef ROCPRIM_HC_API + #include "device/device_binary_search_hc.hpp" #include "device/device_histogram_hc.hpp" #include "device/device_merge_hc.hpp" #include "device/device_merge_sort_hc.hpp" @@ -66,6 +67,7 @@ #include "device/device_select_hc.hpp" #include "device/device_transform_hc.hpp" #else + #include "device/device_binary_search_hip.hpp" #include "device/device_histogram_hip.hpp" #include "device/device_merge_hip.hpp" #include "device/device_merge_sort_hip.hpp" diff --git a/test/rocprim/CMakeLists.txt b/test/rocprim/CMakeLists.txt index c7a22124f..69091bfea 100644 --- a/test/rocprim/CMakeLists.txt +++ b/test/rocprim/CMakeLists.txt @@ -77,6 +77,7 @@ add_rocprim_test_hc("rocprim.hc.block_scan" test_hc_block_scan.cpp) add_rocprim_test_hc("rocprim.hc.block_sort" test_hc_block_sort.cpp) add_rocprim_test_hc("rocprim.hc.constant_iterator" test_hc_constant_iterator.cpp) add_rocprim_test_hc("rocprim.hc.counting_iterator" test_hc_counting_iterator.cpp) +add_rocprim_test_hc("rocprim.hc.device_binary_search" test_hc_device_binary_search.cpp) add_rocprim_test_hc("rocprim.hc.device_histogram" test_hc_device_histogram.cpp) add_rocprim_test_hc("rocprim.hc.device_merge" test_hc_device_merge.cpp) add_rocprim_test_hc("rocprim.hc.device_merge_sort" test_hc_device_merge_sort.cpp) @@ -119,6 +120,7 @@ add_rocprim_test_hip("rocprim.hip.block_scan" test_hip_block_scan.cpp) add_rocprim_test_hip("rocprim.hip.block_sort" test_hip_block_sort.cpp) add_rocprim_test_hip("rocprim.hip.constant_iterator" test_hip_constant_iterator.cpp) add_rocprim_test_hip("rocprim.hip.counting_iterator" test_hip_counting_iterator.cpp) +add_rocprim_test_hip("rocprim.hip.device_binary_search" test_hip_device_binary_search.cpp) add_rocprim_test_hip("rocprim.hip.device_histogram" test_hip_device_histogram.cpp) add_rocprim_test_hip("rocprim.hip.device_merge" test_hip_device_merge.cpp) add_rocprim_test_hip("rocprim.hip.device_merge_sort" test_hip_device_merge_sort.cpp) diff --git a/test/rocprim/test_hc_counting_iterator.cpp b/test/rocprim/test_hc_counting_iterator.cpp index 29141a786..da3cf154e 100644 --- a/test/rocprim/test_hc_counting_iterator.cpp +++ b/test/rocprim/test_hc_counting_iterator.cpp @@ -53,7 +53,7 @@ typedef ::testing::Types< RocprimCountingIteratorParams, RocprimCountingIteratorParams, RocprimCountingIteratorParams, - RocprimCountingIteratorParams + RocprimCountingIteratorParams > RocprimCountingIteratorTestsParams; TYPED_TEST_CASE(RocprimCountingIteratorTests, RocprimCountingIteratorTestsParams); @@ -147,14 +147,6 @@ TYPED_TEST(RocprimCountingIteratorTests, Transform) std::vector output = d_output; for(size_t i = 0; i < output.size(); i++) { - if(std::is_integral::value) - { - ASSERT_EQ(output[i], expected[i]) << "where index = " << i; - } - else if(std::is_floating_point::value) - { - auto tolerance = std::max(std::abs(0.1f * expected[i]), T(0.01f)); - ASSERT_NEAR(output[i], expected[i], tolerance) << "where index = " << i; - } + ASSERT_EQ(output[i], expected[i]) << "where index = " << i; } } diff --git a/test/rocprim/test_hc_device_binary_search.cpp b/test/rocprim/test_hc_device_binary_search.cpp new file mode 100644 index 000000000..b7c30fb1c --- /dev/null +++ b/test/rocprim/test_hc_device_binary_search.cpp @@ -0,0 +1,300 @@ +// MIT License +// +// Copyright (c) 2019 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include +#include +#include +#include +#include +#include + +// Google Test +#include + +// HC API +#include +// rocPRIM API +#include + +#include "test_utils.hpp" + +template< + class Haystack, + class Needle, + class Output = size_t, + class CompareFunction = rocprim::less<> +> +struct params +{ + using haystack_type = Haystack; + using needle_type = Needle; + using output_type = Output; + using compare_op_type = CompareFunction; +}; + +template +class RocprimDeviceBinarySearch : public ::testing::Test { +public: + using params = Params; +}; + +using custom_int2 = test_utils::custom_test_type; +using custom_double2 = test_utils::custom_test_type; + +typedef ::testing::Types< + params, + params >, + params >, + params, + params, + params > +> Params; + +TYPED_TEST_CASE(RocprimDeviceBinarySearch, Params); + +std::vector get_sizes() +{ + std::vector sizes = { 1, 10, 53, 211, 1024, 2345, 4096, 34567, (1 << 16) - 1220, (1 << 22) - 76543 }; + const std::vector random_sizes = test_utils::get_random_data(5, 1, 100000); + sizes.insert(sizes.end(), random_sizes.begin(), random_sizes.end()); + return sizes; +} + +TYPED_TEST(RocprimDeviceBinarySearch, LowerBound) +{ + using haystack_type = typename TestFixture::params::haystack_type; + using needle_type = typename TestFixture::params::needle_type; + using output_type = typename TestFixture::params::output_type; + using compare_op_type = typename TestFixture::params::compare_op_type; + + hc::accelerator acc; + hc::accelerator_view acc_view = acc.create_view(); + + const bool debug_synchronous = false; + + compare_op_type compare_op; + + for(size_t size : get_sizes()) + { + SCOPED_TRACE(testing::Message() << "with size = " << size); + + const size_t haystack_size = size; + const size_t needles_size = std::sqrt(size); + const size_t d = haystack_size / 100; + + // Generate data + std::vector haystack = test_utils::get_random_data( + haystack_size, 0, haystack_size + 2 * d + ); + std::sort(haystack.begin(), haystack.end(), compare_op); + + // Use a narrower range for needles for checking out-of-haystack cases + std::vector needles = test_utils::get_random_data( + needles_size, d, haystack_size + d + ); + + hc::array d_haystack(hc::extent<1>(haystack_size), haystack.begin(), acc_view); + hc::array d_needles(hc::extent<1>(needles_size), needles.begin(), acc_view); + hc::array d_output(needles_size, acc_view); + + // Calculate expected results on host + std::vector expected(needles_size); + for(size_t i = 0; i < needles_size; i++) + { + expected[i] = + std::lower_bound(haystack.begin(), haystack.end(), needles[i], compare_op) - + haystack.begin(); + } + + size_t temporary_storage_bytes; + rocprim::lower_bound( + nullptr, temporary_storage_bytes, + d_haystack.accelerator_pointer(), d_needles.accelerator_pointer(), d_output.accelerator_pointer(), + haystack_size, needles_size, + compare_op, + acc_view, debug_synchronous + ); + + ASSERT_GT(temporary_storage_bytes, 0); + + hc::array d_temporary_storage(temporary_storage_bytes, acc_view); + + rocprim::lower_bound( + d_temporary_storage.accelerator_pointer(), temporary_storage_bytes, + d_haystack.accelerator_pointer(), d_needles.accelerator_pointer(), d_output.accelerator_pointer(), + haystack_size, needles_size, + compare_op, + acc_view, debug_synchronous + ); + acc_view.wait(); + + std::vector output = d_output; + + ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected)); + } +} + +TYPED_TEST(RocprimDeviceBinarySearch, UpperBound) +{ + using haystack_type = typename TestFixture::params::haystack_type; + using needle_type = typename TestFixture::params::needle_type; + using output_type = typename TestFixture::params::output_type; + using compare_op_type = typename TestFixture::params::compare_op_type; + + hc::accelerator acc; + hc::accelerator_view acc_view = acc.create_view(); + + const bool debug_synchronous = false; + + compare_op_type compare_op; + + for(size_t size : get_sizes()) + { + SCOPED_TRACE(testing::Message() << "with size = " << size); + + const size_t haystack_size = size; + const size_t needles_size = std::sqrt(size); + const size_t d = haystack_size / 100; + + // Generate data + std::vector haystack = test_utils::get_random_data( + haystack_size, 0, haystack_size + 2 * d + ); + std::sort(haystack.begin(), haystack.end(), compare_op); + + // Use a narrower range for needles for checking out-of-haystack cases + std::vector needles = test_utils::get_random_data( + needles_size, d, haystack_size + d + ); + + hc::array d_haystack(hc::extent<1>(haystack_size), haystack.begin(), acc_view); + hc::array d_needles(hc::extent<1>(needles_size), needles.begin(), acc_view); + hc::array d_output(needles_size, acc_view); + + // Calculate expected results on host + std::vector expected(needles_size); + for(size_t i = 0; i < needles_size; i++) + { + expected[i] = + std::upper_bound(haystack.begin(), haystack.end(), needles[i], compare_op) - + haystack.begin(); + } + + size_t temporary_storage_bytes; + rocprim::upper_bound( + nullptr, temporary_storage_bytes, + d_haystack.accelerator_pointer(), d_needles.accelerator_pointer(), d_output.accelerator_pointer(), + haystack_size, needles_size, + compare_op, + acc_view, debug_synchronous + ); + + ASSERT_GT(temporary_storage_bytes, 0); + + hc::array d_temporary_storage(temporary_storage_bytes, acc_view); + + rocprim::upper_bound( + d_temporary_storage.accelerator_pointer(), temporary_storage_bytes, + d_haystack.accelerator_pointer(), d_needles.accelerator_pointer(), d_output.accelerator_pointer(), + haystack_size, needles_size, + compare_op, + acc_view, debug_synchronous + ); + acc_view.wait(); + + std::vector output = d_output; + + ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected)); + } +} + +TYPED_TEST(RocprimDeviceBinarySearch, BinarySearch) +{ + using haystack_type = typename TestFixture::params::haystack_type; + using needle_type = typename TestFixture::params::needle_type; + using output_type = typename TestFixture::params::output_type; + using compare_op_type = typename TestFixture::params::compare_op_type; + + hc::accelerator acc; + hc::accelerator_view acc_view = acc.create_view(); + + const bool debug_synchronous = false; + + compare_op_type compare_op; + + for(size_t size : get_sizes()) + { + SCOPED_TRACE(testing::Message() << "with size = " << size); + + const size_t haystack_size = size; + const size_t needles_size = std::sqrt(size); + const size_t d = haystack_size / 100; + + // Generate data + std::vector haystack = test_utils::get_random_data( + haystack_size, 0, haystack_size + 2 * d + ); + std::sort(haystack.begin(), haystack.end(), compare_op); + + // Use a narrower range for needles for checking out-of-haystack cases + std::vector needles = test_utils::get_random_data( + needles_size, d, haystack_size + d + ); + + hc::array d_haystack(hc::extent<1>(haystack_size), haystack.begin(), acc_view); + hc::array d_needles(hc::extent<1>(needles_size), needles.begin(), acc_view); + hc::array d_output(needles_size, acc_view); + + // Calculate expected results on host + std::vector expected(needles_size); + for(size_t i = 0; i < needles_size; i++) + { + expected[i] = std::binary_search(haystack.begin(), haystack.end(), needles[i], compare_op); + } + + size_t temporary_storage_bytes; + rocprim::binary_search( + nullptr, temporary_storage_bytes, + d_haystack.accelerator_pointer(), d_needles.accelerator_pointer(), d_output.accelerator_pointer(), + haystack_size, needles_size, + compare_op, + acc_view, debug_synchronous + ); + + ASSERT_GT(temporary_storage_bytes, 0); + + hc::array d_temporary_storage(temporary_storage_bytes, acc_view); + + rocprim::binary_search( + d_temporary_storage.accelerator_pointer(), temporary_storage_bytes, + d_haystack.accelerator_pointer(), d_needles.accelerator_pointer(), d_output.accelerator_pointer(), + haystack_size, needles_size, + compare_op, + acc_view, debug_synchronous + ); + acc_view.wait(); + + std::vector output = d_output; + + ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected)); + } +} diff --git a/test/rocprim/test_hc_device_merge.cpp b/test/rocprim/test_hc_device_merge.cpp index a58b43cae..2906c0140 100644 --- a/test/rocprim/test_hc_device_merge.cpp +++ b/test/rocprim/test_hc_device_merge.cpp @@ -34,17 +34,17 @@ #include "test_utils.hpp" -namespace rp = rocprim; - // Params for tests template< class KeyType, - class ValueType + class ValueType, + class CompareOp = ::rocprim::less > struct DeviceMergeParams { using key_type = KeyType; using value_type = ValueType; + using compare_op_type = CompareOp; }; template @@ -53,6 +53,7 @@ class RocprimDeviceMergeTests : public ::testing::Test public: using key_type = typename Params::key_type; using value_type = typename Params::value_type; + using compare_op_type = typename Params::compare_op_type; const bool debug_synchronous = false; }; @@ -61,10 +62,10 @@ using custom_double2 = test_utils::custom_test_type; typedef ::testing::Types< DeviceMergeParams, - DeviceMergeParams, + DeviceMergeParams >, DeviceMergeParams, DeviceMergeParams, - DeviceMergeParams, + DeviceMergeParams >, DeviceMergeParams > RocprimDeviceMergeTestsParams; @@ -93,6 +94,7 @@ TYPED_TEST_CASE(RocprimDeviceMergeTests, RocprimDeviceMergeTestsParams); TYPED_TEST(RocprimDeviceMergeTests, MergeKey) { using key_type = typename TestFixture::key_type; + using compare_op_type = typename TestFixture::compare_op_type; const bool debug_synchronous = TestFixture::debug_synchronous; hc::accelerator acc; @@ -108,17 +110,14 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKey) const size_t size1 = std::get<0>(sizes); const size_t size2 = std::get<1>(sizes); + // compare function + compare_op_type compare_op; + // Generate data std::vector keys_input1 = test_utils::get_random_data(size1, 0, size1); std::vector keys_input2 = test_utils::get_random_data(size2, 0, size2); - std::sort(keys_input1.begin(), keys_input1.end()); - std::sort(keys_input2.begin(), keys_input2.end()); - - test_utils::out_of_bounds_flag out_of_bounds(acc_view); - - hc::array d_keys_input1(hc::extent<1>(size1), keys_input1.begin(), acc_view); - hc::array d_keys_input2(hc::extent<1>(size2), keys_input2.begin(), acc_view); - hc::array d_keys_output(size1 + size2, acc_view); + std::sort(keys_input1.begin(), keys_input1.end(), compare_op); + std::sort(keys_input2.begin(), keys_input2.end(), compare_op); // Calculate expected results on host std::vector expected(size1 + size2); @@ -127,17 +126,23 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKey) keys_input1.end(), keys_input2.begin(), keys_input2.end(), - expected.begin() + expected.begin(), + compare_op ); + test_utils::out_of_bounds_flag out_of_bounds(acc_view); + + hc::array d_keys_input1(hc::extent<1>(size1), keys_input1.begin(), acc_view); + hc::array d_keys_input2(hc::extent<1>(size2), keys_input2.begin(), acc_view); + hc::array d_keys_output(size1 + size2, acc_view); + + test_utils::bounds_checking_iterator d_keys_checking_output( d_keys_output.accelerator_pointer(), out_of_bounds.device_pointer(), size1 + size2 ); - // compare function - ::rocprim::less lesser_op; // temp storage size_t temp_storage_size_bytes; @@ -148,7 +153,7 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKey) d_keys_input2.accelerator_pointer(), d_keys_checking_output, keys_input1.size(), keys_input2.size(), - lesser_op, acc_view, debug_synchronous + compare_op, acc_view, debug_synchronous ); // temp_storage_size_bytes must be >0 @@ -164,7 +169,7 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKey) d_keys_input2.accelerator_pointer(), d_keys_checking_output, keys_input1.size(), keys_input2.size(), - lesser_op, acc_view, debug_synchronous + compare_op, acc_view, debug_synchronous ); acc_view.wait(); @@ -178,3 +183,131 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKey) } } } + +TYPED_TEST(RocprimDeviceMergeTests, MergeKeyValue) +{ + using key_type = typename TestFixture::key_type; + using value_type = typename TestFixture::value_type; + using compare_op_type = typename TestFixture::compare_op_type; + const bool debug_synchronous = TestFixture::debug_synchronous; + + using key_value = std::pair; + + hc::accelerator acc; + hc::accelerator_view acc_view = acc.create_view(); + + for(auto sizes : get_sizes()) + { + SCOPED_TRACE( + testing::Message() << "with sizes = {" << + std::get<0>(sizes) << ", " << std::get<1>(sizes) << "}" + ); + + const size_t size1 = std::get<0>(sizes); + const size_t size2 = std::get<1>(sizes); + + // compare function + compare_op_type compare_op; + + // Generate data + std::vector keys_input1 = test_utils::get_random_data(size1, 0, size1); + std::vector keys_input2 = test_utils::get_random_data(size2, 0, size2); + std::sort(keys_input1.begin(), keys_input1.end(), compare_op); + std::sort(keys_input2.begin(), keys_input2.end(), compare_op); + std::vector values_input1(size1); + std::vector values_input2(size2); + std::iota(values_input1.begin(), values_input1.end(), 0); + std::iota(values_input2.begin(), values_input2.end(), size1); + + // Calculate expected results on host + std::vector vector1(size1); + std::vector vector2(size2); + + for(size_t i = 0; i < size1; i++) + { + vector1[i] = key_value(keys_input1[i], values_input1[i]); + } + for(size_t i = 0; i < size2; i++) + { + vector2[i] = key_value(keys_input2[i], values_input2[i]); + } + + std::vector expected(size1 + size2); + std::merge( + vector1.begin(), + vector1.end(), + vector2.begin(), + vector2.end(), + expected.begin(), + [compare_op](const key_value& a, const key_value& b) { return compare_op(a.first, b.first); } + ); + + test_utils::out_of_bounds_flag out_of_bounds(acc_view); + + hc::array d_keys_input1(hc::extent<1>(size1), keys_input1.begin(), acc_view); + hc::array d_keys_input2(hc::extent<1>(size2), keys_input2.begin(), acc_view); + hc::array d_keys_output(size1 + size2, acc_view); + hc::array d_values_input1(hc::extent<1>(size1), values_input1.begin(), acc_view); + hc::array d_values_input2(hc::extent<1>(size2), values_input2.begin(), acc_view); + hc::array d_values_output(size1 + size2, acc_view); + + + test_utils::bounds_checking_iterator d_keys_checking_output( + d_keys_output.accelerator_pointer(), + out_of_bounds.device_pointer(), + size1 + size2 + ); + test_utils::bounds_checking_iterator d_values_checking_output( + d_values_output.accelerator_pointer(), + out_of_bounds.device_pointer(), + size1 + size2 + ); + + // temp storage + size_t temp_storage_size_bytes; + + // Get size of d_temp_storage + rocprim::merge( + nullptr, temp_storage_size_bytes, + d_keys_input1.accelerator_pointer(), + d_keys_input2.accelerator_pointer(), + d_keys_checking_output, + d_values_input1.accelerator_pointer(), + d_values_input2.accelerator_pointer(), + d_values_checking_output, + keys_input1.size(), keys_input2.size(), + compare_op, acc_view, debug_synchronous + ); + + // temp_storage_size_bytes must be >0 + ASSERT_GT(temp_storage_size_bytes, 0); + + // allocate temporary storage + hc::array d_temp_storage(temp_storage_size_bytes, acc_view); + + // Run + rocprim::merge( + d_temp_storage.accelerator_pointer(), temp_storage_size_bytes, + d_keys_input1.accelerator_pointer(), + d_keys_input2.accelerator_pointer(), + d_keys_checking_output, + d_values_input1.accelerator_pointer(), + d_values_input2.accelerator_pointer(), + d_values_checking_output, + keys_input1.size(), keys_input2.size(), + compare_op, acc_view, debug_synchronous + ); + acc_view.wait(); + + ASSERT_FALSE(out_of_bounds.get()); + + // Check if keys_output values are as expected + std::vector keys_output = d_keys_output; + std::vector values_output = d_values_output; + for(size_t i = 0; i < keys_output.size(); i++) + { + ASSERT_EQ(keys_output[i], expected[i].first); + ASSERT_EQ(values_output[i], expected[i].second); + } + } +} diff --git a/test/rocprim/test_hc_device_reduce_by_key.cpp b/test/rocprim/test_hc_device_reduce_by_key.cpp index beecfbc7e..9ff8ed12e 100644 --- a/test/rocprim/test_hc_device_reduce_by_key.cpp +++ b/test/rocprim/test_hc_device_reduce_by_key.cpp @@ -153,6 +153,8 @@ TYPED_TEST(RocprimDeviceReduceByKey, ReduceByKey) { SCOPED_TRACE(testing::Message() << "with size = " << size); + const bool use_unique_keys = bool(test_utils::get_random_value(0, 1)); + // Generate data and calculate expected results std::vector unique_expected; std::vector aggregates_expected; @@ -167,12 +169,11 @@ TYPED_TEST(RocprimDeviceReduceByKey, ReduceByKey) std::vector values_input = test_utils::get_random_data(size, 0, 100); size_t offset = 0; - key_type current_key = key_distribution_type(0, 100)(gen); - key_type prev_key = current_key; + key_type prev_key = key_distribution_type(0, 100)(gen); + key_type current_key = prev_key + key_delta_dis(gen); while(offset < size) { const size_t key_count = key_count_dis(gen); - current_key = current_key + key_delta_dis(gen); const size_t end = std::min(size, offset + key_count); for(size_t i = offset; i < end; i++) @@ -198,7 +199,18 @@ TYPED_TEST(RocprimDeviceReduceByKey, ReduceByKey) aggregates_expected.back() = reduce_op(aggregates_expected.back(), aggregate); } - prev_key = current_key; + if (use_unique_keys) + { + prev_key = current_key; + // e.g. 1,1,1,2,5,5,8,8,8 + current_key = current_key + key_delta_dis(gen); + } + else + { + // e.g. 1,1,5,1,5,5,5,1 + std::swap(current_key, prev_key); + } + offset += key_count; } diff --git a/test/rocprim/test_hc_device_scan.cpp b/test/rocprim/test_hc_device_scan.cpp index 1405966ec..6694c8aa3 100644 --- a/test/rocprim/test_hc_device_scan.cpp +++ b/test/rocprim/test_hc_device_scan.cpp @@ -312,10 +312,20 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanByKey) { SCOPED_TRACE(testing::Message() << "with size = " << size); + const bool use_unique_keys = bool(test_utils::get_random_value(0, 1)); + // Generate data - std::vector input = test_utils::get_random_data(size, 1, 100); - std::vector keys = test_utils::get_random_data(size, 1, 16); - std::sort(keys.begin(), keys.end()); + std::vector input = test_utils::get_random_data(size, 0, 9); + std::vector keys; + if(use_unique_keys) + { + keys = test_utils::get_random_data(size, 0, 16); + std::sort(keys.begin(), keys.end()); + } + else + { + keys = test_utils::get_random_data(size, 0, 3); + } hc::array d_input(hc::extent<1>(size), input.begin(), acc_view); hc::array d_keys(hc::extent<1>(size), keys.begin(), acc_view); @@ -415,11 +425,21 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScanByKey) { SCOPED_TRACE(testing::Message() << "with size = " << size); + const bool use_unique_keys = bool(test_utils::get_random_value(0, 1)); + // Generate data T initial_value = test_utils::get_random_value(1, 100); - std::vector input = test_utils::get_random_data(size, 1, 100); - std::vector keys = test_utils::get_random_data(size, 1, 16); - std::sort(keys.begin(), keys.end()); + std::vector input = test_utils::get_random_data(size, 0, 9); + std::vector keys; + if(use_unique_keys) + { + keys = test_utils::get_random_data(size, 0, 16); + std::sort(keys.begin(), keys.end()); + } + else + { + keys = test_utils::get_random_data(size, 0, 3); + } hc::array d_input(hc::extent<1>(size), input.begin(), acc_view); hc::array d_keys(hc::extent<1>(size), keys.begin(), acc_view); diff --git a/test/rocprim/test_hc_device_segmented_scan.cpp b/test/rocprim/test_hc_device_segmented_scan.cpp index 697e3b425..44dab710f 100644 --- a/test/rocprim/test_hc_device_segmented_scan.cpp +++ b/test/rocprim/test_hc_device_segmented_scan.cpp @@ -282,7 +282,7 @@ TYPED_TEST(RocprimDeviceSegmentedScan, InclusiveScanUsingHeadFlags) SCOPED_TRACE(testing::Message() << "with size = " << size); // Generate data - std::vector input = test_utils::get_random_data(size, 1, 1); + std::vector input = test_utils::get_random_data(size, 1, 10); std::vector flags = test_utils::get_random_data(size, 0, 10); flags[0] = 1U; std::transform( @@ -386,7 +386,7 @@ TYPED_TEST(RocprimDeviceSegmentedScan, ExclusiveScanUsingHeadFlags) SCOPED_TRACE(testing::Message() << "with size = " << size); // Generate data - std::vector input = test_utils::get_random_data(size, 1, 1); + std::vector input = test_utils::get_random_data(size, 1, 10); std::vector flags = test_utils::get_random_data(size, 0, 10); flags[0] = 1U; std::transform( diff --git a/test/rocprim/test_hc_transform_iterator.cpp b/test/rocprim/test_hc_transform_iterator.cpp index 00ebd2d1c..6cfebde3a 100644 --- a/test/rocprim/test_hc_transform_iterator.cpp +++ b/test/rocprim/test_hc_transform_iterator.cpp @@ -79,7 +79,7 @@ typedef ::testing::Types< RocprimTransformIteratorParams>, RocprimTransformIteratorParams, RocprimTransformIteratorParams, - RocprimTransformIteratorParams, double> + RocprimTransformIteratorParams, size_t> > RocprimTransformIteratorTestsParams; TYPED_TEST_CASE(RocprimTransformIteratorTests, RocprimTransformIteratorTestsParams); @@ -272,4 +272,3 @@ TYPED_TEST(RocprimTransformIteratorTests, TransformReduceCountingIterator) ASSERT_NEAR(output[0], expected, tolerance); } } - diff --git a/test/rocprim/test_hip_counting_iterator.cpp b/test/rocprim/test_hip_counting_iterator.cpp index 252bd3635..0bc4032bb 100644 --- a/test/rocprim/test_hip_counting_iterator.cpp +++ b/test/rocprim/test_hip_counting_iterator.cpp @@ -36,8 +36,7 @@ #include "test_utils.hpp" -#define HIP_CHECK(error) \ - ASSERT_EQ(static_cast(error),hipSuccess) +#define HIP_CHECK(error) ASSERT_EQ(error, hipSuccess) namespace rp = rocprim; @@ -60,7 +59,7 @@ typedef ::testing::Types< RocprimCountingIteratorParams, RocprimCountingIteratorParams, RocprimCountingIteratorParams, - RocprimCountingIteratorParams + RocprimCountingIteratorParams > RocprimCountingIteratorTestsParams; TYPED_TEST_CASE(RocprimCountingIteratorTests, RocprimCountingIteratorTestsParams); @@ -125,15 +124,7 @@ TYPED_TEST(RocprimCountingIteratorTests, Transform) // Validating results for(size_t i = 0; i < output.size(); i++) { - if(std::is_integral::value) - { - ASSERT_EQ(output[i], expected[i]) << "where index = " << i; - } - else if(std::is_floating_point::value) - { - auto tolerance = std::max(std::abs(0.1f * expected[i]), T(0.01f)); - ASSERT_NEAR(output[i], expected[i], tolerance) << "where index = " << i; - } + ASSERT_EQ(output[i], expected[i]) << "where index = " << i; } hipFree(d_output); diff --git a/test/rocprim/test_hip_device_binary_search.cpp b/test/rocprim/test_hip_device_binary_search.cpp new file mode 100644 index 000000000..780d5b454 --- /dev/null +++ b/test/rocprim/test_hip_device_binary_search.cpp @@ -0,0 +1,399 @@ +// MIT License +// +// Copyright (c) 2019 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include +#include +#include +#include +#include +#include + +// Google Test +#include + +// HIP API +#include +#include +// rocPRIM API +#include + +#include "test_utils.hpp" + +#define HIP_CHECK(error) ASSERT_EQ(error, hipSuccess) + +template< + class Haystack, + class Needle, + class Output = size_t, + class CompareFunction = rocprim::less<> +> +struct params +{ + using haystack_type = Haystack; + using needle_type = Needle; + using output_type = Output; + using compare_op_type = CompareFunction; +}; + +template +class RocprimDeviceBinarySearch : public ::testing::Test { +public: + using params = Params; +}; + +using custom_int2 = test_utils::custom_test_type; +using custom_double2 = test_utils::custom_test_type; + +typedef ::testing::Types< + params, + params >, + params >, + params, + params, + params > +> Params; + +TYPED_TEST_CASE(RocprimDeviceBinarySearch, Params); + +std::vector get_sizes() +{ + std::vector sizes = { 1, 10, 53, 211, 1024, 2345, 4096, 34567, (1 << 16) - 1220, (1 << 22) - 76543 }; + const std::vector random_sizes = test_utils::get_random_data(5, 1, 100000); + sizes.insert(sizes.end(), random_sizes.begin(), random_sizes.end()); + return sizes; +} + +TYPED_TEST(RocprimDeviceBinarySearch, LowerBound) +{ + using haystack_type = typename TestFixture::params::haystack_type; + using needle_type = typename TestFixture::params::needle_type; + using output_type = typename TestFixture::params::output_type; + using compare_op_type = typename TestFixture::params::compare_op_type; + + hipStream_t stream = 0; + + const bool debug_synchronous = false; + + compare_op_type compare_op; + + for(size_t size : get_sizes()) + { + SCOPED_TRACE(testing::Message() << "with size = " << size); + + const size_t haystack_size = size; + const size_t needles_size = std::sqrt(size); + const size_t d = haystack_size / 100; + + // Generate data + std::vector haystack = test_utils::get_random_data( + haystack_size, 0, haystack_size + 2 * d + ); + std::sort(haystack.begin(), haystack.end(), compare_op); + + // Use a narrower range for needles for checking out-of-haystack cases + std::vector needles = test_utils::get_random_data( + needles_size, d, haystack_size + d + ); + + haystack_type * d_haystack; + needle_type * d_needles; + output_type * d_output; + HIP_CHECK(hipMalloc(&d_haystack, haystack_size * sizeof(haystack_type))); + HIP_CHECK(hipMalloc(&d_needles, needles_size * sizeof(needle_type))); + HIP_CHECK(hipMalloc(&d_output, needles_size * sizeof(output_type))); + HIP_CHECK( + hipMemcpy( + d_haystack, haystack.data(), + haystack_size * sizeof(haystack_type), + hipMemcpyHostToDevice + ) + ); + HIP_CHECK( + hipMemcpy( + d_needles, needles.data(), + needles_size * sizeof(needle_type), + hipMemcpyHostToDevice + ) + ); + + // Calculate expected results on host + std::vector expected(needles_size); + for(size_t i = 0; i < needles_size; i++) + { + expected[i] = + std::lower_bound(haystack.begin(), haystack.end(), needles[i], compare_op) - + haystack.begin(); + } + + void * d_temporary_storage = nullptr; + size_t temporary_storage_bytes; + HIP_CHECK( + rocprim::lower_bound( + d_temporary_storage, temporary_storage_bytes, + d_haystack, d_needles, d_output, + haystack_size, needles_size, + compare_op, + stream, debug_synchronous + ) + ); + + ASSERT_GT(temporary_storage_bytes, 0); + + HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); + + HIP_CHECK( + rocprim::lower_bound( + d_temporary_storage, temporary_storage_bytes, + d_haystack, d_needles, d_output, + haystack_size, needles_size, + compare_op, + stream, debug_synchronous + ) + ); + + std::vector output(needles_size); + HIP_CHECK( + hipMemcpy( + output.data(), d_output, + needles_size * sizeof(output_type), + hipMemcpyDeviceToHost + ) + ); + + HIP_CHECK(hipFree(d_temporary_storage)); + HIP_CHECK(hipFree(d_haystack)); + HIP_CHECK(hipFree(d_needles)); + HIP_CHECK(hipFree(d_output)); + + ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected)); + } +} + +TYPED_TEST(RocprimDeviceBinarySearch, UpperBound) +{ + using haystack_type = typename TestFixture::params::haystack_type; + using needle_type = typename TestFixture::params::needle_type; + using output_type = typename TestFixture::params::output_type; + using compare_op_type = typename TestFixture::params::compare_op_type; + + hipStream_t stream = 0; + + const bool debug_synchronous = false; + + compare_op_type compare_op; + + for(size_t size : get_sizes()) + { + SCOPED_TRACE(testing::Message() << "with size = " << size); + + const size_t haystack_size = size; + const size_t needles_size = std::sqrt(size); + const size_t d = haystack_size / 100; + + // Generate data + std::vector haystack = test_utils::get_random_data( + haystack_size, 0, haystack_size + 2 * d + ); + std::sort(haystack.begin(), haystack.end(), compare_op); + + // Use a narrower range for needles for checking out-of-haystack cases + std::vector needles = test_utils::get_random_data( + needles_size, d, haystack_size + d + ); + + haystack_type * d_haystack; + needle_type * d_needles; + output_type * d_output; + HIP_CHECK(hipMalloc(&d_haystack, haystack_size * sizeof(haystack_type))); + HIP_CHECK(hipMalloc(&d_needles, needles_size * sizeof(needle_type))); + HIP_CHECK(hipMalloc(&d_output, needles_size * sizeof(output_type))); + HIP_CHECK( + hipMemcpy( + d_haystack, haystack.data(), + haystack_size * sizeof(haystack_type), + hipMemcpyHostToDevice + ) + ); + HIP_CHECK( + hipMemcpy( + d_needles, needles.data(), + needles_size * sizeof(needle_type), + hipMemcpyHostToDevice + ) + ); + + // Calculate expected results on host + std::vector expected(needles_size); + for(size_t i = 0; i < needles_size; i++) + { + expected[i] = + std::upper_bound(haystack.begin(), haystack.end(), needles[i], compare_op) - + haystack.begin(); + } + + void * d_temporary_storage = nullptr; + size_t temporary_storage_bytes; + HIP_CHECK( + rocprim::upper_bound( + d_temporary_storage, temporary_storage_bytes, + d_haystack, d_needles, d_output, + haystack_size, needles_size, + compare_op, + stream, debug_synchronous + ) + ); + + ASSERT_GT(temporary_storage_bytes, 0); + + HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); + + HIP_CHECK( + rocprim::upper_bound( + d_temporary_storage, temporary_storage_bytes, + d_haystack, d_needles, d_output, + haystack_size, needles_size, + compare_op, + stream, debug_synchronous + ) + ); + + std::vector output(needles_size); + HIP_CHECK( + hipMemcpy( + output.data(), d_output, + needles_size * sizeof(output_type), + hipMemcpyDeviceToHost + ) + ); + + HIP_CHECK(hipFree(d_temporary_storage)); + HIP_CHECK(hipFree(d_haystack)); + HIP_CHECK(hipFree(d_needles)); + HIP_CHECK(hipFree(d_output)); + + ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected)); + } +} + +TYPED_TEST(RocprimDeviceBinarySearch, BinarySearch) +{ + using haystack_type = typename TestFixture::params::haystack_type; + using needle_type = typename TestFixture::params::needle_type; + using output_type = typename TestFixture::params::output_type; + using compare_op_type = typename TestFixture::params::compare_op_type; + + hipStream_t stream = 0; + + const bool debug_synchronous = false; + + compare_op_type compare_op; + + for(size_t size : get_sizes()) + { + SCOPED_TRACE(testing::Message() << "with size = " << size); + + const size_t haystack_size = size; + const size_t needles_size = std::sqrt(size); + const size_t d = haystack_size / 100; + + // Generate data + std::vector haystack = test_utils::get_random_data( + haystack_size, 0, haystack_size + 2 * d + ); + std::sort(haystack.begin(), haystack.end(), compare_op); + + // Use a narrower range for needles for checking out-of-haystack cases + std::vector needles = test_utils::get_random_data( + needles_size, d, haystack_size + d + ); + + haystack_type * d_haystack; + needle_type * d_needles; + output_type * d_output; + HIP_CHECK(hipMalloc(&d_haystack, haystack_size * sizeof(haystack_type))); + HIP_CHECK(hipMalloc(&d_needles, needles_size * sizeof(needle_type))); + HIP_CHECK(hipMalloc(&d_output, needles_size * sizeof(output_type))); + HIP_CHECK( + hipMemcpy( + d_haystack, haystack.data(), + haystack_size * sizeof(haystack_type), + hipMemcpyHostToDevice + ) + ); + HIP_CHECK( + hipMemcpy( + d_needles, needles.data(), + needles_size * sizeof(needle_type), + hipMemcpyHostToDevice + ) + ); + + // Calculate expected results on host + std::vector expected(needles_size); + for(size_t i = 0; i < needles_size; i++) + { + expected[i] = std::binary_search(haystack.begin(), haystack.end(), needles[i], compare_op); + } + + void * d_temporary_storage = nullptr; + size_t temporary_storage_bytes; + HIP_CHECK( + rocprim::binary_search( + d_temporary_storage, temporary_storage_bytes, + d_haystack, d_needles, d_output, + haystack_size, needles_size, + compare_op, + stream, debug_synchronous + ) + ); + + ASSERT_GT(temporary_storage_bytes, 0); + + HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); + + HIP_CHECK( + rocprim::binary_search( + d_temporary_storage, temporary_storage_bytes, + d_haystack, d_needles, d_output, + haystack_size, needles_size, + compare_op, + stream, debug_synchronous + ) + ); + + std::vector output(needles_size); + HIP_CHECK( + hipMemcpy( + output.data(), d_output, + needles_size * sizeof(output_type), + hipMemcpyDeviceToHost + ) + ); + + HIP_CHECK(hipFree(d_temporary_storage)); + HIP_CHECK(hipFree(d_haystack)); + HIP_CHECK(hipFree(d_needles)); + HIP_CHECK(hipFree(d_output)); + + ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected)); + } +} diff --git a/test/rocprim/test_hip_device_merge.cpp b/test/rocprim/test_hip_device_merge.cpp index 0a04e0deb..d8e92509f 100644 --- a/test/rocprim/test_hip_device_merge.cpp +++ b/test/rocprim/test_hip_device_merge.cpp @@ -35,20 +35,19 @@ #include "test_utils.hpp" -#define HIP_CHECK(error) \ - ASSERT_EQ(static_cast(error),hipSuccess) - -namespace rp = rocprim; +#define HIP_CHECK(error) ASSERT_EQ(error, hipSuccess) // Params for tests template< class KeyType, - class ValueType + class ValueType, + class CompareOp = ::rocprim::less > struct DeviceMergeParams { using key_type = KeyType; using value_type = ValueType; + using compare_op_type = CompareOp; }; template @@ -57,6 +56,7 @@ class RocprimDeviceMergeTests : public ::testing::Test public: using key_type = typename Params::key_type; using value_type = typename Params::value_type; + using compare_op_type = typename Params::compare_op_type; const bool debug_synchronous = false; }; @@ -65,10 +65,10 @@ using custom_double2 = test_utils::custom_test_type; typedef ::testing::Types< DeviceMergeParams, - DeviceMergeParams, + DeviceMergeParams >, DeviceMergeParams, DeviceMergeParams, - DeviceMergeParams, + DeviceMergeParams >, DeviceMergeParams > RocprimDeviceMergeTestsParams; @@ -97,6 +97,7 @@ TYPED_TEST_CASE(RocprimDeviceMergeTests, RocprimDeviceMergeTestsParams); TYPED_TEST(RocprimDeviceMergeTests, MergeKey) { using key_type = typename TestFixture::key_type; + using compare_op_type = typename TestFixture::compare_op_type; const bool debug_synchronous = TestFixture::debug_synchronous; hipStream_t stream = 0; // default @@ -111,13 +112,27 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKey) const size_t size1 = std::get<0>(sizes); const size_t size2 = std::get<1>(sizes); + // compare function + compare_op_type compare_op; + // Generate data std::vector keys_input1 = test_utils::get_random_data(size1, 0, size1); std::vector keys_input2 = test_utils::get_random_data(size2, 0, size2); - std::sort(keys_input1.begin(), keys_input1.end()); - std::sort(keys_input2.begin(), keys_input2.end()); + std::sort(keys_input1.begin(), keys_input1.end(), compare_op); + std::sort(keys_input2.begin(), keys_input2.end(), compare_op); std::vector keys_output(size1 + size2, 0); + // Calculate expected results on host + std::vector expected(keys_output.size()); + std::merge( + keys_input1.begin(), + keys_input1.end(), + keys_input2.begin(), + keys_input2.end(), + expected.begin(), + compare_op + ); + test_utils::out_of_bounds_flag out_of_bounds; key_type * d_keys_input1; @@ -140,16 +155,170 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKey) hipMemcpyHostToDevice ) ); + + test_utils::bounds_checking_iterator d_keys_checking_output( + d_keys_output, + out_of_bounds.device_pointer(), + size1 + size2 + ); + + // temp storage + size_t temp_storage_size_bytes; + void * d_temp_storage = nullptr; + // Get size of d_temp_storage + HIP_CHECK( + rocprim::merge( + d_temp_storage, temp_storage_size_bytes, + d_keys_input1, d_keys_input2, + d_keys_checking_output, + keys_input1.size(), keys_input2.size(), + compare_op, stream, debug_synchronous + ) + ); + + // temp_storage_size_bytes must be >0 + ASSERT_GT(temp_storage_size_bytes, 0); + + // allocate temporary storage + HIP_CHECK(hipMalloc(&d_temp_storage, temp_storage_size_bytes)); + + // Run + HIP_CHECK( + rocprim::merge( + d_temp_storage, temp_storage_size_bytes, + d_keys_input1, d_keys_input2, + d_keys_checking_output, + keys_input1.size(), keys_input2.size(), + compare_op, stream, debug_synchronous + ) + ); + HIP_CHECK(hipPeekAtLastError()); HIP_CHECK(hipDeviceSynchronize()); + ASSERT_FALSE(out_of_bounds.get()); + + // Copy keys_output to host + HIP_CHECK( + hipMemcpy( + keys_output.data(), d_keys_output, + keys_output.size() * sizeof(key_type), + hipMemcpyDeviceToHost + ) + ); + + // Check if keys_output values are as expected + for(size_t i = 0; i < keys_output.size(); i++) + { + ASSERT_EQ(keys_output[i], expected[i]); + } + + hipFree(d_keys_input1); + hipFree(d_keys_input2); + hipFree(d_keys_output); + hipFree(d_temp_storage); + } +} + +TYPED_TEST(RocprimDeviceMergeTests, MergeKeyValue) +{ + using key_type = typename TestFixture::key_type; + using value_type = typename TestFixture::value_type; + using compare_op_type = typename TestFixture::compare_op_type; + const bool debug_synchronous = TestFixture::debug_synchronous; + + using key_value = std::pair; + + hipStream_t stream = 0; // default + + for(auto sizes : get_sizes()) + { + SCOPED_TRACE( + testing::Message() << "with sizes = {" << + std::get<0>(sizes) << ", " << std::get<1>(sizes) << "}" + ); + + const size_t size1 = std::get<0>(sizes); + const size_t size2 = std::get<1>(sizes); + + // compare function + compare_op_type compare_op; + + // Generate data + std::vector keys_input1 = test_utils::get_random_data(size1, 0, size1); + std::vector keys_input2 = test_utils::get_random_data(size2, 0, size2); + std::sort(keys_input1.begin(), keys_input1.end(), compare_op); + std::sort(keys_input2.begin(), keys_input2.end(), compare_op); + std::vector values_input1(size1); + std::vector values_input2(size2); + std::iota(values_input1.begin(), values_input1.end(), 0); + std::iota(values_input2.begin(), values_input2.end(), size1); + std::vector keys_output(size1 + size2, 0); + std::vector values_output(size1 + size2, 0); + // Calculate expected results on host - std::vector expected(keys_output.size()); + std::vector vector1(size1); + std::vector vector2(size2); + + for(size_t i = 0; i < size1; i++) + { + vector1[i] = key_value(keys_input1[i], values_input1[i]); + } + for(size_t i = 0; i < size2; i++) + { + vector2[i] = key_value(keys_input2[i], values_input2[i]); + } + + std::vector expected(size1 + size2); std::merge( - keys_input1.begin(), - keys_input1.end(), - keys_input2.begin(), - keys_input2.end(), - expected.begin() + vector1.begin(), + vector1.end(), + vector2.begin(), + vector2.end(), + expected.begin(), + [compare_op](const key_value& a, const key_value& b) { return compare_op(a.first, b.first); } + ); + + test_utils::out_of_bounds_flag out_of_bounds; + + key_type * d_keys_input1; + key_type * d_keys_input2; + key_type * d_keys_output; + value_type * d_values_input1; + value_type * d_values_input2; + value_type * d_values_output; + HIP_CHECK(hipMalloc(&d_keys_input1, keys_input1.size() * sizeof(key_type))); + HIP_CHECK(hipMalloc(&d_keys_input2, keys_input2.size() * sizeof(key_type))); + HIP_CHECK(hipMalloc(&d_keys_output, keys_output.size() * sizeof(key_type))); + HIP_CHECK(hipMalloc(&d_values_input1, values_input1.size() * sizeof(value_type))); + HIP_CHECK(hipMalloc(&d_values_input2, values_input2.size() * sizeof(value_type))); + HIP_CHECK(hipMalloc(&d_values_output, values_output.size() * sizeof(value_type))); + HIP_CHECK( + hipMemcpy( + d_keys_input1, keys_input1.data(), + keys_input1.size() * sizeof(key_type), + hipMemcpyHostToDevice + ) + ); + HIP_CHECK( + hipMemcpy( + d_keys_input2, keys_input2.data(), + keys_input2.size() * sizeof(key_type), + hipMemcpyHostToDevice + ) + ); + HIP_CHECK( + hipMemcpy( + d_values_input1, values_input1.data(), + values_input1.size() * sizeof(value_type), + hipMemcpyHostToDevice + ) + ); + HIP_CHECK( + hipMemcpy( + d_values_input2, values_input2.data(), + values_input2.size() * sizeof(value_type), + hipMemcpyHostToDevice + ) ); test_utils::bounds_checking_iterator d_keys_checking_output( @@ -157,9 +326,12 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKey) out_of_bounds.device_pointer(), size1 + size2 ); + test_utils::bounds_checking_iterator d_values_checking_output( + d_values_output, + out_of_bounds.device_pointer(), + size1 + size2 + ); - // compare function - ::rocprim::less lesser_op; // temp storage size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; @@ -169,8 +341,10 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKey) d_temp_storage, temp_storage_size_bytes, d_keys_input1, d_keys_input2, d_keys_checking_output, + d_values_input1, d_values_input2, + d_values_checking_output, keys_input1.size(), keys_input2.size(), - lesser_op, stream, debug_synchronous + compare_op, stream, debug_synchronous ) ); @@ -179,7 +353,6 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKey) // allocate temporary storage HIP_CHECK(hipMalloc(&d_temp_storage, temp_storage_size_bytes)); - HIP_CHECK(hipDeviceSynchronize()); // Run HIP_CHECK( @@ -187,8 +360,10 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKey) d_temp_storage, temp_storage_size_bytes, d_keys_input1, d_keys_input2, d_keys_checking_output, + d_values_input1, d_values_input2, + d_values_checking_output, keys_input1.size(), keys_input2.size(), - lesser_op, stream, debug_synchronous + compare_op, stream, debug_synchronous ) ); HIP_CHECK(hipPeekAtLastError()); @@ -196,7 +371,6 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKey) ASSERT_FALSE(out_of_bounds.get()); - // Copy keys_output to host HIP_CHECK( hipMemcpy( keys_output.data(), d_keys_output, @@ -204,17 +378,27 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKey) hipMemcpyDeviceToHost ) ); - HIP_CHECK(hipDeviceSynchronize()); + HIP_CHECK( + hipMemcpy( + values_output.data(), d_values_output, + values_output.size() * sizeof(value_type), + hipMemcpyDeviceToHost + ) + ); // Check if keys_output values are as expected for(size_t i = 0; i < keys_output.size(); i++) { - ASSERT_EQ(keys_output[i], expected[i]); + ASSERT_EQ(keys_output[i], expected[i].first); + ASSERT_EQ(values_output[i], expected[i].second); } hipFree(d_keys_input1); hipFree(d_keys_input2); hipFree(d_keys_output); + hipFree(d_values_input1); + hipFree(d_values_input2); + hipFree(d_values_output); hipFree(d_temp_storage); } } diff --git a/test/rocprim/test_hip_device_reduce_by_key.cpp b/test/rocprim/test_hip_device_reduce_by_key.cpp index 7f69e803b..f3e697b99 100644 --- a/test/rocprim/test_hip_device_reduce_by_key.cpp +++ b/test/rocprim/test_hip_device_reduce_by_key.cpp @@ -37,8 +37,7 @@ #include "test_utils.hpp" -#define HIP_CHECK(error) \ - ASSERT_EQ(static_cast(error),hipSuccess) +#define HIP_CHECK(error) ASSERT_EQ(error, hipSuccess) namespace rp = rocprim; @@ -158,6 +157,8 @@ TYPED_TEST(RocprimDeviceReduceByKey, ReduceByKey) hipStream_t stream = 0; // default + const bool use_unique_keys = bool(test_utils::get_random_value(0, 1)); + // Generate data and calculate expected results std::vector unique_expected; std::vector aggregates_expected; @@ -172,12 +173,11 @@ TYPED_TEST(RocprimDeviceReduceByKey, ReduceByKey) std::vector values_input = test_utils::get_random_data(size, 0, 100); size_t offset = 0; - key_type current_key = key_distribution_type(0, 100)(gen); - key_type prev_key = current_key; + key_type prev_key = key_distribution_type(0, 100)(gen); + key_type current_key = prev_key + key_delta_dis(gen); while(offset < size) { const size_t key_count = key_count_dis(gen); - current_key = current_key + key_delta_dis(gen); const size_t end = std::min(size, offset + key_count); for(size_t i = offset; i < end; i++) @@ -203,7 +203,18 @@ TYPED_TEST(RocprimDeviceReduceByKey, ReduceByKey) aggregates_expected.back() = reduce_op(aggregates_expected.back(), aggregate); } - prev_key = current_key; + if (use_unique_keys) + { + prev_key = current_key; + // e.g. 1,1,1,2,5,5,8,8,8 + current_key = current_key + key_delta_dis(gen); + } + else + { + // e.g. 1,1,5,1,5,5,5,1 + std::swap(current_key, prev_key); + } + offset += key_count; } diff --git a/test/rocprim/test_hip_device_scan.cpp b/test/rocprim/test_hip_device_scan.cpp index 2990d42c9..234a880b0 100644 --- a/test/rocprim/test_hip_device_scan.cpp +++ b/test/rocprim/test_hip_device_scan.cpp @@ -371,10 +371,20 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanByKey) SCOPED_TRACE(testing::Message() << "with size = " << size); + const bool use_unique_keys = bool(test_utils::get_random_value(0, 1)); + // Generate data - std::vector input = test_utils::get_random_data(size, 1, 10); - std::vector keys = test_utils::get_random_data(size, 1, 16); - std::sort(keys.begin(), keys.end()); + std::vector input = test_utils::get_random_data(size, 0, 9); + std::vector keys; + if(use_unique_keys) + { + keys = test_utils::get_random_data(size, 0, 16); + std::sort(keys.begin(), keys.end()); + } + else + { + keys = test_utils::get_random_data(size, 0, 3); + } std::vector output(input.size(), 0); T * d_input; @@ -497,11 +507,21 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScanByKey) SCOPED_TRACE(testing::Message() << "with size = " << size); + const bool use_unique_keys = bool(test_utils::get_random_value(0, 1)); + // Generate data - T initial_value = test_utils::get_random_value(1, 1); - std::vector input = test_utils::get_random_data(size, 1, 10); - std::vector keys = test_utils::get_random_data(size, 1, 16); - std::sort(keys.begin(), keys.end()); + T initial_value = test_utils::get_random_value(1, 100); + std::vector input = test_utils::get_random_data(size, 0, 9); + std::vector keys; + if(use_unique_keys) + { + keys = test_utils::get_random_data(size, 0, 16); + std::sort(keys.begin(), keys.end()); + } + else + { + keys = test_utils::get_random_data(size, 0, 3); + } std::vector output(input.size(), 0); T * d_input; diff --git a/test/rocprim/test_hip_device_segmented_scan.cpp b/test/rocprim/test_hip_device_segmented_scan.cpp index 050c2f4b2..009238a63 100644 --- a/test/rocprim/test_hip_device_segmented_scan.cpp +++ b/test/rocprim/test_hip_device_segmented_scan.cpp @@ -37,7 +37,7 @@ #include "test_utils.hpp" -#define HIP_CHECK(error) ASSERT_EQ(static_cast(error),hipSuccess) +#define HIP_CHECK(error) ASSERT_EQ(error, hipSuccess) namespace rp = rocprim; @@ -370,7 +370,7 @@ TYPED_TEST(RocprimDeviceSegmentedScan, InclusiveScanUsingHeadFlags) SCOPED_TRACE(testing::Message() << "with size = " << size); // Generate data - std::vector input = test_utils::get_random_data(size, 1, 1); + std::vector input = test_utils::get_random_data(size, 1, 10); std::vector flags = test_utils::get_random_data(size, 0, 10); flags[0] = 1U; std::transform( @@ -500,7 +500,7 @@ TYPED_TEST(RocprimDeviceSegmentedScan, ExclusiveScanUsingHeadFlags) SCOPED_TRACE(testing::Message() << "with size = " << size); // Generate data - std::vector input = test_utils::get_random_data(size, 1, 1); + std::vector input = test_utils::get_random_data(size, 1, 10); std::vector flags = test_utils::get_random_data(size, 0, 10); flags[0] = 1U; std::transform( diff --git a/test/rocprim/test_utils.hpp b/test/rocprim/test_utils.hpp index 6c407d584..d4a978d56 100644 --- a/test/rocprim/test_utils.hpp +++ b/test/rocprim/test_utils.hpp @@ -241,9 +241,8 @@ OutputIt host_inclusive_scan(InputIt first, InputIt last, OutputIt d_first, BinaryOperation op) { using input_type = typename std::iterator_traits::value_type; - using output_type = typename std::iterator_traits::value_type; using result_type = typename ::rocprim::detail::match_result_type< - input_type, output_type, BinaryOperation + input_type, BinaryOperation >::type; if (first == last) return d_first; @@ -264,9 +263,8 @@ OutputIt host_exclusive_scan(InputIt first, InputIt last, BinaryOperation op) { using input_type = typename std::iterator_traits::value_type; - using output_type = typename std::iterator_traits::value_type; using result_type = typename ::rocprim::detail::match_result_type< - input_type, output_type, BinaryOperation + input_type, BinaryOperation >::type; if (first == last) return d_first; @@ -289,9 +287,8 @@ OutputIt host_exclusive_scan_by_key(InputIt first, InputIt last, KeyIt k_first, BinaryOperation op, KeyCompare key_compare_op) { using input_type = typename std::iterator_traits::value_type; - using output_type = typename std::iterator_traits::value_type; using result_type = typename ::rocprim::detail::match_result_type< - input_type, output_type, BinaryOperation + input_type, BinaryOperation >::type; if (first == last) return d_first;